In [176]:
import torch
from torch import Tensor

from torch_geometric.datasets import MD17

from torch.nn import Module, Embedding, Linear, MSELoss, LeakyReLU, SiLU
from torch_geometric.nn import global_add_pool
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.data import Data

from torch.optim import Adam
import torch.optim.lr_scheduler as lr_scheduler
from torch_geometric.loader import DataLoader

import wandb

In [177]:
benzene_dataset = MD17(root='../../data/EGNN2/benzene', name='benzene', pre_transform=None, transform=None)



for reference:
torch.max(benzene_dataset[:].edge_dist).item() == 0.07085

In [223]:
class EGNN2(MessagePassing):
    def __init__(self):
        super().__init__(aggr="add")
        
        self.act = SiLU()
    
        self.embedding = Embedding(118, 32)
        self.message_lin = Linear(32 + 7, 32)
        self.update_lin = Linear(32 + 32, 32)
        self.atom_lin_post = Linear(32, 32)
        self.atom_lin_red = Linear(32, 1)
        
    def forward(self, x: Data) -> Tensor:
        edge_index = x.edge_index
        pos = x.pos
        pos.requires_grad_(True)
        
        row, col = edge_index
        edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=-1).view(-1, 1)
        edge_attr = gaussian_rbf(edge_attr)
        
        E_hat = self.embedding(x.z)
        E_hat = self.act(E_hat)
        
        E_hat = self.propagate(x_j=E_hat, edge_attr=edge_attr, edge_index=edge_index)
    
        E_hat = self.atom_lin_post(E_hat)
        E_hat = self.act(E_hat)
        
        E_hat = self.atom_lin_red(E_hat)
        E_hat = self.act(E_hat)
        
        E_hat = global_add_pool(E_hat, x.batch)
        
        F_hat = -torch.autograd.grad(E_hat.sum(), pos, retain_graph=True)[0]
        
        return E_hat, F_hat

    def message(self, x_j: Tensor, edge_attr: Tensor, edge_index: Tensor) -> Tensor:
        print(f'x_j: {x_j}')
        print(f'edge_attr.view(-1): {edge_attr.view(-1)}')
        lin_in = torch.cat((x_j, edge_attr), dim=1)
        return self.message_lin(lin_in)
    
    def update(self, aggr_out: Tensor, x: Tensor) -> Tensor:
        return self.update_lin(torch.cat((aggr_out, x), dim=1))

In [224]:
def gaussian_rbf(x: Tensor) -> Tensor:
    cs = torch.tensor(numpy.arange(0,0.7,0.1))
    diffs = x - cs
    return torch.exp(torch.square(diffs) / -.005)

In [225]:
model = EGNN2()

In [226]:
model(data)

E_hat: tensor([[ 6.3611e-01,  1.3035e+00,  2.8216e-01, -2.6067e-01,  1.8976e-01,
         -2.6602e-01,  6.8586e-01, -1.2785e-01,  1.2170e+00, -5.5395e-02,
          1.6710e+00, -2.2195e-01, -1.9416e-01,  1.5043e-01, -2.7167e-01,
         -2.5344e-01,  8.0119e-02,  4.4342e-01,  3.4498e-01,  1.2750e-01,
         -5.7309e-02,  1.2929e+00, -2.4271e-01,  1.6236e-01, -2.6434e-01,
         -1.3608e-01, -5.3063e-02, -2.7827e-01, -2.6205e-01,  3.3865e-01,
          2.3582e-01,  6.1365e-01],
        [ 6.3611e-01,  1.3035e+00,  2.8216e-01, -2.6067e-01,  1.8976e-01,
         -2.6602e-01,  6.8586e-01, -1.2785e-01,  1.2170e+00, -5.5395e-02,
          1.6710e+00, -2.2195e-01, -1.9416e-01,  1.5043e-01, -2.7167e-01,
         -2.5344e-01,  8.0119e-02,  4.4342e-01,  3.4498e-01,  1.2750e-01,
         -5.7309e-02,  1.2929e+00, -2.4271e-01,  1.6236e-01, -2.6434e-01,
         -1.3608e-01, -5.3063e-02, -2.7827e-01, -2.6205e-01,  3.3865e-01,
          2.3582e-01,  6.1365e-01],
        [ 6.3611e-01,  1.3035e+00

TypeError: expected Tensor as element 0 in argument 0, but got type