In [3]:
# alright, we just tried a minimal example of a simple neural network
# It's not a good idea to train your neural network from scratch, it's better to use a pre-trained model

# We will use the TorchMD-ET model, and I really recommend for its simplicity and efficiency compared to group-equivariant models
from Geom2Vec.geom2vec.models.torchmd.main_model import create_model, get_args

import torch
import torch.nn as nn

In [5]:
from torch_geometric.data import Data, DataLoader
from torch_geometric.datasets import QM9 # We will use the QM9 dataset for this example

# Let's start by loading the dataset

path = '/project/dinner/zpengmei/Geom2Vec/Tutorial/data_sets/QM9'

# QM9 has many labels, let's just pick the first one for now, we can do this with Transform object in PyG

class QM9Transform:
    def __call__(self, data):
        # Select target.
        data.y = data.y[:, 0]
        return data

# Load the dataset, you can't do this on the compute node, you need to do this on the login node for internet access

dataset = QM9(path, transform=QM9Transform()).shuffle()

# Normalize targets to mean = 0 and std = 1.
mean = dataset.data.y.mean(dim=0, keepdim=True)
std = dataset.data.y.std(dim=0, keepdim=True)
dataset.data.y = (dataset.data.y - mean) / std
mean, std = mean[:, 0].item(), std[:, 0].item()

# split the dataset into training, validation and test sets

train_dataset = dataset[:110000]
val_dataset = dataset[110000:120000]
test_dataset = dataset[120000:]

# load your data into the DataLoader

test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)



In [8]:
from torch_scatter import scatter

class Net(nn.Module):

    def __init__(self, hidden_dim, output_dim):
        super(Net, self).__init__()
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        pt_gnn_args = get_args(
            hidden_channels=hidden_dim, 
            num_layers=9,
            num_rbf=64,
            num_heads=8, 
            cutoff=7.5 # cutoff radius
        )
        self.representation_model = create_model(pt_gnn_args)

        # a simple MLP for regression
        self.task_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, z, pos, batch):

        x_rep,_ , _ = self.representation_model(z, pos, batch)
        x_rep = scatter(x_rep, batch, dim=0, reduce='add')
        out = self.task_head(x_rep)

        return out



In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(hidden_dim=256, output_dim=1).to(device)

checkpoint = torch.load('/project/dinner/zpengmei/subspace_pytorch/pretrain/denoise/denali/logs_models_denali_0.2/tensorboard_logs/ET_l9_256hc_9l_8head_2024-04-12_21-41-17/ET_best.pth')
model.representation_model.load_state_dict(checkpoint, strict=False)

# Now we have a pre-trained model, let's freeze it and train the task head
for param in model.representation_model.parameters():
    param.requires_grad = False


In [21]:
# you can always accelerate your model via automatic mixed precision


# define the optimizer and loss function
from torch.optim import Adam
from torch.nn import L1Loss
from tqdm import tqdm

optimizer = Adam(model.parameters(), lr=1e-3)
criterion = L1Loss()

# define the training loop

def train():
    model.train()

    total_loss = 0
    for data in tqdm(train_loader):
        data = data.to(device)
        optimizer.zero_grad()

        out = model(data.z, data.pos, data.batch) # GNN operates on atomic numbers, positions and batch vector which assigns each atom to a specific molecule
        loss = criterion(out.view(-1), data.y.view(-1))

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        # show the batch loss in the progress bar
        tqdm.write(f'loss: {loss.item()}')

    return total_loss / len(train_loader)

# define the evaluation loop

@torch.no_grad()
def test(loader):
    model.eval()

    total_loss = 0
    for data in loader:
        data = data.to(device)
        out = model(data.z, data.pos, data.batch)
        total_loss += criterion(out.view(-1), data.y.view(-1)).item()

    return total_loss / len(loader)

In [22]:
# train the model
from tqdm import tqdm

best_val_loss = None
for epoch in tqdm(range(1, 101)):
    loss = train()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')

    # you can change the frequency of validation and test
    if epoch % 2 == 0:
        val_loss = test(val_loader)
        print(f'Val Loss: {val_loss:.4f}')

        if best_val_loss is None or val_loss <= best_val_loss:
            best_val_loss = val_loss
            # test the accuracy on the test set
            test_loss = test(test_loader)
            print(f'Test Loss: {test_loss:.4f}')


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:00<?, ?it/s]

loss: 0.5200073719024658
loss: 1.4172673225402832
loss: 0.5361088514328003
loss: 0.9493891596794128


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:00<?, ?it/s]

loss: 1.0238265991210938
loss: 0.8548645973205566
loss: 0.4784182906150818
loss: 0.5422019958496094
loss: 0.7181023955345154


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:00<?, ?it/s]

loss: 0.7830835580825806
loss: 0.781213104724884
loss: 0.48556357622146606
loss: 0.5631784796714783
loss: 0.6033120155334473


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:01<?, ?it/s]

