In [1]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets.qm9 import QM9
import torch_geometric.datasets.qm9 as qm9
from torch_geometric.data import DataLoader
import torch_geometric.nn as tgnn
from torch_scatter import scatter
import tqdm
import numpy as np
import wandb

In [2]:
# [0] Reports MAE in eV / Chemical Accuracy of the target variable U0. 
# The chemical accuracy of U0 is 0.043 see [1, Table 5].

# Reproduced table [0]
# MXMNet: 0.00590/0.043 = 0.13720930232558143
# HMGNN:  0.00592/0.043 = 0.13767441860465118
# MPNN:   0.01935/0.043 = 0.45
# KRR:    0.0251 /0.043 = 0.5837209302325582
# [0] https://paperswithcode.com/sota/formation-energy-on-qm9
# [1] Neural Message Passing for Quantum Chemistry, https://arxiv.org/pdf/1704.01212v2.pdf
# MXMNet https://arxiv.org/pdf/2011.07457v1.pdf
# HMGNN https://arxiv.org/pdf/2009.12710v1.pdf
# MPNN https://arxiv.org/pdf/1704.01212v2.pdf
# KRR HDAD kernel ridge regression https://arxiv.org/pdf/1702.05532.pdf
# HDAD means HDAD (Histogram of distances, anglesand dihedral angles)

# [2] Reports the average value of MAE / Chemical Accuracy of over all targets
# [2] https://paperswithcode.com/sota/drug-discovery-on-qm9
target_dict = {0: 'mu, D, Dipole moment', 
               1: 'alpha, {a_0}^3, Isotropic polarizability', 
               2: 'epsilon_{HOMO}, eV, Highest occupied molecular orbital energy',
               3: 'epsilon_{LUMO}, eV, Lowest unoccupied molecular orbital energy',
               4: 'Delta, eV, Gap between HOMO and LUMO',
               5: '< R^2 >, {a_0}^2, Electronic spatial extent',
               6: 'ZPVE, eV, Zero point vibrational energy', 
               7: 'U_0, eV, Internal energy at 0K',
               8: 'U, eV, Internal energy at 298.15K', 
               9: 'H, eV, Enthalpy at 298.15K',
               10: 'G, eV, Free energy at 298.15K',  
               11: 'c_{v}, cal\(mol K), Heat capacity at 298.15K'}

chemical_accuracy = {idx:0.043 for idx in range(12)}
chemical_accuracy[0] = 0.1
chemical_accuracy[1] = 0.1
chemical_accuracy[5] = 1.2
chemical_accuracy[6] = 0.0012
chemical_accuracy[11] = 0.050

In [3]:
wandb.init(project='QM9-GAT', entity='chrisxx')
config = wandb.config
config.hidden_dim = 256
config.nlayers = 6
config.nhead = 1
config.lr = 0.0003
config.n_epochs = 5000
config.patience = 5
config.factor = 0.95
config.minimal_lr = 6e-8
config.target_idx = 7
config.batch_size = 128
config.n_train = 110000
config.n_valid = 10000
config.target_ratio = 0.3
config.store_starting_from_ratio = 1
config.required_improvement = 0.8
config.model_dir = '../models/qm9/GAT2/'
config.num_workers = 4
config.dfs_codes = '../datasets/qm9_geometric_work/min_dfs_codes.json'

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mchrisxx[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.11.0 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [4]:
import json
with open(config.dfs_codes, 'r') as f:
    dfs_codes = json.load(f)

In [5]:
dfs_codes['gdb_55']['dfs_indices']

[5, 4, 9, 1, 13, 6, 7, 8, 10, 11, 12, 2, 3, 0, 14]

In [6]:
def int_2_one_hot(data, dfs_codes = dfs_codes):
    features = data.x
    # make atomic number a one hot
    atomic_number = nn.functional.one_hot(features[:, 5].long(), 100)
    # make num_h a one hot
    num_h = nn.functional.one_hot(features[:, -1].long(), 9)
    dfs_indices = nn.functional.one_hot(torch.LongTensor(dfs_codes[data.name]['dfs_indices']), 29)
    data.x = torch.cat((features[:, :5], features[:, 6:-1], atomic_number, num_h, dfs_indices), axis=1)
    return data

In [7]:
target_idx = config.target_idx

In [8]:
dataset = QM9('../datasets/qm9_geometric_work/', transform=int_2_one_hot)

In [9]:
dataset = dataset.shuffle()
train_dataset = dataset[:config.n_train]
valid_dataset = dataset[config.n_train:config.n_train+config.n_valid]
test_dataset = dataset[config.n_train+config.n_valid:]
config.n_test = len(test_dataset)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, pin_memory=True, num_workers=config.num_workers)
valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, num_workers=config.num_workers)
test_loader = DataLoader(test_dataset, batch_size=32, num_workers=config.num_workers)

