In [182]:
import torch
from torch import Tensor

from torch_geometric.datasets import MD17

from torch.nn import Module, Embedding, Linear, MSELoss, LeakyReLU, SiLU
from torch_geometric.nn import global_add_pool
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.data import Data

from torch.optim import Adam
import torch.optim.lr_scheduler as lr_scheduler
from torch_geometric.loader import DataLoader

import numpy as np

import wandb

In [183]:
# load in dataset
dataset = MD17(root='../../data/EGNN2/benzene', name='benzene', pre_transform=None, transform=None)

# 80/10/10 split
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

# build train, val, test datasets out of main dataset
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

# turn into DataLoaders for batching efficiency
train_loader = DataLoader(train_dataset, batch_size=128)
val_loader = DataLoader(val_dataset, batch_size=128)
test_loader = DataLoader(test_dataset, batch_size=128)



In [184]:
class EGNN2(MessagePassing):
    def __init__(self):
        super().__init__()
        
        self.embedding = Embedding(118,32)
        
        self.message_lin = Linear(32 + 8, 32)
        self.update_lin = Linear(32 + 32, 32)
        
        self.compress_lin = Linear(32, 1)
        
    def forward(self, data):
        edge_index = data.edge_index
        z = data.z
        pos = data.pos
        pos.requires_grad_(True)
        
        idx1, idx2 = edge_index
        edge_distance = torch.norm(pos[idx1] - pos[idx2], p=2, dim=-1).view(-1, 1)
        gaussian_edge_attr = gaussian_rbf(edge_distance)
        
        E_hat = self.embedding(z)
        
        E_hat = self.propagate(edge_index, x=E_hat, edge_attr=gaussian_edge_attr)
        
        E_hat = self.compress_lin(E_hat)
                
        E_hat = global_add_pool(E_hat, data.batch)
        
        F_hat = -torch.autograd.grad(E_hat.sum(), pos, retain_graph=True)[0]
        
        return E_hat, F_hat
    
    def message(self, x_j, edge_attr):
        lin_in = torch.cat((x_j, edge_attr), dim=1).float()
        
        out = self.message_lin(lin_in)
        
        return out
    
    def update(self, aggr_out, x):
        lin_in = torch.cat((aggr_out, x), dim=1).float()
        
        return self.update_lin(lin_in)

In [185]:
def gaussian_rbf(x: Tensor) -> Tensor:
    cs = torch.tensor(np.arange(0,1.6,0.2))
    return torch.exp(torch.square((x - cs)) / -.005).float()

In [186]:
model = EGNN2()

base_learning_rate = 0.0001
num_epochs = 10

optimizer = Adam(model.parameters(), base_learning_rate)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min', factor=0.1, patience=1, threshold=0)
loss_fn = MSELoss()

val_mean_losses = []
val_mean_E_losses = []
val_mean_F_losses = []

test_losses = []
test_E_losses = []
test_F_losses = []

learning_rates = [base_learning_rate]

In [187]:
def CalcF_loss(F: Tensor, F_hat: Tensor) -> Tensor:
        # average square of the magnitude of the difference between the predicted and actual force vectors on each atom
        # also squared error, just a more complicated calculation
        # multiplied by 100 so that the losses for F and E begin on the same OOM
        F_error = F_hat - F
        F_squared_error = torch.square(F_error)
        F_atomwise_error_magnitudes = torch.sum(F_squared_error, dim=1)
        F_loss = torch.div(torch.sum(F_squared_error), F.size()[0]) * 100
        return F_loss

In [188]:
wandb.init(
    project = "EGNN",
    config = {
        "model": "RBF EGNN",
        "embed_dimension": 32,
        "optimizer": "Adam",
        "base_learning_rate": 0.0001,
        "dataset": "MD17",
        "scheduler": "ReduceLROnPlateau",
        "epochs": num_epochs
    }
)

0,1
E_train_losses,█▁▁▁▁▁▁▁▁▁▁
F_train_losses,█▆▄▃▂▁▁▁▁▁▁
train_losses,█▆▄▃▂▁▁▁▁▁▁
training_rates,▁▁▁▁▁▁▁▁▁▁▁

0,1
E_train_losses,0.0
F_train_losses,103405.40625
train_losses,103405.40625
training_rates,0.0001


