In [3]:
# Okay, once we get environment ready, let's try to start with the building a simple gnn network
# We will use the pytorch geometric library

import torch
from torch_geometric.data import Data, DataLoader
from torch_geometric.datasets import QM9 # We will use the QM9 dataset for this example

# Let's start by loading the dataset

path = '/project/dinner/zpengmei/Geom2Vec/Tutorial/data_sets/QM9'

# QM9 has many labels, let's just pick the first one for now, we can do this with Transform object in PyG

class QM9Transform:
    def __call__(self, data):
        # Select target.
        data.y = data.y[:, 0]
        return data

# Load the dataset, you can't do this on the compute node, you need to do this on the login node for internet access

dataset = QM9(path, transform=QM9Transform()).shuffle()

# Normalize targets to mean = 0 and std = 1.
mean = dataset.data.y.mean(dim=0, keepdim=True)
std = dataset.data.y.std(dim=0, keepdim=True)
dataset.data.y = (dataset.data.y - mean) / std
mean, std = mean[:, 0].item(), std[:, 0].item()

dataset

QM9(130831)

In [4]:
# split the dataset into training, validation and test sets

train_dataset = dataset[:110000]
val_dataset = dataset[110000:120000]
test_dataset = dataset[120000:]

# load your data into the DataLoader

test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)




In [5]:
# let's use a naive SchNet model for this example
from torch_geometric.nn import SchNet

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = SchNet(hidden_channels=128, num_filters=128, num_interactions=6).to(device)
print(net)

SchNet(hidden_channels=128, num_filters=128, num_interactions=6, num_gaussians=50, cutoff=10.0)


In [6]:
# you can always accelerate your model via automatic mixed precision
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

# define the optimizer and loss function
from torch.optim import Adam
from torch.nn import L1Loss

optimizer = Adam(net.parameters(), lr=1e-3)
criterion = L1Loss()

# define the training loop

def train():
    net.train()

    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()

        with autocast():
            out = net(data.z, data.pos, data.batch) # GNN operates on atomic numbers, positions and batch vector which assigns each atom to a specific molecule
            loss = criterion(out.view(-1), data.y.view(-1))
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

    return total_loss / len(train_loader)

# define the evaluation loop

@torch.no_grad()
def test(loader):
    net.eval()

    total_loss = 0
    for data in loader:
        data = data.to(device)
        with autocast():
            out = net(data.z, data.pos, data.batch)
            total_loss += criterion(out.view(-1), data.y.view(-1)).item()

    return total_loss / len(loader)


In [10]:
# train the model
from tqdm import tqdm

best_val_loss = None
for epoch in tqdm(range(1, 101)):
    loss = train()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')

    # you can change the frequency of validation and test
    if epoch % 2 == 0:
        val_loss = test(val_loader)
        print(f'Val Loss: {val_loss:.4f}')

        if best_val_loss is None or val_loss <= best_val_loss:
            best_val_loss = val_loss
            # test the accuracy on the test set
            test_loss = test(test_loader)
            print(f'Test Loss: {test_loss:.4f}')


  1%|█                                                                                                         | 1/100 [00:14<23:57, 14.52s/it]

Epoch: 001, Loss: 0.3401
Epoch: 002, Loss: 0.2925
Val Loss: 0.2890


  2%|██                                                                                                        | 2/100 [00:33<27:38, 16.92s/it]

Test Loss: 0.2916


  3%|███▏                                                                                                      | 3/100 [00:47<25:27, 15.74s/it]

Epoch: 003, Loss: 0.2605


  3%|███▏                                                                                                      | 3/100 [00:53<28:49, 17.83s/it]


KeyboardInterrupt: 