In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import InMemoryDataset, Data
import torch_geometric.transforms as T
import numpy as np
import pickle as pkl
import utils
import data

In [3]:
min_atoms = 5
max_atoms = 30
num_atoms = 110
basis = 3
num_gauss = 5
hidden = 6

In [4]:
Zs, Ds, sizes = utils.create_dummy_batch(min_atoms, max_atoms, num_atoms, 20)

In [5]:
sizes

tensor([ 7, 21, 11, 24, 21, 15,  6, 25,  7, 28, 25, 22, 20, 27,  5, 16, 13, 21,
        29, 11])

In [6]:
class InteractionBlock(nn.Module):
    def __init__(self, basis, hidden):
        super().__init__()
        self.cf = nn.Linear(basis, hidden)
        self.fc = nn.Linear(hidden, basis, False)
    
    def forward(self, C, D_hat, sizes):
        X = self.cf(C)
        X = X.unsqueeze(-2) * D_hat
        X = torch.tanh(self.fc(X))
        
        num_batch = C.shape[0] if len(C.shape) > 2 else 1
        mask = utils.mask_2d(sizes, max_atoms)
        return (mask.unsqueeze(-1) * X).sum(-3)

In [7]:
class MDTNN(nn.Module):
    def __init__(self, basis, num_atoms, num_gauss, hidden, T=3):
        super().__init__()
        self.basis = basis
        self.T = T
        
        self.C_embed = nn.Embedding(num_atoms + 1, basis)
        self.df = nn.Linear(num_gauss, hidden)
        self.interaction = InteractionBlock(basis, hidden)
        self.mlp = nn.Sequential(nn.Linear(basis, hidden),
                                 nn.Tanh(),
                                 nn.Linear(hidden, 1))
    
    def forward(self, Z, D, sizes):
        C = self.C_embed(Z)
        d_hat = self.df(D)
        
        for _ in range(self.T):
            C = C + self.interaction(C, d_hat, sizes)
            
        E = self.mlp(C).squeeze()
        mask = utils.mask_1d(sizes, max_atoms)
        
        return (mask * E).sum(-1)#.squeeze()

In [8]:
model = MDTNN(basis, num_atoms, num_gauss, hidden)

In [9]:
Zs

tensor([[101,   7,  52,  55,   5,  52,  35,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0],
        [ 88,  18,  56,  62,  65,  29,  52,  46,  37,  23,  18,  56,  29, 107,
          76,  49,  66,  12,  13,  25,  31,   0,   0,   0,   0,   0,   0,   0,
           0,   0],
        [ 50,  81,  96,  20,  26, 104,  76,  41,  97,  74,  50,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0],
        [ 93,  21,  40,  52,  23,  90,  70,  83,  30,  66,  86,  27,  29,  21,
          76,  90,  15,  72,  25, 103,  84,  94,  32,  36,   0,   0,   0,   0,
           0,   0],
        [ 78,  69,  96,  10,  64,  47,  77,  51,  65,  88,  11,  87,   2,  20,
          21,  40,  54,  31,  94,  67,  58,   0,   0,   0,   0,   0,   0,   0,
           0,   0],
        [ 83,   7, 109,  62,  61,  31, 105,  41,   7, 100,  69,  16,  77,  36,
         101,   0,   0,   0,   

In [10]:
model(Zs, utils.transform_D(Ds, num_gauss), sizes)

tensor([ 5.4549,  2.1654,  9.1737,  9.2142, 13.4600,  8.7968,  2.9231,  6.5949,
         3.0562,  8.3927,  8.8697, -1.3868, 11.5769, 14.8065,  3.2132,  9.8461,
         9.2581,  0.4544, 22.2138,  7.0638], grad_fn=<SumBackward1>)

In [13]:
import pytorch_lightning as pl

In [14]:
class DTNNModule(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.dtnn = MDTNN(basis, num_atoms, num_gauss, hidden)
    
    def forward(self, Z, D, sizes):
        return self.dtnn(Z, D, sizes)
    
    def train_dataloader(self):
        return DataLoader(data.QM8Dataset('E1-CC2', max_atoms, num_gauss), 15)
    
    def training_step(self, batch, batch_idx):
        Z, D, sizes, target = batch
        predict = self.forward(Z, D, sizes)
        loss = F.mse_loss(predict, target)
        return {'loss': loss,
                'log': {'train_loss': loss}}
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

In [16]:
model = DTNNModule()
trainer = pl.Trainer()
trainer.fit(model)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name | Type  | Params
-------------------------------
0 | dtnn | MDTNN | 442   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

Saving latest checkpoint..





1