In [189]:
for epoch in range(num_epochs):
    # TRAINING
    # track gradients
    model.train()
    
    i = 0
    
    # loop through loader
    for data in train_loader:
        # clear gradients again for good measure
        optimizer.zero_grad()

        # target values
        E = data.energy
        F = data.force
        
        # predictions from the model
        E_hat, F_hat = model(data)
        
        # squared error for energy loss
        E_loss = loss_fn(E_hat, E)

        # a version of squared error for force loss
        F_loss = CalcF_loss(F_hat, F)
        
        # canonical loss
        loss = F_loss + E_loss
    
        # calculate gradients
        loss.backward()
        
        # update
        optimizer.step()
        
        # save loss every 100 goes
        if i%100 == 0:
            wandb.log({"train_losses": loss.item()})
            wandb.log({"E_train_losses": E_loss.item()})
            wandb.log({"F_train_losses": F_loss.item()})
            
            # save learning rate
            lr = optimizer.param_groups[0]['lr']
            wandb.log({"training_rates": lr})
        i+=1
        
    
    # VAL
    epoch_losses = []
    epoch_E_losses = []
    epoch_F_losses = []
    
    # do not track gradients
    model.eval()
    
    # loop through val loader
    for data in val_loader:
        # target values
        E = data.energy
        F = data.force
        
        # predictions from the model
        E_hat, F_hat = model(data)
        
        # squared error for energy loss
        E_loss = loss_fn(E_hat, E)
        
        # a version of squared error for force loss
        F_loss = CalcF_loss(F_hat, F)
        
        # canonical loss
        loss = F_loss + E_loss
        
        # track F_loss, E_loss, canonical loss
        epoch_losses.append(loss.item())
        epoch_E_losses.append(E_loss.item())
        epoch_F_losses.append(F_loss.item())
    
    epoch_mean_loss = torch.mean(torch.tensor(epoch_losses)).item()
    epoch_mean_E_loss = torch.mean(torch.tensor(epoch_E_losses)).item()
    epoch_mean_F_loss = torch.mean(torch.tensor(epoch_F_losses)).item()
    
    val_mean_losses.append(epoch_mean_loss)
    val_mean_E_losses.append(epoch_mean_E_loss)
    val_mean_F_losses.append(epoch_mean_F_loss)
    
    wandb.log({"epoch_mean_loss": epoch_mean_loss})
    wandb.log({"epoch_mean_E_loss": epoch_mean_E_loss})
    wandb.log({"epoch_mean_F_loss": epoch_mean_F_loss})
    
    # print out the results of the epoch
    print(f'EPOCH {epoch+1} OF {num_epochs} | VAL MEAN LOSS: {epoch_mean_loss}')
    
    # if this is our best val performance yet, save the weights
    if min(val_mean_losses) == epoch_mean_loss:
        torch.save(model, '../weights/EGNN2.pth')
        
    scheduler.step(epoch_mean_loss)

# TEST
for data in test_loader:
    # target values
    E = data.energy
    F = data.force
    
    # predictions from the model
    E_hat, F_hat = model(data)
    
    # squared error for energy loss
    E_loss = loss_fn(E_hat, E)
    
    # a version of squared error for force loss
    F_loss = CalcF_loss(F_hat, F)
    
    # canonical loss
    loss = F_loss + E_loss
    
    # save losses
    test_losses.append(loss.item())
    test_E_losses.append(E_loss.item())
    test_F_losses.append(F_loss.item())

# save and print mean test loss
test_mean_loss = torch.mean(torch.tensor(test_losses)).item()
test_mean_E_loss = torch.mean(torch.tensor(test_E_losses)).item()
test_mean_F_loss = torch.mean(torch.tensor(test_F_losses)).item()

wandb.log({"test_mean_loss": test_mean_loss})
wandb.log({"test_mean_E_loss": test_mean_E_loss})
wandb.log({"test_mean_F_loss": test_mean_F_loss})

print(f'TEST MEAN LOSS: {test_mean_loss}')

wandb.finish()

  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


EPOCH 1 OF 10 | VAL MEAN LOSS: 0.011949259787797928
EPOCH 2 OF 10 | VAL MEAN LOSS: 1.244987515747198e-06
EPOCH 3 OF 10 | VAL MEAN LOSS: 0.0008091051713563502
EPOCH 4 OF 10 | VAL MEAN LOSS: 3.4829427022486925e-05
EPOCH 5 OF 10 | VAL MEAN LOSS: 9.100762667912932e-07
EPOCH 6 OF 10 | VAL MEAN LOSS: 7.077346708683763e-06
EPOCH 7 OF 10 | VAL MEAN LOSS: 1.6483418221469037e-06
EPOCH 8 OF 10 | VAL MEAN LOSS: 9.278430752601707e-07
EPOCH 9 OF 10 | VAL MEAN LOSS: 1.006949332804652e-06
EPOCH 10 OF 10 | VAL MEAN LOSS: 9.248040555576154e-07


  return F.mse_loss(input, target, reduction=self.reduction)


TEST MEAN LOSS: 1.0036715138994623e-06


0,1
E_train_losses,█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
F_train_losses,█▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch_mean_E_loss,▁▁█▁▁▁▁▁▁▁
epoch_mean_F_loss,█▁▁▁▁▁▁▁▁▁
epoch_mean_loss,█▁▁▁▁▁▁▁▁▁
test_mean_E_loss,▁
test_mean_F_loss,▁
test_mean_loss,▁
train_losses,█▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
training_rates,████████████████▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁

0,1
E_train_losses,0.0
F_train_losses,0.0
epoch_mean_E_loss,0.0
epoch_mean_F_loss,0.0
epoch_mean_loss,0.0
test_mean_E_loss,0.0
test_mean_F_loss,0.0
test_mean_loss,0.0
train_losses,0.0
training_rates,0.0


In [191]:
wandb.finish()