In [1]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets.qm9 import QM9
import torch_geometric.datasets.qm9 as qm9
from torch_geometric.data import DataLoader
from torch.utils.data import Dataset

import torch_geometric.nn as tgnn
from torch_scatter import scatter
import tqdm
import numpy as np
import wandb
import json
import os


In [2]:
# [0] Reports MAE in eV / Chemical Accuracy of the target variable U0. 
# The chemical accuracy of U0 is 0.043 see [1, Table 5].

# Reproduced table [0]
# MXMNet: 0.00590/0.043 = 0.13720930232558143
# HMGNN:  0.00592/0.043 = 0.13767441860465118
# MPNN:   0.01935/0.043 = 0.45
# KRR:    0.0251 /0.043 = 0.5837209302325582
# [0] https://paperswithcode.com/sota/formation-energy-on-qm9
# [1] Neural Message Passing for Quantum Chemistry, https://arxiv.org/pdf/1704.01212v2.pdf
# MXMNet https://arxiv.org/pdf/2011.07457v1.pdf
# HMGNN https://arxiv.org/pdf/2009.12710v1.pdf
# MPNN https://arxiv.org/pdf/1704.01212v2.pdf
# KRR HDAD kernel ridge regression https://arxiv.org/pdf/1702.05532.pdf
# HDAD means HDAD (Histogram of distances, anglesand dihedral angles)

# [2] Reports the average value of MAE / Chemical Accuracy of over all targets
# [2] https://paperswithcode.com/sota/drug-discovery-on-qm9
target_dict = {0: 'mu, D, Dipole moment', 
               1: 'alpha, {a_0}^3, Isotropic polarizability', 
               2: 'epsilon_{HOMO}, eV, Highest occupied molecular orbital energy',
               3: 'epsilon_{LUMO}, eV, Lowest unoccupied molecular orbital energy',
               4: 'Delta, eV, Gap between HOMO and LUMO',
               5: '< R^2 >, {a_0}^2, Electronic spatial extent',
               6: 'ZPVE, eV, Zero point vibrational energy', 
               7: 'U_0, eV, Internal energy at 0K',
               8: 'U, eV, Internal energy at 298.15K', 
               9: 'H, eV, Enthalpy at 298.15K',
               10: 'G, eV, Free energy at 298.15K',  
               11: 'c_{v}, cal\(mol K), Heat capacity at 298.15K'}

chemical_accuracy = {idx:0.043 for idx in range(12)}
chemical_accuracy[0] = 0.1
chemical_accuracy[1] = 0.1
chemical_accuracy[5] = 1.2
chemical_accuracy[6] = 0.0012
chemical_accuracy[11] = 0.050

In [3]:
wandb.init(project='QM9', entity='chrisxx')
config = wandb.config
config.lr = 0.0003
config.n_epochs = 10
config.patience = 5
config.factor = 0.95
config.minimal_lr = 6e-8
config.target_idx = 7
config.batch_size = 128
config.n_train = 110000
config.n_valid = 10000
config.target_ratio = 0.4
config.store_starting_from_ratio = 1
config.required_improvement = 0.8
config.model_dir = '../models/qm9/gtransformer/'
config.dfs_codes = '../datasets/qm9_geometric/min_dfs_codes.json'

