In [4]:
import sys
sys.path.append("../src")

from precomputed_dataset import precomputedDataset
from modules import MPNNLayer, MPNNTokenizer, SelfAttentionEncoder, PredictionHead
from model import MPNNTransformerModel

import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader

In [6]:
def main():
    torch.manual_seed(0)
    #Using GPU 3, check if available before running
    device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
    #device = torch.device("cpu")
    print("Using device:", device)
    
    # Load Dataset
    h5_path = "../data/10000samples.h5"  # <- CHANGE THIS
    ds = precomputedDataset(h5_path)

    loader = DataLoader(ds, batch_size=32, num_workers=4, pin_memory=True, persistent_workers=True, prefetch_factor=2)

    # take feature dims from one sample
    sample0 = ds[0]
    node_in_dim = sample0.x.shape[-1]
    edge_in_dim = sample0.edge_attr.shape[-1]

    # Build model
    model = MPNNTransformerModel(
        node_in_dim=node_in_dim,
        edge_in_dim=edge_in_dim,
        num_output_sources=1,  
    ).to(device)

    # Optimizer, loss, scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.0)

    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=25,   # every 10 epochs
        gamma=0.8      # multiply LR by 0.8
    )

    # 6) Train loop (overfit)
    num_epochs = 50  # usually enough to see memorization
    for epoch in range(1, num_epochs + 1):

        model.train()
        running_loss = 0.0

        for data in loader:
            data = data.to(device)

            pred = model.forward_from_data(data)  # expected shape: [I, 2] with I=1
            
            target = data.y

            loss = F.mse_loss(pred, target)

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            

        scheduler.step()

        avg_loss = running_loss / len(loader)
        
        #print occasionally
        if epoch % 25 == 0 or epoch == 1:
            print(f"Epoch {epoch:4d} | loss = {avg_loss:.6f}")

        if avg_loss < 1e-6:
            print(f"Early stopping at epoch {epoch} with loss {avg_loss:.6f}")
            break

    #Quick evaluation on the same samples

    model.eval()
    print("\nPredictions vs targets:")
    with torch.no_grad():
        for i, data in enumerate(ds):
            data = data.to(device)
            pred = model.forward_from_data(data)  # [1,2]
            target = data.y
            if target.dim() == 1:
                target = target.unsqueeze(0)
            if target.shape[-1] == 3:
                target = target[:, :2]

            print(f"[{i:02d}] pred={pred.squeeze(0).cpu().numpy()}  target={target.squeeze(0).cpu().numpy()}")

if __name__ == "__main__":
    main()        


Using device: cuda:2


Epoch    1 | loss = 0.028299
Epoch   25 | loss = 0.027470
Epoch   50 | loss = 0.027456

Predictions vs targets:
[00] pred=[0.0002824  0.00355356]  target=[[0.0637504  0.27527907]]
[01] pred=[0.00028245 0.00355353]  target=[[ 0.03151155 -0.06910084]]
[02] pred=[0.00028247 0.00355352]  target=[[ 0.26129934 -0.24150708]]
[03] pred=[0.00028233 0.0035536 ]  target=[[-0.26930133  0.33665705]]
[04] pred=[0.00028238 0.00355357]  target=[[-0.15904541 -0.26057345]]
[05] pred=[0.00028245 0.00355353]  target=[[ 0.16570383 -0.19515383]]
[06] pred=[0.00028247 0.00355352]  target=[[ 0.09985827 -0.24750595]]
[07] pred=[0.0002824  0.00355356]  target=[[0.11258166 0.08941735]]
[08] pred=[0.00028238 0.00355357]  target=[[-0.01980888  0.10539575]]
[09] pred=[0.00028246 0.00355352]  target=[[-0.28913948 -0.3706684 ]]
[10] pred=[0.00028241 0.00355355]  target=[[-0.2512472  -0.37273756]]
[11] pred=[0.00028241 0.00355356]  target=[[0.07799777 0.09916284]]
[12] pred=[0.00028239 0.00355357]  target=[[-0.3211778