In [58]:
import torch, sys, os
from torch.nn import Module, Linear, SiLU, Embedding
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn import global_add_pool
sys.path.append(os.path.abspath('/Users/samharshe/Documents/Gerstein Lab/EGNN Pro/src/data/get'))
from data_get_utils import get_dataset, get_dataloaders
from model_utils import bessel_rbf, cosine_cutoff, sanity_check

In [59]:
dataset = get_dataset(version='charizard', molecule='benzene')
dataloader, _, _ = get_dataloaders(version='charizard', molecule='benzene', train_split=0.01, val_split=0.98, test_split=0.01, batch_size=32)
data = dataset[0]



In [60]:
class PAINNMessage(MessagePassing):
    def __init__(self):
        super().__init__(aggr='sum')
        
        self.embedding = Embedding(118,8)
        
        self.act = SiLU()
        
        self.linear_1 = Linear(8,8)
        self.linear_2 = Linear(20,24)
        self.linear_3 = Linear(8,24)
        
    def forward(self, data):
        # get relevant parts from data
        num_nodes = data.num_nodes
        unit_edge_vec = data.unit_edge_vec
        edge_index = data.edge_index
        edge_vec_length = data.edge_vec_length
        
        v = torch.zeros((num_nodes,8,3)) # initialize equivariant feature vector
        s = self.embedding(data.z) # initialize invariant feature vector
        
        # concatenate feature vectors
        x = torch.cat((v.view(num_nodes,-1), s), dim=1)
        
        x = self.propagate(edge_index=edge_index, x=x, edge_attr1=unit_edge_vec, edge_attr2=edge_vec_length)
        
        return x
        
    def message(self, x_j, unit_edge_vec, edge_vec_length, edge_attr1, edge_attr2):
        v_j = x_j[:,:int(x_j.shape[1] * 0.75)]
        s_j = x_j[:,int(x_j.shape[1] * 0.75):]
        
        s_j = self.linear_1(s_j)
        s_j = self.act(s_j)
        s_j = self.linear_3(s_j)
        
        unit_edge_vec = edge_attr1
        edge_vec_length = edge_attr2.view(-1,1)
        
        edge_vec = bessel_rbf(x=edge_vec_length, n=20, r_cut=1.4415) * cosine_cutoff(x=edge_vec_length, r_cut=1.4415)
        edge_vec = self.linear_2(edge_vec)
        
        split = s_j * edge_vec
        even_third = int(split.shape[1]/3)
        split_1 = split[:, :even_third]
        split_2 = split[:, even_third:-even_third]
        split_3 = split[:, -even_third:]
        
        split_1.unsqueeze_(dim=2)
        
        v_j = v_j.view(-1,8,3)
        v_j = v_j * split_1
        
        v_j += torch.einsum('ni,nj->nij', split_3, unit_edge_vec)
        
        s_j = split_2
        
        out = torch.cat((v_j.view(v_j.size(0), -1), s_j), dim=1)
        
        return out
        
    def update(self, aggr_out, x):
        x += aggr_out
        
        return x

In [61]:
model = PAINNBlock()

In [62]:
sanity_check(model=model)



AttributeError: 'GlobalStorage' object has no attribute 'unit_edge_vec'

In [None]:
class PAINNUpdate(MessagePassing):
    def __init__(self):
        super().__init__()
        
        self.act = SiLU()
        
        self.U = Linear(128,128,bias=False)
        self.V = Linear(128,128,bias=False)
        
        self.linear_1 = Linear(256,128)
        self.linear_2 = Linear(128,384)
    
    def forward(self, x, edge_index):
        x = self.propagate(edge_index=edge_index, x=x)
        
        return x
        
    def message(self, x_j):
        v_j = x_j[:,:int(x_j.shape[1] * 0.75)]
        s_j = x_j[:,int(x_j.shape[1] * 0.75):]
        
        v = self.U(v)
        v_V = self.V(v)
        
        stack_in = torch.norm(v_V, p=2, dim=1)
        
        stack = torch.cat((stack_in, s_j), dim=1)
        s_j = self.linear_1(stack)
        s_j = self.act(s_j)
        split = self.linear_2(s_j)
        
        even_third = int(split.shape[1]/3)
        split_1 = split[:,:even_third]    
        split_2 = split[:,even_third:-even_third]
        split_3 = split[:,-even_third:]
    
    def update(self, aggr_out, x):
        x += aggr_out
    
        return x

In [None]:
class PAINN(Module):
    def __init__():
        super().__init__()
        
        self.embedding = Embedding(118,128)
        self.block_1 = PAINNBlock()
        self.block_2 = PAINNBlock()
        self.block_3 = PAINNBlock()
        self.prediction = PAINNPrediction()
    
    def forward(self, data):
        edge_index = data.edge_index
        pos = data.pos
        pos.requires_grad_(True)
        
        v = torch.zeros(data.num_nodes, 128, 3)
        s = self.embedding(data)
        
        x = torch.cat((v.view(v.size(0), -1), s), dim=1)
        
        x = self.block_1(x=x, edge_index, ) # IMPLEMENT EDGE ATTRIBUTES
        x = self.block_2(x=x, edge_index, )
        x = self.block_2(x=x, edge_index, )
        
        F_hat, E_hat = self.prediction(x)
        
        return F_hat, E_hat