In [2]:
import torch_geometric
import os
from torch_geometric.datasets import MD17
from torch_geometric.nn import GCNConv
from torch.nn import Module, Embedding, Linear, MSELoss, LeakyReLU
from torch.optim import Adam
from torch_geometric.nn import global_mean_pool
import torch.optim.lr_scheduler as lr_scheduler
from torch_geometric.loader import DataLoader
import torch
import wandb
from torch import Tensor

In [27]:
# setting up wandb
os.environ['WANDB_NOTEBOOK_NAME'] = 'toy_EGNN.ipynb'
wandb.login()

# reproducibility
torch.manual_seed(2002)



<torch._C.Generator at 0x1109700f0>

In [3]:
# load in dataset
dataset = MD17(root='../data/benzene', name='benzene', transform=None, pre_transform=None)

# 80/10/10 split
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

# build train, val, test datasets out of main dataset
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

# turn into DataLoaders for batching efficiency
train_loader = DataLoader(train_dataset, batch_size=128)
val_loader = DataLoader(val_dataset, batch_size=128)
test_loader = DataLoader(test_dataset, batch_size=128)



In [39]:
class ToyGCN(Module):
    def __init__(self):
        super().__init__()
        
        # initialize layers
        self.embedding = Embedding(118, 16)
        self.conv1 = GCNConv(16, 16)
        self.lin1 = Linear(16, 16)
        self.conv2 = GCNConv(16, 16)
        self.lin2 = Linear(16, 4)
        self.lin3 = Linear(4, 1)
        self.non_linearity = LeakyReLU()

    # define forward pass
    def forward(self, data):
        # get relevant parts from data arg
        edge_index = data.edge_index        
        pos = data.pos
        pos.requires_grad = True
        
        # calculate distances between nodes
        edge_attr = torch.sqrt(torch.sum(torch.square(pos[edge_index[0,:]] - pos[edge_index[1,:]]),dim=1))
        
        # initialize E_hat
        E_hat = data.z

        # embed E_hat
        E_hat = self.embedding(E_hat)
        
        # conv layer 1
        E_hat = self.conv1(E_hat, edge_index, edge_attr)
        E_hat = self.non_linearity(E_hat)
        
        # linear layer 1
        E_hat = self.lin1(E_hat)
        E_hat = self.non_linearity(E_hat)
        
        # conv layer 2
        E_hat = self.conv2(E_hat, edge_index, edge_attr)
        E_hat = self.non_linearity(E_hat)
        
        # linear layer 2
        E_hat = self.lin2(E_hat)
        E_hat = self.non_linearity(E_hat)
        
        # linear layer 3: compression
        E_hat = self.lin3(E_hat)
        E_hat = self.non_linearity(E_hat)
        
        # combine representations of all nodes
        # into single graph-level prediction
        E_hat = global_mean_pool(E_hat, data.batch)
        #* sum pool, not mean pool
        
        # calculate the force on each atom, which is the negative gradient of the atom's position
        F_hat = -torch.autograd.grad(E_hat.sum(), pos, create_graph=True)[0]
        
        return E_hat, F_hat

In [41]:
model = ToyGCN()

In [30]:
model = ToyGCN()

base_learning_rate = 0.0001
num_epochs = 1

optimizer = Adam(model.parameters(), base_learning_rate)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min', factor=0.1, patience=1, threshold=0)
loss_fn = MSELoss()

val_mean_losses = []
val_mean_E_losses = []
val_mean_F_losses = []

test_losses = []
test_E_losses = []
test_F_losses = []

learning_rates = [base_learning_rate]

In [31]:
wandb.init(
    project = "EGNN",
    config = {
        "model": "2-layer GCNN",
        "embed_dimension": 8,
        "optimizer": "Adam",
        "base_learning_rate": 0.0001,
        "dataset": "MD17",
        "scheduler": "ReduceLROnPlateau",
        "epochs": num_epochs
    }
)

0,1
E_train_losses,▁
F_train_losses,▁
train_losses,▁
training_rates,▁

