In [1]:
%load_ext autoreload
%autoreload 2

In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import InMemoryDataset, Data
import torch_geometric.transforms as T
import numpy as np
import pickle as pkl
import utils
import data

In [3]:
min_atoms = 5
max_atoms = 30
num_atoms = 110
basis = 3
num_gauss = 5
hidden = 6

In [4]:
Zs, Ds, sizes = utils.create_dummy_batch(min_atoms, max_atoms, num_atoms, 20)

In [5]:
sizes

tensor([16, 16, 19, 29, 29, 10, 24, 12,  8, 15, 20, 17, 26, 12, 25, 12, 20, 20,
        13, 20])

In [6]:
class InteractionBlock(nn.Module):
    def __init__(self, basis, hidden):
        super().__init__()
        self.cf = nn.Linear(basis, hidden)
        self.fc = nn.Linear(hidden, basis, False)
    
    def forward(self, C, D_hat, sizes):
        X = self.cf(C)
        X = X.unsqueeze(-2) * D_hat
        X = torch.tanh(self.fc(X))
        
        num_batch = C.shape[0] if len(C.shape) > 2 else 1
        mask = utils.mask_2d(sizes, max_atoms)
        return (mask.unsqueeze(-1) * X).sum(-3)

In [7]:
class MDTNN(nn.Module):
    def __init__(self, basis, num_atoms, num_gauss, hidden, T=3):
        super().__init__()
        self.basis = basis
        self.T = T
        
        self.C_embed = nn.Embedding(num_atoms + 1, basis)
        self.df = nn.Linear(num_gauss, hidden)
        self.interaction = InteractionBlock(basis, hidden)
        self.mlp = nn.Sequential(nn.Linear(basis, hidden),
                                 nn.Tanh(),
                                 nn.Linear(hidden, 1))
    
    def forward(self, Z, D, sizes):
        C = self.C_embed(Z)
        d_hat = self.df(D)
        
        print(C.shape)
        print(d_hat.shape)
        for _ in range(self.T):
            C += self.interaction(C, d_hat, sizes)
            
        E = self.mlp(C).squeeze()
        mask = utils.mask_1d(sizes, max_atoms)
        
        return (mask * E).sum(-1)#.squeeze()

In [8]:
model = MDTNN(basis, num_atoms, num_gauss, hidden)

In [9]:
Zs

tensor([[ 32,  42, 104,  84,  45,  33,  41,  63,  43,  62,  13,  15,  82,  26,
          47,  15,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0],
        [  2,  92,  82,  68,  94,   6, 108,  78,   2,  64,  61,  22,   4,  59,
          43, 107,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0],
        [ 50,  35,   9,  97,  43,  24,  50,  86,  96,   9,  27,  75,  28,  95,
          11,  19,  40,  15,  37,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0],
        [ 13,  55,  15,  32, 104, 108,  92, 107,   6,  93,  25,  51,  84,  21,
          89,  63,  39,  21,  10,  56,  51, 109,  51,  33,   3,  25,  59,  50,
           8,   0],
        [ 42,  42, 110,   8,  68,  72,  27,  53,  30,  94,  82,  75,  87,  17,
          29,   6,  15,  96,  41,  74,  86,  79,  34,  68,  76,   9, 110,  10,
          22,   0],
        [ 76,  98,   6,  72,  19,  83,  56,  65,  16,  75,   0,   0,   0,   0,
           0,   0,   0,   0,   

In [10]:
model(Zs, utils.transform_D(Ds, num_gauss), sizes)

torch.Size([20, 30, 3])
torch.Size([20, 30, 30, 6])


tensor([ -4.2454,  -5.1538,  -6.1584, -11.2226, -11.7209,  -1.7913,  -7.7839,
         -2.8416,  -1.8101,  -3.7942,  -7.0755,  -5.5387, -10.7433,  -3.5129,
         -9.2853,  -2.5058,  -5.9024,  -6.7912,  -4.1861,  -7.3582],
       grad_fn=<SumBackward1>)

In [33]:
ds = data.QM8Dataset('E1-CC2', 30)
dl = DataLoader(ds, batch_size=15)

In [34]:
Zs, Ds, sizes, target = next(iter(dl))

In [35]:
model(Zs, utils.transform_D(Ds, num_gauss), sizes)

torch.Size([15, 30, 3])
torch.Size([15, 30, 30, 6])


tensor([-0.9837, -0.2447,  0.0298, -0.5211,  0.0952, -0.3821, -3.8352, -1.8950,
        -3.2366, -1.8696, -2.9710, -1.9652, -6.5276, -4.9831, -4.9447],
       grad_fn=<SumBackward1>)