In [1]:
%load_ext autoreload
%autoreload 2

In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import InMemoryDataset, Data
import torch_geometric.transforms as T
import numpy as np
import pickle as pkl
import utils
import data

In [3]:
min_atoms = 5
max_atoms = 30
num_atoms = 110
basis = 3
num_gauss = 5
hidden = 6

In [60]:
Zs, Ds, sizes = utils.create_dummy_batch(min_atoms, max_atoms, num_atoms, 20)

In [5]:
sizes

tensor([16, 16, 19, 29, 29, 10, 24, 12,  8, 15, 20, 17, 26, 12, 25, 12, 20, 20,
        13, 20])

In [6]:
class InteractionBlock(nn.Module):
    def __init__(self, basis, hidden):
        super().__init__()
        self.cf = nn.Linear(basis, hidden)
        self.fc = nn.Linear(hidden, basis, False)
    
    def forward(self, C, D_hat, sizes):
        X = self.cf(C)
        X = X.unsqueeze(-2) * D_hat
        X = torch.tanh(self.fc(X))
        
        num_batch = C.shape[0] if len(C.shape) > 2 else 1
        mask = utils.mask_2d(sizes, max_atoms)
        return (mask.unsqueeze(-1) * X).sum(-3)

In [87]:
class MDTNN(nn.Module):
    def __init__(self, basis, num_atoms, num_gauss, hidden, T=3):
        super().__init__()
        self.basis = basis
        self.T = T
        
        self.C_embed = nn.Embedding(num_atoms + 1, basis)
        self.df = nn.Linear(num_gauss, hidden)
        self.interaction = InteractionBlock(basis, hidden)
        self.mlp = nn.Sequential(nn.Linear(basis, hidden),
                                 nn.Tanh(),
                                 nn.Linear(hidden, 1))
    
    def forward(self, Z, D, sizes):
        C = self.C_embed(Z)
        d_hat = self.df(D)
        
        for _ in range(self.T):
            C = C + self.interaction(C, d_hat, sizes)
            
        E = self.mlp(C).squeeze()
        mask = utils.mask_1d(sizes, max_atoms)
        
        return (mask * E).sum(-1)#.squeeze()

In [88]:
model = MDTNN(basis, num_atoms, num_gauss, hidden)

In [89]:
Zs

tensor([[6, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0],
        [7, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0],
        [8, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0],
        [6, 6, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0],
        [6, 7, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0],
        [6, 8, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0],
        [6, 6, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0],
        [8, 6, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0],
        [6, 6, 6, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0],
        [6, 6, 7, 1

In [90]:
model(Zs, utils.transform_D(Ds, num_gauss), sizes)

tensor([-1.0997, -0.2721,  0.8745,  0.6122, -0.3676,  1.0027,  0.2065,  0.6105,
         3.9669,  0.9653,  3.8406,  1.9240,  5.7070,  4.9746,  5.0244],
       grad_fn=<SumBackward1>)

In [91]:
Zs, Ds, sizes, target = next(iter(dl))

In [92]:
pred = model(Zs, utils.transform_D(Ds, num_gauss), sizes)

In [93]:
loss = (pred - target).sum()

In [94]:
with torch.autograd.set_detect_anomaly(True):
    pred = model(Zs, utils.transform_D(Ds, num_gauss), sizes)
    loss = pred.sum()
    loss.backward()

In [95]:
import pytorch_lightning as pl

In [96]:
class DTNNModule(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.dtnn = MDTNN(basis, num_atoms, num_gauss, hidden)
    
    def forward(self, Z, D, sizes):
        return self.dtnn(Z, D, sizes)
    
    def train_dataloader(self):
        return DataLoader(data.QM8Dataset('E1-CC2', max_atoms, num_gauss), 15)
    
    def training_step(self, batch, batch_idx):
        Z, D, sizes, target = batch
        predict = self.forward(Z, D, sizes)
        loss = F.mse_loss(predict, target)
        return {'loss': loss,
                'log': {'train_loss': loss}}
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

In [97]:
model = DTNNModule()
trainer = pl.Trainer()
trainer.fit(model)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name | Type  | Params
-------------------------------
0 | dtnn | MDTNN | 442   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

Saving latest checkpoint..





1