0,1
E_train_losses,0.1309
F_train_losses,0.70616
train_losses,0.83707
training_rates,0.0001


In [32]:
def CalcF_loss(F: Tensor, F_hat: Tensor) -> Tensor:
        # average square of the magnitude of the difference between the predicted and actual force vectors on each atom
        # also squared error, just a more complicated calculation
        # multiplied by 1e8 so that the losses for F and E begin on the same OOM
        F_error = F_hat - F
        F_squared_error = torch.square(F_error)
        F_atomwise_error_magnitudes = torch.sum(F_squared_error, dim=1)
        F_loss = torch.div(torch.sum(F_squared_error), F.size()[0]) * 1e8
        return F_loss

In [35]:
for epoch in range(num_epochs):
    TRAINING
    track gradients
    model.train()
    
    i = 0
    
    # loop through loader
    for data in train_loader:
        # target values
        E = data.energy
        F = data.force
        
        # predictions from the model
        E_hat, F_hat = model(data)
        
        # clear out the gradients from autograd (?)
        # (not sure whether this is necessary anymore because I changed the code that calculates the force predictions, but I don't think it can hurt)
        optimizer.zero_grad()

        # squared error for energy loss
        E_loss = loss_fn(E_hat, E)

        # a version of squared error for force loss
        F_loss = CalcF_loss(F_hat.view(-1), F.view(-1))
        
        # canonical loss
        loss = F_loss + E_loss
    
        # calculate gradients
        loss.backward()
        
        # update
        optimizer.step()
    
        # clear gradients again for good measure
        optimizer.zero_grad()
        
        # save loss every 100 goes
        if i%100 == 0:
            wandb.log({"train_losses": loss.item()})
            wandb.log({"E_train_losses": E_loss.item()})
            wandb.log({"F_train_losses": F_loss.item()})
            
            # save learning rate
            lr = optimizer.param_groups[0]['lr']
            wandb.log({"training_rates": lr})
        i+=1
        
    
    # VAL
    epoch_losses = []
    epoch_E_losses = []
    epoch_F_losses = []
    
    # do not keep track of gradients
    model.eval()
    
    # loop through val loader
    for data in val_loader:
        # target values
        E = data.energy
        F = data.force
        
        # predictions from the model
        E_hat, F_hat = model(data)
        
        # clear out the gradients from autograd (?)
        # (not sure whether this is necessary anymore because I changed the code that calculates the force predictions, but I don't think it can hurt)
        optimizer.zero_grad()

        # squared error for energy loss
        E_loss = loss_fn(E_hat, E)
        
        # a version of squared error for force loss
        F_loss = CalcF_loss(F_hat, F)
        
        # canonical loss
        loss = F_loss + E_loss
        
        # track F_loss, E_loss, canonical loss
        epoch_losses.append(loss.item())
        epoch_E_losses.append(E_loss.item())
        epoch_F_losses.append(F_loss.item())
    
    epoch_mean_loss = torch.mean(torch.tensor(epoch_losses)).item()
    epoch_mean_E_loss = torch.mean(torch.tensor(epoch_E_losses)).item()
    epoch_mean_F_loss = torch.mean(torch.tensor(epoch_F_losses)).item()
    
    val_mean_losses.append(epoch_mean_loss)
    val_mean_E_losses.append(epoch_mean_E_loss)
    val_mean_F_losses.append(epoch_mean_F_loss)
    
    wandb.log({"epoch_mean_loss": epoch_mean_loss})
    wandb.log({"epoch_mean_E_loss": epoch_mean_E_loss})
    wandb.log({"epoch_mean_F_loss": epoch_mean_F_loss})
    
    # print out the results of the epoch
    print(f'EPOCH {epoch+1} OF {num_epochs} | VAL MEAN LOSS: {epoch_mean_loss}')
    
    # if this is our best val performance yet, save the weights
    if min(val_mean_losses) == epoch_mean_loss:
        torch.save(model, 'weights/toy_EGNN.pth')
        
    scheduler.step(epoch_mean_loss)