loss: 0.7289230823516846
loss: 0.6909183263778687
loss: 0.6736429333686829
loss: 0.5923258066177368
loss: 0.5828412175178528


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:01<?, ?it/s]

loss: 0.571662187576294
loss: 0.623786211013794
loss: 0.6461403369903564
loss: 0.5875224471092224
loss: 0.5990049242973328


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:01<?, ?it/s]

loss: 0.43501219153404236
loss: 0.5819124579429626
loss: 0.6337181329727173
loss: 0.5733524560928345
loss: 0.6440577507019043


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:01<?, ?it/s]

loss: 0.6440508365631104
loss: 0.5117762088775635
loss: 0.5486366748809814
loss: 0.521108865737915
loss: 0.608310341835022


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:01<?, ?it/s]

loss: 0.5508701205253601
loss: 0.506677508354187
loss: 0.4590603709220886
loss: 0.5459403991699219
loss: 0.4940509796142578


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:02<?, ?it/s]

loss: 0.5093805193901062
loss: 0.5123497247695923
loss: 0.5690261125564575
loss: 0.5355029702186584
loss: 0.6420365571975708


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:02<?, ?it/s]

loss: 0.4739060401916504
loss: 0.45418739318847656
loss: 0.5201643109321594
loss: 0.5374792218208313
loss: 0.4717155396938324


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:02<?, ?it/s]

loss: 0.49420779943466187
loss: 0.5692234039306641
loss: 0.5540987253189087
loss: 0.537847638130188
loss: 0.5032912492752075


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:02<?, ?it/s]

loss: 0.5461724996566772
loss: 0.47859805822372437
loss: 0.489015132188797
loss: 0.5713825225830078
loss: 0.49997639656066895


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:03<?, ?it/s]

loss: 0.5413932204246521
loss: 0.49666327238082886
loss: 0.5045497417449951
loss: 0.5448247194290161
loss: 0.4937674403190613


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:03<?, ?it/s]

loss: 0.5411781668663025
loss: 0.486542284488678
loss: 0.46902143955230713
loss: 0.575258731842041
loss: 0.569782018661499


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:03<?, ?it/s]

loss: 0.4520796239376068
loss: 0.466191828250885
loss: 0.4747021496295929
loss: 0.4974052906036377
loss: 0.5050523281097412


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:03<?, ?it/s]

loss: 0.5825802087783813
loss: 0.46620848774909973
loss: 0.5143654942512512
loss: 0.48063504695892334
loss: 0.46445274353027344


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:04<?, ?it/s]

loss: 0.4981779456138611
loss: 0.5354379415512085
loss: 0.474453330039978
loss: 0.557253360748291
loss: 0.43746915459632874


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:04<?, ?it/s]

loss: 0.5012587308883667
loss: 0.49300187826156616
loss: 0.6555629372596741
loss: 0.5537574291229248
loss: 0.4676777124404907


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:04<?, ?it/s]

loss: 0.43724489212036133
loss: 0.4777649939060211
loss: 0.4391739070415497
loss: 0.5670796632766724
loss: 0.461762934923172


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:04<?, ?it/s]

loss: 0.4494091868400574
loss: 0.4682985842227936
loss: 0.4750130772590637
loss: 0.48602044582366943
loss: 0.4622310400009155


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:05<?, ?it/s]

loss: 0.5157997012138367
loss: 0.5318399667739868
loss: 0.4941781759262085
loss: 0.47181665897369385
loss: 0.48596546053886414


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:05<?, ?it/s]

loss: 0.5119925737380981
loss: 0.5438975095748901
loss: 0.527809739112854
loss: 0.530188798904419
loss: 0.5553250312805176


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:05<?, ?it/s]

loss: 0.5000234246253967
loss: 0.4896053671836853
loss: 0.5014001131057739
loss: 0.38670942187309265
loss: 0.46811947226524353


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:05<?, ?it/s]

loss: 0.5459826588630676
loss: 0.513683557510376
loss: 0.40534287691116333
loss: 0.5084928274154663
loss: 0.4886758029460907


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:06<?, ?it/s]

loss: 0.4758909046649933
loss: 0.4649079442024231
loss: 0.45940619707107544
loss: 0.46280282735824585
loss: 0.5146344900131226


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:06<?, ?it/s]

loss: 0.4735131859779358
loss: 0.5110135674476624
loss: 0.6561450362205505
loss: 0.4792681336402893
loss: 0.5401037335395813


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:06<?, ?it/s]

loss: 0.39926043152809143
loss: 0.3911281228065491
loss: 0.45827996730804443
loss: 0.43375465273857117
loss: 0.43107128143310547


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:06<?, ?it/s]

loss: 0.49711892008781433
loss: 0.5112186670303345
loss: 0.48219603300094604
loss: 0.5261686444282532
loss: 0.4770900309085846


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:07<?, ?it/s]

loss: 0.5050716400146484
loss: 0.424006849527359
loss: 0.5261406302452087
loss: 0.5121563673019409
loss: 0.46737784147262573


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:07<?, ?it/s]