In [10]:
import os
os.makedirs(config.model_dir, exist_ok=True)

In [11]:
torch.save(dataset.indices(), config.model_dir+'dataset_indices.pt')

In [12]:
ngpu=1
device = torch.device('cuda:0' if (torch.cuda.is_available() and ngpu > 0) else 'cpu')

# Model

In [13]:
target_vec = []

In [14]:
# based on https://schnetpack.readthedocs.io/en/stable/tutorials/tutorial_02_qm9.html
# and https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/schnet.html#SchNet
for data in train_loader:
    data = data.to(device)
    atomU0s = torch.tensor(qm9.atomrefs[target_idx], device=device)[torch.argmax(data.x[:, :5], axis=1)]
    target_modular = scatter(atomU0s, data.batch, dim=-1, reduce='sum')
    target_vec += [(data.y[:, target_idx] - target_modular).detach().cpu().numpy()]
target_vec = np.concatenate(target_vec, axis=0)

In [15]:
target_mean = np.mean(target_vec)
target_std = np.std(target_vec)

In [16]:
class SelfAttentionReadout(nn.Module):
    def __init__(self, dtarget, dmodel=512, dim_feedforward=2048, nlayers=6, nhead=8):
        self.dtarget = dtarget
        self.dmodel = dmodel
        self.dim_feedforward = dim_feedforward
        self.nhead = nheada
        self.nlayers = nlayers
        self.cls_token = nn.Parameter(torch.empty(1, 1, self.d_model), requires_grad=True)
        self.enc = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.d_model, 
                                                                    nhead=self.nhead, 
                                                                    dim_feedforward=self.dim_feedforward), 
                                         self.nlayers)
        self.fc_out = nn.Linear(self.d_model, self.dtarget)
        
    def forward(sequence):
        # sequence must be of shape: [n_seq, n_batch, dmodel]
        augmented_sequence = torch.cat((self.cls_token.expand(-1, sequence.shape[1], -1), sequence), dim=0)
        transformer_out = self.enc(augmented_sequence)
        out = self.fc_out(transformer_out[0]) 
        return out
        

class GATNN(nn.Module):
    def __init__(self, vert_dim, hidden_dim=config.hidden_dim, nhead=config.nhead, nlayers=config.nlayers,
                 #ro_dmodel=128, ro_dim_feedforward=512, ro_nlayers=6, ro_nhead=8, 
                 mean=None, std=None, atomref=None,
                 max_vertices=29, max_edges=28):
        """
        transfomer model is some type of transformer that 
        """
        super(GATNN, self).__init__()
        gat_layers = [(tgnn.GATv2Conv(hidden_dim, hidden_dim, heads=nhead), 'x, edge_index -> x')]
        for _ in range(nlayers-1):
            gat_layers += [(tgnn.GATv2Conv(hidden_dim, hidden_dim, heads=nhead), 'x, edge_index -> x')]
        
        self.gat_layers = tgnn.Sequential('x, edge_index', gat_layers)
        self.fc_out = nn.Linear(hidden_dim, 1)
        self.fc_out2 = nn.Linear(hidden_dim, 1)
        self.readout = tgnn.GlobalAttention(self.fc_out, self.fc_out2)
        
        self.x_emb = nn.Linear(vert_dim, hidden_dim)
        self.pos_emb = nn.Linear(3, hidden_dim)
        
        #TODO: maybe:
        #self.readout = SelfAttentionReadout(ro_dmodel, ro_dim_feedforward, ro_nlayers, ro_nhead)
        
        self.mean = mean
        self.std = std
        self.register_buffer('initial_atomref', atomref)
        self.atomref = None
        if atomref is not None:
            self.atomref = nn.Embedding(100, 1)
            self.atomref.weight.data.copy_(atomref)
    
    
    def forward(self, data):
        x, z, edge_index, batch = data.x, data.z, data.edge_index, data.batch
        h = self.x_emb(x) + self.pos_emb(data.pos)
        h = self.gat_layers(h, edge_index) 
        out = self.readout(h, batch)
        
        
        #batch = batch.permute(1, 0, 2) # seq_dim x batch_dim x n_model
        #transformer_out = self.enc(batch)
        #out = self.fc_out(transformer_out[0]) 
        
        # tricks from Schnet
        if self.mean is not None and self.std is not None:
            out = out * self.std + self.mean
                
        if self.atomref is not None:
            out = out + tgnn.global_add_pool(self.atomref(z), batch)
        
        return out