# TEST
for data in test_loader:
    # target values
    E = data.energy
    F = data.force
    
    # predictions from the model
    E_hat, F_hat = model(data)
    
    # clear out the gradients from autograd (?)
    # (not sure whether this is necessary anymore because I changed the code that calculates the force predictions, but I don't think it can hurt)
    optimizer.zero_grad()

    # squared error for energy loss
    E_loss = loss_fn(E_hat, E)
    
    # a version of squared error for force loss
    F_loss = CalcF_loss(F_hat, F)
    
    # canonical loss
    loss = F_loss + E_loss
    
    # save losses
    test_losses.append(loss.item())
    test_E_losses.append(E_loss.item())
    test_F_losses.append(F_loss.item())

# save and print mean test loss
test_mean_loss = torch.mean(torch.tensor(test_losses)).item()
test_mean_E_loss = torch.mean(torch.tensor(test_E_losses)).item()
test_mean_F_loss = torch.mean(torch.tensor(test_F_losses)).item()

wandb.log({"test_mean_loss": test_mean_loss})
wandb.log({"test_mean_E_loss": test_mean_E_loss})
wandb.log({"test_mean_F_loss": test_mean_F_loss})

print(f'TEST MEAN LOSS: {test_mean_loss}')

wandb.finish()

EPOCH 1 OF 1 | VAL MEAN LOSS: 0.7702529430389404


  return F.mse_loss(input, target, reduction=self.reduction)


TEST MEAN LOSS: 0.7712029218673706


0,1
E_train_losses,███████▇▆▅▅▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
F_train_losses,▁▁▁▁▁▁▁▆▂▁▁▁▂█▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▁
epoch_mean_E_loss,▁▁
epoch_mean_F_loss,▁▁
epoch_mean_loss,▁▁
test_mean_E_loss,▁
test_mean_F_loss,▁
test_mean_loss,▁
train_losses,▁▁▁▁▁▁▁▆▂▁▁▁▂█▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▁
training_rates,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
E_train_losses,7e-05
F_train_losses,0.67965
epoch_mean_E_loss,7e-05
epoch_mean_F_loss,0.77019
epoch_mean_loss,0.77025
test_mean_E_loss,7e-05
test_mean_F_loss,0.77114
test_mean_loss,0.7712
train_losses,0.67973
training_rates,0.0001


In [37]:
data = train_dataset[0]

In [38]:
E_hat, F_hat = model(data)

In [40]:
print(data.force)
print(F_hat)
print(F_hat - data.force)

tensor([[ 1.4450e-04, -5.4219e-06, -2.6049e-05],
        [-1.1722e-04,  4.0329e-05,  2.9376e-06],
        [-4.2287e-06, -7.1831e-05, -3.1632e-05],
        [ 4.5677e-05,  7.7480e-05,  5.2616e-05],
        [-3.8878e-05, -9.1347e-05, -2.6461e-05],
        [ 2.2421e-06,  4.5081e-05, -2.8685e-05],
        [ 6.1412e-06,  5.6802e-05, -2.6758e-06],
        [-1.2903e-05, -3.2683e-05, -7.0144e-07],
        [ 4.6507e-06,  2.8377e-05,  2.6073e-05],
        [ 6.3945e-06, -4.1892e-05, -3.3723e-06],
        [ 4.5682e-05,  4.0348e-05,  5.4611e-06],
        [-8.2055e-05, -4.5241e-05,  3.2487e-05]])
tensor([[-4.9841e-06, -2.0859e-06, -2.1329e-06],
        [ 1.3823e-06,  1.7118e-06,  9.9114e-07],
        [ 4.5945e-06, -3.4423e-06, -3.7060e-06],
        [ 2.5680e-06,  1.8788e-06,  4.3330e-06],
        [-2.7541e-06, -1.5846e-06, -7.1490e-06],
        [-6.4554e-06, -1.1896e-06,  5.2323e-07],
        [-1.0568e-05,  1.2414e-05, -5.6979e-06],
        [ 4.8518e-06,  1.4566e-05, -1.2702e-05],
        [ 1.5571e-0