In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import pickle as pkl
import utils
import data

In [3]:
from tqdm.auto import tqdm
getattr(tqdm, '_instances', {}).clear()  # ⬅ add this line

In [4]:
min_atoms = 5
max_atoms = 30
num_atoms = 110
basis = 30
num_gauss = 5
hidden = 20

In [5]:
Zs, Ds, sizes = utils.create_dummy_batch(min_atoms, max_atoms, num_atoms, 20)

In [6]:
sizes

tensor([22, 14, 13, 27, 11, 26, 15, 19, 27,  5, 28, 26, 29, 10, 28, 24, 16, 18,
        10, 23])

In [7]:
class InteractionBlock(nn.Module):
    def __init__(self, basis, hidden):
        super().__init__()
        self.cf = nn.Linear(basis, hidden)
        self.fc = nn.Linear(hidden, basis, False)
    
    def forward(self, C, D_hat, sizes):
        X = self.cf(C)
        X = X.unsqueeze(-2) * D_hat
        X = torch.tanh(self.fc(X))
        
        num_batch = C.shape[0] if len(C.shape) > 2 else 1
        mask = utils.mask_2d(sizes, max_atoms)
        mask = mask.to(X.device)
        return (mask.unsqueeze(-1) * X).sum(-3)

In [8]:
class MDTNN(nn.Module):
    def __init__(self, basis, num_atoms, num_gauss, hidden, T=3):
        super().__init__()
        self.basis = basis
        self.T = T
        
        self.C_embed = nn.Embedding(num_atoms + 1, basis)
        self.df = nn.Linear(num_gauss, basis)
        self.interaction = InteractionBlock(basis, basis)
        self.mlp = nn.Sequential(nn.Linear(basis, hidden),
                                 nn.Tanh(),
                                 nn.Linear(hidden, 1))
    
    def forward(self, Z, D, sizes):
        C = self.C_embed(Z)
        d_hat = self.df(D)
        
        for _ in range(self.T):
            C = C + self.interaction(C, d_hat, sizes)
            
        E = self.mlp(C).squeeze()
        mask = utils.mask_1d(sizes, max_atoms)
        mask = mask.to(E.device)
        return (mask * E).sum(-1)#.squeeze()

In [9]:
model = MDTNN(basis, num_atoms, num_gauss, hidden)

In [10]:
model(Zs, utils.transform_D(Ds, num_gauss), sizes)

tensor([ -3.8966,  -0.5899,  -4.6702,  -7.2142,  -2.6789,  -3.6381,  -0.8011,
         -7.7035,  -6.5794,  -1.2178,  -7.5376,   0.6271, -10.0708,  -3.2347,
          5.6723,  -3.4375,  -4.5307,   0.3071,  -4.5394,  -6.7996],
       grad_fn=<SumBackward1>)

In [11]:
import pytorch_lightning as pl

In [12]:
class DTNNModule(pl.LightningModule):
    def __init__(self, basis, num_atoms, num_gauss, hidden, target):
        super().__init__()
        self.dtnn = MDTNN(basis, num_atoms, num_gauss, hidden)
        self.target = target
    
    def forward(self, Z, D, sizes):
        return self.dtnn(Z, D, sizes)
    
    def prepare_data(self):
        self.dataset = data.QM8Dataset(self.target, max_atoms, num_gauss)
        size = len(self.dataset)
        test_size = int(size * 0.2)
        sizes = [size - 2*test_size, test_size, test_size]
        self.train_dataset, self.test_dataset, self.valid_dataset = random_split(self.dataset, sizes)
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, 15)
    
    def step(self, batch, batch_idx, loss_fn):
        Z, D, sizes, target = batch
        predict = self.forward(Z, D, sizes)
        loss = loss_fn(predict, target)
        return loss
    
    def training_step(self, batch, batch_idx):
        loss = self.step(batch, batch_idx, F.mse_loss)
        result = pl.TrainResult(minimize=loss)
        result.log('train_loss', loss, prog_bar=True)
        result.log_dict({'train_loss': loss})
        return result
    
    def val_dataloader(self):
        return DataLoader(self.valid_dataset, 50)
    
    def validation_step(self, batch, batch_idx):
        loss = self.step(batch, batch_idx, F.l1_loss)
        
        result = pl.EvalResult(checkpoint_on=loss)
        result.log_dict({'val_loss': loss})
        return result
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, 50)
    
    def test_step(self, batch, batch_idx):
        result = self.validation_step(batch, batch_idx)
        result.rename_keys({'val_loss': 'test_loss'})
        return result
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), 1e-4)

In [13]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

In [None]:
model = DTNNModule(basis, num_atoms, 11, hidden, 'E1-CC2')
trainer = pl.Trainer(gpus=1, checkpoint_callback=ModelCheckpoint(), early_stop_callback=EarlyStopping(patience=10))
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name | Type  | Params
-------------------------------
0 | dtnn | MDTNN | 6 K   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

In [None]:
trainer.test(model)

In [None]:
%debug