In [3]:
from h5Dataset import h5Dataset
from modules import MPNNLayer, MPNNTokenizer, SelfAttentionEncoder, PredictionHead
from model import MPNNTransformerModel

import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader

In [5]:
def main():
   
    torch.manual_seed(0)
    #Using GPU 3, check if available before running
    device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

    # Load Dataset
    h5_path = "../data/samples/100samples.h5"  # <- CHANGE THIS
    ds = h5Dataset(h5_path)

    loader = DataLoader(ds, batch_size=10)

    # take feature dims from one sample
    sample0 = ds[0]
    node_in_dim = sample0.x.shape[-1]
    edge_in_dim = sample0.edge_attr.shape[-1]

    # Build model
    model = MPNNTransformerModel(
        node_in_dim=node_in_dim,
        edge_in_dim=edge_in_dim,
        num_output_sources=1,  
    ).to(device)

    # Optimizer, loss, scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.0)

    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=25,   # every 10 epochs
        gamma=0.8      # multiply LR by 0.8
    )

    # 6) Train loop (overfit)
    num_epochs = 500  # usually enough to see memorization
    for epoch in range(1, num_epochs + 1):

        model.train()
        running_loss = 0.0

        for data in loader:
            data = data.to(device)

            pred = model.forward_from_data(data)  # expected shape: [I, 2] with I=1
            
            target = data.y

            loss = F.mse_loss(pred, target)

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            

        scheduler.step()

        avg_loss = running_loss / len(loader)
        
        #print occasionally
        if epoch % 25 == 0 or epoch == 1:
            print(f"Epoch {epoch:4d} | loss = {avg_loss:.6f}")

        if avg_loss < 1e-6:
            print(f"Early stopping at epoch {epoch} with loss {avg_loss:.6f}")
            break

    #Quick evaluation on the same samples

    model.eval()
    print("\nPredictions vs targets:")
    with torch.no_grad():
        for i, data in enumerate(ds):
            data = data.to(device)
            pred = model.forward_from_data(data)  # [1,2]
            target = data.y
            if target.dim() == 1:
                target = target.unsqueeze(0)
            if target.shape[-1] == 3:
                target = target[:, :2]

            print(f"[{i:02d}] pred={pred.squeeze(0).cpu().numpy()}  target={target.squeeze(0).cpu().numpy()}")

if __name__ == "__main__":
    main()        


Epoch    1 | loss = 0.053237
Epoch   25 | loss = 0.013544
Epoch   50 | loss = 0.012356
Epoch   75 | loss = 0.011797
Epoch  100 | loss = 0.011814
Epoch  125 | loss = 0.008432
Epoch  150 | loss = 0.006433
Epoch  175 | loss = 0.005515
Epoch  200 | loss = 0.003683
Epoch  225 | loss = 0.003715
Epoch  250 | loss = 0.003160
Epoch  275 | loss = 0.003258
Epoch  300 | loss = 0.000798
Epoch  325 | loss = 0.000720
Epoch  350 | loss = 0.000324
Epoch  375 | loss = 0.000235
Epoch  400 | loss = 0.000187
Epoch  425 | loss = 0.000157
Epoch  450 | loss = 0.000135
Epoch  475 | loss = 0.000119
Epoch  500 | loss = 0.000108

Predictions vs targets:
[00] pred=[0.06399404 0.26747346]  target=[[0.0637504  0.27527907]]
[01] pred=[ 0.03137429 -0.06574981]  target=[[ 0.03151155 -0.06910084]]
[02] pred=[ 0.25985107 -0.24152297]  target=[[ 0.26129934 -0.24150708]]
[03] pred=[-0.01402315  0.1424833 ]  target=[[-0.01454815  0.1408742 ]]
[04] pred=[-0.29331306  0.14965004]  target=[[-0.28629878  0.15405327]]
[05] pred=