In [17]:
model = GATNN(next(iter(train_loader)).x.shape[1], atomref=dataset.atomref(target_idx), mean=target_mean, std=target_std)
loss = nn.MSELoss(reduction='mean')
optimizer = optim.Adam(model.parameters(), lr=config.lr)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', verbose=True, patience=config.patience, factor=config.factor)



# Training

In [18]:
model = model.to(device)

In [None]:
loss_hist = []
min_mae = config.store_starting_from_ratio
try:
    # For each epoch
    for epoch in range(config.n_epochs):
        # For each batch in the dataloader
        pbar = tqdm.tqdm(enumerate(train_loader, 0))
        epoch_loss = 0
        for i, data in pbar:
            model.zero_grad()
            data.to(device)
            target = data.y[:, target_idx]
            prediction = model(data)
            output = loss(prediction.view(-1), target)
            mae = (prediction.view(-1) - target).abs().mean()
            epoch_loss = (epoch_loss*i + mae.item())/(i+1)
            
            pbar.set_description('Epoch %d: MAE/CA %2.6f'%(epoch+1, epoch_loss/chemical_accuracy[target_idx]))
            output.backward()
            optimizer.step()
            wandb.log({'MSE': output.item()})
        curr_lr = list(optimizer.param_groups)[0]['lr']
        wandb.log({'MAE':epoch_loss, 
                   'MAE/CA':epoch_loss/chemical_accuracy[target_idx],
                   'learning rate':curr_lr})
        lr_scheduler.step(epoch_loss)
        loss_hist += [epoch_loss] 

        if epoch_loss/chemical_accuracy[target_idx] < min_mae*config.required_improvement:
            min_mae = epoch_loss/chemical_accuracy[target_idx]
            torch.save(model.state_dict(), config.model_dir+'gat_epoch%d.pt'%(epoch+1))
        if curr_lr < config.minimal_lr:
            break
        if epoch_loss/chemical_accuracy[target_idx] < config.target_ratio:
            break

except KeyboardInterrupt:
    print('keyboard interrupt caught')
    torch.save(model.state_dict(), config.model_dir+'gat_epoch%d.pt'%(epoch+1))

Epoch 1: MAE/CA 57.410989: : 860it [01:07, 12.72it/s] 
Epoch 2: MAE/CA 32.545458: : 25it [00:01, 14.35it/s]

In [None]:
pbar = tqdm.tqdm(enumerate(test_loader, 0))
epoch_loss = 0
maes = []
for i, data in pbar:
    data.to(device)
    prediction = prediction = model(data.x, data.z, data.edge_index, data.batch)
    mae = (prediction.view(-1) - data.y[:, target_idx]).abs()
    maes += [mae.detach().cpu()]
maes = torch.cat(maes, dim=0)
mae = maes.mean().item()
print(mae, mae/chemical_accuracy[target_idx])
wandb.log({'TEST MAE':mae, 'TEST MAE/CA':mae/chemical_accuracy[target_idx]})