[34m[1mwandb[0m: Currently logged in as: [33mchrisxx[0m (use `wandb login --relogin` to force relogin)


In [4]:
target_idx = config.target_idx

# Dataset

In [66]:
class QM9DFSCodes(Dataset):
    def __init__(self, dfs_codes, qm9_dataset, 
                 target_idx, vert_feats = ['x', 'pos'],
                 vertex_transform=None, edge_transform=None):
        self.dfs_codes = dfs_codes
        self.qm9_dataset = qm9_dataset
        self.target_idx = target_idx
        self.vertex_features = vert_feats
        self.max_len = np.max([len(d['min_dfs_code']) for d in self.dfs_codes.values()])
        self.feat_dim = 2 + qm9_dataset[0].edge_attr.shape[1] # 2 for the dfs indices
        for feat in vert_feats: 
            self.feat_dim += 2*qm9_dataset[0][feat].shape[1]
        self.vertex_transform = vertex_transform
        self.edge_transform = edge_transform

    def __len__(self):
        return len(self.qm9_dataset)

    def __getitem__(self, idx):
        """
        graph_repr: [batch_size, max_edges, 2 + n_vert_feat + n_edge_feat + n_vert_feat]
        return vertex_features, graph_repr, target
        """
        data = self.qm9_dataset[idx]
        code_dict = self.dfs_codes[data.name]
        code = code_dict['min_dfs_code']
        eid2nr = code_dict['edge_id_2_edge_number']
        
        vert_feats = [data[k].detach().cpu().numpy() for k in ['x', 'pos']]
        vert_feats = np.concatenate(vert_feats, axis=1)
        edge_feats = data.edge_attr.detach().cpu().numpy()
        
        d = {'dfs_from':np.zeros(self.max_len), 
             'dfs_to':np.zeros(self.max_len),
             'feat_from':np.zeros((self.max_len, vert_feats.shape[1])), 
             'feat_to':np.zeros((self.max_len, vert_feats.shape[1])), 
             'feat_edge':np.zeros((self.max_len, edge_feats.shape[1])),
             'n_edges':len(code)*np.ones((1,), dtype=np.int32)}
        
        for idx, e_tuple in enumerate(code):
            e_idx = eid2nr[str(e_tuple[-2])]
            from_idx = e_tuple[-3]
            to_idx = e_tuple[-1]
            d['dfs_from'][idx] = e_tuple[0]
            d['dfs_to'][idx] = e_tuple[1]
            d['feat_from'][idx] = vert_feats[from_idx]
            d['feat_to'][idx] = vert_feats[to_idx]
            d['feat_edge'][idx] = edge_feats[e_idx]      
        
        
        d_tensors = {key: torch.Tensor(val) for key, val in d.items()}
        
        if self.vertex_transform:
            d_tensors['feat_from'] = self.vertex_transform(d_tensors['feat_from'])
            d_tensors['feat_to'] = self.vertex_transform(d_tensors['feat_to'])
        if self.edge_transform:
            d_tensors['feat_edge'] = self.edge_transform(d_tensors['feat_edge'])
        
        d_tensors['target'] = data.y[:, self.target_idx]
        
        return d_tensors
    
    def shuffle(self):
        self.qm9_dataset = self.qm9_dataset.shuffle()
        return self


In [53]:
with open(config.dfs_codes, 'r') as f:
    dfs_codes = json.load(f)

In [54]:
dfs_codes

{'gdb_1': {'min_dfs_code': [[0, 1, 0, 0, 1, 4, 3, 0],
   [1, 2, 1, 0, 0, 0, 0, 1],
   [1, 3, 1, 0, 0, 0, 1, 2],
   [1, 4, 1, 0, 0, 0, 2, 3]],
  'edge_id_2_edge_number': {'3': 7, '0': 0, '1': 1, '2': 2}},
 'gdb_2': {'min_dfs_code': [[0, 1, 0, 0, 2, 3, 2, 0],
   [1, 2, 2, 0, 0, 0, 0, 1],
   [1, 3, 2, 0, 0, 0, 1, 2]],
  'edge_id_2_edge_number': {'2': 5, '0': 0, '1': 1}},
 'gdb_3': {'min_dfs_code': [[0, 1, 0, 0, 3, 2, 1, 0],
   [1, 2, 3, 0, 0, 0, 0, 1]],
  'edge_id_2_edge_number': {'1': 3, '0': 0}},
 'gdb_4': {'min_dfs_code': [[0, 1, 0, 0, 1, 3, 1, 0],
   [1, 2, 1, 2, 1, 0, 0, 1],
   [2, 3, 1, 0, 0, 1, 2, 2]],
  'edge_id_2_edge_number': {'1': 5, '0': 0, '2': 3}},
 'gdb_5': {'min_dfs_code': [[0, 1, 0, 0, 1, 2, 1, 0],
   [1, 2, 1, 2, 2, 0, 0, 1]],
  'edge_id_2_edge_number': {'1': 3, '0': 0}},
 'gdb_6': {'min_dfs_code': [[0, 1, 0, 0, 1, 3, 2, 0],
   [1, 2, 1, 0, 0, 0, 1, 2],
   [1, 3, 1, 1, 3, 0, 0, 1]],
  'edge_id_2_edge_number': {'2': 5, '1': 1, '0': 0}},
 'gdb_7': {'min_dfs_code': [[0, 1, 

In [55]:
dset = QM9('../datasets/qm9_geometric/')

In [67]:
dset = dset.shuffle()
train_dataset = QM9DFSCodes(dfs_codes, dset[:config.n_train], target_idx)
valid_dataset = QM9DFSCodes(dfs_codes, dset[config.n_train:config.n_train+config.n_valid], target_idx) 
test_dataset = QM9DFSCodes(dfs_codes, dset[config.n_train+config.n_valid:], target_idx) 
config.n_test = len(test_dataset)
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, pin_memory=False)
valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size)
test_loader = DataLoader(test_dataset, batch_size=32)

In [68]:
os.makedirs(config.model_dir, exist_ok=True)

In [69]:
d = next(iter(train_loader))

In [81]:
torch.save(dset.indices(), config.model_dir+'dataset_indices.pt')

In [9]:
ngpu=1
device = torch.device('cuda:0' if (torch.cuda.is_available() and ngpu > 0) else 'cpu')

# Model

In [10]:
target_vec = []

In [11]:
# based on https://schnetpack.readthedocs.io/en/stable/tutorials/tutorial_02_qm9.html
# and https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/schnet.html#SchNet
for data in train_loader:
    data = data.to(device)
    atomU0s = torch.tensor(qm9.atomrefs[target_idx], device=device)[torch.argmax(data.x[:, :5], axis=1)]
    target_modular = scatter(atomU0s, data.batch, dim=-1, reduce='sum')
    target_vec += [(data.y[:, target_idx] - target_modular).detach().cpu().numpy()]
target_vec = np.concatenate(target_vec, axis=0)

In [12]:
target_mean = np.mean(target_vec)
target_std = np.std(target_vec)

In [13]:
model = tgnn.SchNet(hidden_channels=128, num_filters=128, num_interactions=6, num_gaussians=50, cutoff=10.0,
                    readout='add', atomref=dataset.atomref(target_idx), mean=target_mean, std=target_std)
loss = nn.MSELoss(reduction='mean')
optimizer = optim.Adam(model.parameters(), lr=config.lr)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', verbose=True, patience=config.patience, factor=config.factor)

# Training

In [14]:
model = model.to(device)

In [15]:
loss_hist = []
min_mae = config.store_starting_from_ratio
try:
    # For each epoch
    for epoch in range(config.n_epochs):
        # For each batch in the dataloader
        pbar = tqdm.tqdm(enumerate(train_loader, 0))
        epoch_loss = 0
        for i, data in pbar:
            model.zero_grad()
            data.to(device)
            target = data.y[:, target_idx]
            prediction = model(data.z, data.pos, data.batch)
            output = loss(prediction.view(-1), target)
            mae = (prediction.view(-1) - target).abs().mean()
            epoch_loss = (epoch_loss*i + mae.item())/(i+1)
            
            pbar.set_description('Epoch %d: MAE/CA %2.6f'%(epoch+1, epoch_loss/chemical_accuracy[target_idx]))
            output.backward()
            optimizer.step()
            wandb.log({'MSE': output.item()})
        curr_lr = list(optimizer.param_groups)[0]['lr']
        wandb.log({'MAE':epoch_loss, 
                   'MAE/CA':epoch_loss/chemical_accuracy[target_idx],
                   'learning rate':curr_lr})
        lr_scheduler.step(epoch_loss)
        loss_hist += [epoch_loss] 

        if epoch_loss/chemical_accuracy[target_idx] < min_mae*config.required_improvement:
            min_mae = epoch_loss/chemical_accuracy[target_idx]
            torch.save(model.state_dict(), config.model_dir+'schnet_epoch%d.pt'%(epoch+1))
        if curr_lr < config.minimal_lr:
            break
        if epoch_loss/chemical_accuracy[target_idx] < config.target_ratio:
            break

except KeyboardInterrupt:
    print('keyboard interrupt caught')
    torch.save(model.state_dict(), config.model_dir+'schnet_epoch%d.pt'%(epoch+1))

Epoch 1: MAE/CA 475.639698: : 860it [00:56, 15.21it/s] 
Epoch 2: MAE/CA 114.884713: : 860it [00:29, 29.43it/s]
Epoch 3: MAE/CA 90.428354: : 860it [00:29, 29.20it/s] 
Epoch 4: MAE/CA 81.045134: : 860it [00:29, 28.73it/s] 
Epoch 5: MAE/CA 74.642198: : 860it [00:29, 29.17it/s]
Epoch 6: MAE/CA 65.803200: : 860it [00:30, 28.56it/s]
Epoch 7: MAE/CA 62.368607: : 860it [00:29, 29.37it/s] 
Epoch 8: MAE/CA 64.920311: : 860it [00:29, 28.91it/s]
Epoch 9: MAE/CA 62.342441: : 860it [00:32, 26.46it/s]
Epoch 10: MAE/CA 55.558642: : 860it [00:32, 26.31it/s]


In [16]:
pbar = tqdm.tqdm(enumerate(test_loader, 0))
epoch_loss = 0
maes = []
for i, data in pbar:
    data.to(device)
    prediction = model(data.z, data.pos, data.batch)
    mae = (prediction.view(-1) - data.y[:, target_idx]).abs()
    maes += [mae.detach().cpu()]
maes = torch.cat(maes, dim=0)
mae = maes.mean().item()
print(mae, mae/chemical_accuracy[target_idx])
wandb.log({'TEST MAE':mae, 'TEST MAE/CA':mae/chemical_accuracy[target_idx]})


339it [00:07, 44.22it/s]

0.9423131346702576 21.914258945819945



