In [1]:
import torch
from torch import Tensor

import os

from torch_geometric.datasets import MD17

from torch.nn import Embedding, Linear, MSELoss, SiLU
from torch_geometric.nn import global_add_pool
from torch_geometric.nn.conv import MessagePassing

from gaussian_rbf import gaussian_rbf
from losses import CalcF_squared_loss, CalcF_absolute_loss

from torch.optim import Adam
import torch.optim.lr_scheduler as lr_scheduler
from torch_geometric.loader import DataLoader

import wandb

In [2]:
# setting up wandb
os.environ['WANDB_NOTEBOOK_NAME'] = 'EGNN4.ipynb'
wandb.login()

# reproducibility
torch.manual_seed(2002)

[34m[1mwandb[0m: Currently logged in as: [33msharshe[0m. Use [1m`wandb login --relogin`[0m to force relogin


<torch._C.Generator at 0x10da24030>

In [3]:
class EGNN4(MessagePassing):
    def __init__(self):
        super().__init__()
        
        # activation function
        self.act = SiLU()
        
        # initialize layers
        # 118 atomic numbers into 32-dimensional space
        self.embedding = Embedding(118,64)
        
        # 64 dimensions for the embedding of the neighbor
        # 8 for the embedding of the distance
        self.message_lin = Linear(64 + 8, 64)
        
        # 64 dimensions for the current node embedding
        # 64 for the message
        self.update_lin = Linear(64 + 64, 64)
        
        # 64 dimensions for the embedding in and out
        self.atomwise_lin1 = Linear(64, 64)
        self.atomwise_lin2 = Linear(64, 64)
        self.atomwise_lin3 = Linear(64, 64)
        
        # compress the 32-dimensional node embedding to 1 dimension
        self.compress_lin1 = Linear(64, 8)
        self.compress_lin2 = Linear(8, 1)
        
    def forward(self, data):
        # get attributes out of data object
        edge_index = data.edge_index
        z = data.z
        pos = data.pos
        
        # force is the negative gradient of energy with respect to position, so pos must be on the computational graph
        pos.requires_grad_(True)
        
        # calculate edge distances and turn them into a vector through Gaussian RBF
        idx1, idx2 = edge_index
        edge_attr = torch.norm(pos[idx1] - pos[idx2], p=2, dim=-1).view(-1, 1)
        gaussian_edge_attr = gaussian_rbf(edge_attr)
        
        # embed
        E_hat = self.embedding(z)
        E_hat = self.act(E_hat)
        
        # message passing x 3
        # message passing 1
        E_hat = self.propagate(edge_index, x=E_hat, edge_attr=gaussian_edge_attr)
        E_hat = self.act(E_hat)
        E_hat = self.atomwise_lin1(E_hat)
        E_hat = self.act(E_hat)
        
        # message passing 2
        E_hat = self.propagate(edge_index, x=E_hat, edge_attr=gaussian_edge_attr)
        E_hat = self.act(E_hat)
        E_hat = self.atomwise_lin1(E_hat)
        E_hat = self.act(E_hat)
        
        # message passing 3
        E_hat = self.propagate(edge_index, x=E_hat, edge_attr=gaussian_edge_attr)
        E_hat = self.act(E_hat)
        E_hat = self.atomwise_lin1(E_hat)
        E_hat = self.act(E_hat)

        # compression
        E_hat = self.compress_lin1(E_hat)
        E_hat = self.act(E_hat)
        E_hat = self.compress_lin2(E_hat)
        E_hat = self.act(E_hat)
        E_hat = global_add_pool(E_hat, data.batch)
        
        # calculate the energy prediction as the negative gradient of energy with respect to position, retaining the computational graph for backprop
        F_hat = -torch.autograd.grad(E_hat.sum(), pos, retain_graph=True)[0]
        
        # return a tuple of the predictions
        return E_hat, F_hat
    
    def message(self, x_j, edge_attr):
        # concatenate the vectors
        lin_in = torch.cat((x_j, edge_attr), dim=1).float()
        
        # pass them into the linear layer
        out = self.message_lin(lin_in)
        
        # return the output
        return out
    
    def update(self, aggr_out, x):
        # concatenate the vectors
        lin_in = torch.cat((aggr_out, x), dim=1).float()
        
        # pass them into the linear layer
        out = self.update_lin(lin_in)
        
        # return the output
        return out

In [4]:
# model hyperparameters
base_learning_rate = 0.001
num_epochs = 50
scheduler_mode = 'min'
scheduler_factor = 0.32
scheduler_patience = 1
scheduler_threshold = 0
rho = 1-1e-1

In [5]:
# initialize model
model = EGNN4()

# I couldn't think of a concise way to initialize optimizer, scheduler, and loss_fn based on the contents of config
# this is all for show anyway, but it would be nice to have a natural way of doing this that generalizes when I am selecting hyperparameters more carefully
optimizer = Adam(model.parameters(), lr=base_learning_rate)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode=scheduler_mode, factor=scheduler_factor, patience=scheduler_patience, threshold=scheduler_threshold)
loss_fn = MSELoss()

In [6]:
config = {
    'base_learning_rate': base_learning_rate,
    'num_epochs': num_epochs,
    'optimizer': 'Adam',
    'scheduler': 'ReduceLROnPlateau',
    'scheduler_mode': 'min',
    'scheduler_factor': scheduler_factor, 
    'scheduler_patience': scheduler_patience,
    'scheduler_threshold': scheduler_threshold,
    'training_loss_fn': 'MSELoss',
    'rho': rho
}

In [7]:
# load in dataset
dataset = MD17(root='../../data/EGNN2/benzene', name='benzene', pre_transform=None, transform=None)

# 80/10/10 split
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

# build train, val, test datasets out of main dataset
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

# turn into DataLoaders for batching efficiency
train_loader = DataLoader(train_dataset, batch_size=32)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)



In [8]:
# val statistics
val_mean_losses = []
val_mean_E_losses = []
val_mean_F_losses = []

# test statistics using the same loss function as training
test_squared_losses = []
test_E_squared_losses = []
test_F_squared_losses = []

# test statistics using MAE for comparison with other benchmarks
test_absolute_losses = []
test_E_absolute_losses = []
test_F_absolute_losses = []

# learning rates
learning_rates = [base_learning_rate]

# initialize wandb run
wandb.init(
    project = "EGNN",
    config = config,
)

In [10]:
# training loop occurs num_epochs times
for epoch in range(num_epochs):
    # TRAINING
    # track gradients
    model.train()
    
    # dummy variable to track loss every 100 batches
    i = 0
    
    # loop through loader
    for data in train_loader:
        # clear gradients
        optimizer.zero_grad()

        # target values
        E = data.energy
        F = data.force
        
        # predictions from the model
        E_hat, F_hat = model(data)
        
        # squared error for energy loss
        E_loss = (1 - rho) * loss_fn(torch.squeeze(E_hat), E)

        # a version of squared error for force loss
        F_loss = rho * CalcF_squared_loss(F_hat, F)
        
        # canonical loss
        loss = E_loss + F_loss
    
        # calculate gradients
        loss.backward()
        
        # update
        optimizer.step()
        
        # save loss every 100 goes
        if i%100 == 0:
            wandb.log({"train_losses": loss.item()})
            wandb.log({"E_train_losses": E_loss.item()})
            wandb.log({"F_train_losses": F_loss.item()})
            
            # save learning rate
            lr = optimizer.param_groups[0]['lr']
            wandb.log({"training_rates": lr})
        i+=1
    
    # VAL
    epoch_losses = []
    epoch_E_losses = []
    epoch_F_losses = []
    
    # do not track gradients
    model.eval()
    
    # loop through val loader
    for data in val_loader:
        # target values
        E = data.energy
        F = data.force
        
        # predictions from the model
        E_hat, F_hat = model(data)
        
        # squared error for energy loss
        E_loss = (1 - rho) * loss_fn(torch.squeeze(E_hat), E)
        
        # a version of squared error for force loss
        F_loss = rho * CalcF_squared_loss(F_hat, F)
        
        # canonical loss
        loss =  E_loss + F_loss
        
        # track F_loss, E_loss, canonical loss
        epoch_losses.append(loss.item())
        epoch_E_losses.append(E_loss.item())
        epoch_F_losses.append(F_loss.item())
    
    # calculate the mean losses from this epoch
    epoch_mean_loss = torch.mean(torch.tensor(epoch_losses)).item()
    epoch_mean_E_loss = torch.mean(torch.tensor(epoch_E_losses)).item()
    epoch_mean_F_loss = torch.mean(torch.tensor(epoch_F_losses)).item()
    
    # save the mean canonical loss from this epoch for comparison to that of other epochs to determine whether to save weights
    val_mean_losses.append(epoch_mean_loss)
    
    # log mean losses with wandb
    wandb.log({"epoch_mean_loss": epoch_mean_loss})
    wandb.log({"epoch_mean_E_loss": epoch_mean_E_loss})
    wandb.log({"epoch_mean_F_loss": epoch_mean_F_loss})
    
    # print out the results of the epoch
    print(f'EPOCH {epoch+1} OF {num_epochs} | VAL MEAN LOSS: {epoch_mean_loss}')
    
    # if this is our best val performance yet, save the weights
    if min(val_mean_losses) == epoch_mean_loss:
        torch.save(model, '../weights/EGNN4.pth')
        
    scheduler.step(epoch_mean_loss)

EPOCH 1 OF 50 | VAL MEAN LOSS: 1.3183922646931023e-06
EPOCH 2 OF 50 | VAL MEAN LOSS: 9.860318783694311e-08
EPOCH 3 OF 50 | VAL MEAN LOSS: 1.1990578308029853e-08
EPOCH 4 OF 50 | VAL MEAN LOSS: 1.6386486834107927e-08
EPOCH 5 OF 50 | VAL MEAN LOSS: 1.869754271410784e-07
EPOCH 6 OF 50 | VAL MEAN LOSS: 8.231103798550521e-09
EPOCH 7 OF 50 | VAL MEAN LOSS: 1.1994077731003472e-08
EPOCH 8 OF 50 | VAL MEAN LOSS: 1.0956646256943259e-08
EPOCH 9 OF 50 | VAL MEAN LOSS: 6.999739898816415e-09
EPOCH 10 OF 50 | VAL MEAN LOSS: 7.273564861520754e-09
EPOCH 11 OF 50 | VAL MEAN LOSS: 7.336788510059478e-09
EPOCH 12 OF 50 | VAL MEAN LOSS: 6.0434284243626735e-09
EPOCH 13 OF 50 | VAL MEAN LOSS: 6.040112410232723e-09
EPOCH 14 OF 50 | VAL MEAN LOSS: 6.035252742009334e-09
EPOCH 15 OF 50 | VAL MEAN LOSS: 6.031712906917619e-09
EPOCH 16 OF 50 | VAL MEAN LOSS: 6.027954579934658e-09
EPOCH 17 OF 50 | VAL MEAN LOSS: 6.024590160080834e-09
EPOCH 18 OF 50 | VAL MEAN LOSS: 6.0208149577078984e-09
EPOCH 19 OF 50 | VAL MEAN LOSS

In [11]:
# TEST
for data in test_loader:
    # target values
    E = data.energy
    F = data.force
    
    # predictions from the model
    E_hat, F_hat = model(data)
    
    # squared error for energy loss
    E_squared_loss = loss_fn(torch.squeeze(E_hat), E) * (1-rho)
    
    # a version of squared error for force loss
    F_squared_loss = CalcF_squared_loss(F_hat, F) * rho
    
    # canonical loss
    squared_loss = E_squared_loss + F_squared_loss
    
    # squared error for energy loss
    E_absolute_loss = (1 - rho) * torch.mean(torch.abs(torch.squeeze(E_hat)-E))
    
    # a version of squared error for force loss
    F_absolute_loss = rho * CalcF_absolute_loss(F_hat, F)
    
    # canonical loss
    absolute_loss = E_absolute_loss + F_absolute_loss
    print(absolute_loss)
    
    # save squared losses
    test_squared_losses.append(squared_loss.item())
    test_E_squared_losses.append(E_squared_loss.item())
    test_F_squared_losses.append(F_squared_loss.item())
    
    # save absolute losses
    test_absolute_losses.append(absolute_loss.item())
    test_E_absolute_losses.append(E_absolute_loss.item())
    test_F_absolute_losses.append(F_absolute_loss.item())

# calculate and log mean test losses
test_mean_squared_loss = torch.mean(torch.tensor(test_squared_losses)).item()
test_mean_E_squared_loss = torch.mean(torch.tensor(test_E_squared_losses)).item()
test_mean_F_squared_loss = torch.mean(torch.tensor(test_F_squared_losses)).item()

wandb.log({"test_mean_squared_loss": test_mean_squared_loss})
wandb.log({"test_mean_E_squared_loss": test_mean_E_squared_loss})
wandb.log({"test_mean_F_squared_loss": test_mean_F_squared_loss})

test_mean_absolute_loss = torch.mean(torch.tensor(test_absolute_losses)).item()
test_mean_E_absolute_loss = torch.mean(torch.tensor(test_E_absolute_losses)).item()
test_mean_F_absolute_loss = torch.mean(torch.tensor(test_F_absolute_losses)).item()

wandb.log({"test_mean_absolute_loss": test_mean_absolute_loss})
wandb.log({"test_mean_E_absolute_loss": test_mean_E_absolute_loss})
wandb.log({"test_mean_F_absolute_loss": test_mean_F_absolute_loss})

# print mean test losses
print(f'TEST MEAN SQUARED LOSS: {test_mean_squared_loss}')
print(f'TEST MEAN ABSOLUTE LOSS: {test_mean_squared_loss}')

tensor(9.5996e-05, grad_fn=<AddBackward0>)
tensor(9.2079e-05, grad_fn=<AddBackward0>)
tensor(9.2122e-05, grad_fn=<AddBackward0>)
tensor(9.5046e-05, grad_fn=<AddBackward0>)
tensor(9.3187e-05, grad_fn=<AddBackward0>)
tensor(9.7777e-05, grad_fn=<AddBackward0>)
tensor(9.3921e-05, grad_fn=<AddBackward0>)
tensor(9.7456e-05, grad_fn=<AddBackward0>)
tensor(9.5569e-05, grad_fn=<AddBackward0>)
tensor(9.2585e-05, grad_fn=<AddBackward0>)
tensor(9.3244e-05, grad_fn=<AddBackward0>)
tensor(9.1006e-05, grad_fn=<AddBackward0>)
tensor(9.8178e-05, grad_fn=<AddBackward0>)
tensor(9.2025e-05, grad_fn=<AddBackward0>)
tensor(9.3868e-05, grad_fn=<AddBackward0>)
tensor(9.4773e-05, grad_fn=<AddBackward0>)
tensor(8.8627e-05, grad_fn=<AddBackward0>)
tensor(9.0039e-05, grad_fn=<AddBackward0>)
tensor(9.2453e-05, grad_fn=<AddBackward0>)
tensor(9.3909e-05, grad_fn=<AddBackward0>)
tensor(9.1208e-05, grad_fn=<AddBackward0>)
tensor(9.2506e-05, grad_fn=<AddBackward0>)
tensor(9.2170e-05, grad_fn=<AddBackward0>)
tensor(9.88

In [12]:
wandb.finish()

0,1
E_train_losses,▂▁▄█▂▁▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
F_train_losses,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch_mean_E_loss,█▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch_mean_F_loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch_mean_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_mean_E_absolute_loss,▁
test_mean_E_squared_loss,▁
test_mean_F_absolute_loss,▁
test_mean_F_squared_loss,▁
test_mean_absolute_loss,▁

0,1
E_train_losses,0.0
F_train_losses,0.0
epoch_mean_E_loss,0.0
epoch_mean_F_loss,0.0
epoch_mean_loss,0.0
test_mean_E_absolute_loss,0.0
test_mean_E_squared_loss,0.0
test_mean_F_absolute_loss,9e-05
test_mean_F_squared_loss,0.0
test_mean_absolute_loss,9e-05
