In [8]:
import torch
from torch import Tensor

from torch_geometric.datasets import MD17

from torch.nn import Module, Embedding, Linear, MSELoss, LeakyReLU, SiLU
from torch_geometric.nn import global_add_pool
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.data import Data

from torch.optim import Adam
import torch.optim.lr_scheduler as lr_scheduler
from torch_geometric.loader import DataLoader

import numpy as np

import wandb

In [9]:
benzene_dataset = MD17(root='../../data/EGNN2/benzene', name='benzene', pre_transform=None, transform=None)



In [158]:
class EGNN2(MessagePassing):
    def __init__(self):
        super().__init__()
        
        self.embedding = Embedding(118,32)
        
        self.message_lin = Linear(32 + 8, 32)
        self.update_lin = Linear(32 + 32, 32)
        
        self.compress_lin = Linear(32, 1)
        
    def forward(self, data):
        edge_index = data.edge_index
        z = data.z
        pos = data.pos
        pos.requires_grad_(True)
        
        idx1, idx2 = edge_index
        edge_distance = torch.norm(pos[idx1] - pos[idx2], p=2, dim=-1).view(-1, 1)
        gaussian_edge_attr = gaussian_rbf(edge_distance)
        
        E_hat = self.embedding(z)
        
        E_hat = self.propagate(edge_index, x=E_hat, edge_attr=gaussian_edge_attr)
        
        E_hat = self.compress_lin(E_hat)
                
        E_hat = global_add_pool(E_hat, data.batch)
        
        F_hat = -torch.autograd.grad(E_hat.sum(), pos, retain_graph=True)[0]
        
        return E_hat, F_hat
    
    def message(self, x_j, edge_attr):
        lin_in = torch.cat((x_j, edge_attr), dim=1).float()
        
        out = self.message_lin(lin_in)
        
        return out
    
    def update(self, aggr_out, x):
        lin_in = torch.cat((aggr_out, x), dim=1).float()
        
        return self.update_lin(lin_in)
        

In [159]:
e = EGNN2()

In [160]:
e(data)

(tensor([[-0.2419]], grad_fn=<SumBackward1>),
 tensor([[-0.0000,  0.4872, -0.0000],
         [ 0.4219,  0.2439, -0.0000],
         [ 0.4219, -0.2439, -0.0000],
         [-0.0000, -0.4872, -0.0000],
         [-0.4219, -0.2439, -0.0000],
         [-0.4219,  0.2439, -0.0000],
         [-0.0000, -0.0011, -0.0000],
         [-0.0009, -0.0005, -0.0000],
         [-0.0009,  0.0005, -0.0000],
         [-0.0000,  0.0011, -0.0000],
         [ 0.0009,  0.0005, -0.0000],
         [ 0.0009, -0.0005, -0.0000]]))

for reference:  
torch.max(benzene_dataset[:].edge_dist).item() == 0.1417

In [154]:
def gaussian_rbf(x: Tensor) -> Tensor:
    cs = torch.tensor(np.arange(0,1.6,0.2))
    return torch.exp(torch.square((x - cs)) / -.005).float()