loss: 0.46901875734329224
loss: 0.4551428556442261
loss: 0.5050830245018005
loss: 0.4518866539001465
loss: 0.4814850687980652


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:07<?, ?it/s]

loss: 0.4819481372833252
loss: 0.534110963344574
loss: 0.4251474142074585
loss: 0.5552624464035034
loss: 0.48748791217803955


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:07<?, ?it/s]

loss: 0.41850006580352783
loss: 0.5276327729225159
loss: 0.46134114265441895
loss: 0.508721649646759
loss: 0.46908894181251526


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:08<?, ?it/s]

loss: 0.521167516708374
loss: 0.48932141065597534
loss: 0.46153396368026733
loss: 0.4699472188949585
loss: 0.44696980714797974


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:08<?, ?it/s]

loss: 0.4940582811832428
loss: 0.5041068196296692
loss: 0.5258564352989197
loss: 0.4459315538406372
loss: 0.480510950088501


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:08<?, ?it/s]

loss: 0.42443615198135376
loss: 0.39924484491348267
loss: 0.5269871354103088
loss: 0.4557105004787445
loss: 0.44570356607437134


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:08<?, ?it/s]

loss: 0.5410065054893494
loss: 0.46842634677886963
loss: 0.4822745621204376
loss: 0.5665594339370728
loss: 0.6483844518661499


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:08<?, ?it/s]

loss: 0.4846384525299072
loss: 0.43887749314308167
loss: 0.5246306657791138
loss: 0.5096452236175537
loss: 0.49886399507522583


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:09<?, ?it/s]

loss: 0.4653892517089844
loss: 0.49057167768478394
loss: 0.42611604928970337
loss: 0.4733826518058777
loss: 0.4001539349555969


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:09<?, ?it/s]

loss: 0.4658302366733551
loss: 0.49045228958129883
loss: 0.4130503535270691
loss: 0.45727071166038513
loss: 0.49089258909225464


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:09<?, ?it/s]

loss: 0.4751192033290863
loss: 0.44711166620254517
loss: 0.5647507905960083
loss: 0.5285124778747559
loss: 0.43312662839889526


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:09<?, ?it/s]

loss: 0.49078306555747986
loss: 0.4528542160987854
loss: 0.48955193161964417
loss: 0.5036443471908569
loss: 0.4933597445487976


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:10<?, ?it/s]

loss: 0.46323317289352417
loss: 0.4822259545326233
loss: 0.47299495339393616
loss: 0.5207004547119141
loss: 0.4995160698890686


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:10<?, ?it/s]

loss: 0.5509845018386841
loss: 0.4649443030357361
loss: 0.47470855712890625
loss: 0.4716569781303406
loss: 0.4292088449001312


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:10<?, ?it/s]

loss: 0.5389305949211121
loss: 0.4684444069862366
loss: 0.5153197050094604
loss: 0.5853538513183594
loss: 0.4639259874820709


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:10<?, ?it/s]

loss: 0.44693800806999207
loss: 0.5210949182510376
loss: 0.46951574087142944
loss: 0.44994473457336426
loss: 0.4922454059123993


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:11<?, ?it/s]

loss: 0.5363374948501587
loss: 0.4727838337421417
loss: 0.40182483196258545
loss: 0.47520095109939575
loss: 0.5307572484016418


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:11<?, ?it/s]

loss: 0.5116574168205261
loss: 0.4617246389389038
loss: 0.503728449344635
loss: 0.49595189094543457
loss: 0.518054187297821


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:11<?, ?it/s]

loss: 0.4496854245662689
loss: 0.553199291229248
loss: 0.4581904411315918
loss: 0.4520050883293152
loss: 0.5408293604850769


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:11<?, ?it/s]

loss: 0.42990002036094666
loss: 0.5800740718841553
loss: 0.4322085380554199
loss: 0.4585990905761719
loss: 0.5278629064559937


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:12<?, ?it/s]

loss: 0.489776074886322
loss: 0.46418118476867676
loss: 0.4931550920009613
loss: 0.4981940984725952
loss: 0.4076874256134033


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:12<?, ?it/s]

loss: 0.5457660555839539
loss: 0.5236266255378723
loss: 0.4469335675239563
loss: 0.47784915566444397
loss: 0.5025129914283752


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:12<?, ?it/s]

loss: 0.4776964783668518
loss: 0.5233088731765747
loss: 0.5217671394348145
loss: 0.5834490060806274
loss: 0.44416719675064087


                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
                                                                                                                                               
  0%|                                                                                                                  | 0/100 [00:12<?, ?it/s]

loss: 0.5613101124763489
loss: 0.48074063658714294
loss: 0.5003455877304077
loss: 0.49599993228912354
loss: 0.5002671480178833


                                                                                                                                               
                                                                                                                                               
 31%|████████████████████████████████▏                                                                       | 266/860 [00:13<00:29, 20.42it/s]
  0%|                                                                                                                  | 0/100 [00:13<?, ?it/s]


loss: 0.48379987478256226
loss: 0.46983882784843445


KeyboardInterrupt: 