In [1]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets.qm9 import QM9
from torch_geometric.data import DataLoader
import torch_geometric.nn as tgnn 
import tqdm
import numpy as np

In [2]:
# [0] Reports MAE in eV / Chemical Accuracy of the target variable U0. 
# The chemical accuracy of U0 is 0.043 see [1, Table 5].

# Reproduced table [0]
# MXMNet: 0.00590/0.043 = 0.13720930232558143
# HMGNN:  0.00592/0.043 = 0.13767441860465118
# MPNN:   0.01935/0.043 = 0.45
# KRR:    0.0251 /0.043 = 0.5837209302325582
# [0] https://paperswithcode.com/sota/formation-energy-on-qm9
# [1] Neural Message Passing for Quantum Chemistry, https://arxiv.org/pdf/1704.01212v2.pdf
# MXMNet https://arxiv.org/pdf/2011.07457v1.pdf
# HMGNN https://arxiv.org/pdf/2009.12710v1.pdf
# MPNN https://arxiv.org/pdf/1704.01212v2.pdf
# KRR HDAD kernel ridge regression https://arxiv.org/pdf/1702.05532.pdf
# HDAD means HDAD (Histogram of distances, anglesand dihedral angles)

# [2] Reports the average value of MAE / Chemical Accuracy of over all targets
# [2] https://paperswithcode.com/sota/drug-discovery-on-qm9
target_dict = {0: 'mu, D, Dipole moment', 
               1: 'alpha, {a_0}^3, Isotropic polarizability', 
               2: 'epsilon_{HOMO}, eV, Highest occupied molecular orbital energy',
               3: 'epsilon_{LUMO}, eV, Lowest unoccupied molecular orbital energy',
               4: 'Delta, eV, Gap between HOMO and LUMO',
               5: '< R^2 >, {a_0}^2, Electronic spatial extent',
               6: 'ZPVE, eV, Zero point vibrational energy', 
               7: 'U_0, eV, Internal energy at 0K',
               8: 'U, eV, Internal energy at 298.15K', 
               9: 'H, eV, Enthalpy at 298.15K',
               10: 'G, eV, Free energy at 298.15K',  
               11: 'c_{v}, cal\(mol K), Heat capacity at 298.15K'}

chemical_accuracy = {idx:0.043 for idx in range(12)}
chemical_accuracy[0] = 0.1
chemical_accuracy[1] = 0.1
chemical_accuracy[5] = 1.2
chemical_accuracy[6] = 0.0012
chemical_accuracy[11] = 0.050

# get rid of the degenerate molecules

In [3]:
from urllib import request
import tempfile
import os
at_url = "https://ndownloader.figshare.com/files/3195404"
tmpdir = tempfile.mkdtemp("gdb9")
tmp_path = os.path.join(tmpdir, "uncharacterized.txt")
request.urlretrieve(at_url, tmp_path)

evilmols = []
with open(tmp_path) as f:
    lines = f.readlines()
    for line in lines[9:-1]:
        evilmols.append(int(line.split()[0]))
evilgdbs = ['gdb_%d'%id for id in evilmols]

In [4]:
pre_filter = lambda d: (d.name not in evilgdbs)

In [5]:
dataset = QM9('../datasets/qm9_geometric/', pre_filter=pre_filter)#, pre_filter=pre_filter)

In [6]:
# actually QM9 already automatically gets rid of all the gad examples -.-

In [7]:
dataset = dataset.shuffle()
train_dataset = dataset[:110000]
valid_dataset = dataset[110000:120000]
test_dataset = dataset[120000:]
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=128)
test_loader = DataLoader(test_dataset, batch_size=32)

In [8]:
batch = next(iter(train_loader))

In [9]:
# investigate batch

In [10]:
batch['batch']

tensor([  0,   0,   0,  ..., 127, 127, 127])

In [11]:
(batch['batch'].detach().numpy() == 2).sum()

17

In [12]:
# node features:

In [13]:
batch.x # one_hot(type), atomic_number, aromatic, sp1, sp2, sp3, num_hs -> 5+1+1+1+1+1+1 = 11

tensor([[0., 1., 0.,  ..., 0., 0., 3.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        ...,
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.]])

In [14]:
batch.x.shape

torch.Size([2312, 11])

In [15]:
batch.pos

tensor([[-0.3588,  1.5057,  0.0280],
        [-0.0104,  0.0281,  0.0266],
        [-1.1223, -0.7992, -0.2756],
        ...,
        [-0.6367,  2.0221,  0.3932],
        [ 2.5185,  3.9498, -0.0540],
        [ 0.7527,  4.0329,  0.4794]])

In [16]:
batch.z # atomic number

tensor([6, 6, 8,  ..., 1, 1, 1])

In [17]:
# edge features:

In [18]:
batch.edge_index

tensor([[   0,    0,    0,  ..., 2309, 2310, 2311],
        [   1,    9,   10,  ..., 2301, 2303, 2303]])

In [19]:
batch.edge_attr # one_hot(bond_type) -> 4 single, double, triple, aromatic

tensor([[1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        ...,
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.]])

In [20]:
# targets 
batch.y[:, 7] # this target is U0

tensor([-12488.3291, -10903.4668, -12385.6064, -11982.4629, -10970.1846,
        -11948.2627, -10873.4121, -10971.1172, -12384.2012, -10442.4688,
        -10498.6143, -10565.3018, -11375.7627, -11980.6533,  -9521.0508,
        -11938.5840, -10967.2461,  -8393.6201,  -9867.2900, -10396.1357,
        -11575.9082,  -8425.9336, -11915.0820, -12860.6611, -11913.5195,
        -10935.9746,  -9620.8193, -10969.2451, -12385.6582, -16737.6719,
         -8830.1855, -11350.0332, -13548.5205, -11510.8398, -12352.2871,
        -11980.6904, -10833.9395, -12014.1719, -11811.1729, -10329.2812,
        -10394.2656, -10442.6445, -12489.1924, -12521.8496, -11845.3486,
         -9359.1309, -11373.7695, -11884.3643,  -9894.4697, -11476.3525,
        -10938.9883, -10396.7402, -12959.3730, -12823.4238, -10395.8770,
        -13332.7354,  -8923.0596, -11843.8457, -10498.7334, -11916.9492,
        -12457.6152, -10938.1016,  -9556.3135, -10847.7900, -11375.5811,
        -11811.9688,  -9430.5146, -10969.1514, -119

In [21]:
ngpu=1
device = torch.device('cuda:0' if (torch.cuda.is_available() and ngpu > 0) else 'cpu')

# Model

In [22]:
model = tgnn.SchNet().to(device)

In [23]:
lr = 0.001
n_epochs = 1000
patience = 3
factor = 0.5
target_idx = 7

In [24]:
target_vec = [d.y[0, target_idx] for d in train_dataset]
target_mean = np.mean(target_vec)
target_std = np.std(target_vec)

In [25]:
def normalize(target, mean=target_mean, std=target_std):
    return (target - mean)/std

def denormalize(target, mean=target_mean, std=target_std):
    return target*std + mean

In [26]:
loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', verbose=True, patience=patience, factor=factor)

# Training

In [None]:
loss_hist = []
try:
    # For each epoch
    for epoch in range(n_epochs):
        # For each batch in the dataloader
        pbar = tqdm.tqdm(enumerate(train_loader, 0))
        epoch_loss = 0
        for i, data in pbar:
            model.zero_grad()
            data.to(device)
            target = normalize(data.y[:,target_idx])
            prediction = model(data.z, data.x, data.batch)
            output = loss(prediction.view(-1), target)
            mae = (denormalize(prediction.view(-1)) - target).abs().mean()
            epoch_loss = (epoch_loss*i + mae.item())/(i+1)
            
            pbar.set_description('MAE/CA %2.6f'%(epoch_loss/chemical_accuracy[target_idx]))
            output.backward()
            optimizer.step()
        
        lr_scheduler.step(epoch_loss)
        loss_hist += [epoch_loss]

except KeyboardInterrupt:
    print('keyboard interrupt caught')

MAE/CA 264989.273908: : 860it [00:29, 29.57it/s]
MAE/CA 260139.460795: : 121it [00:04, 29.07it/s]

In [None]:
model = model.to('cpu')
pbar = tqdm.tqdm(enumerate(test_loader, 0))
epoch_loss = 0
maes = []
for i, data in pbar:
    data.to('cpu')
    prediction = model(data.z, data.x, data.batch)
    mae = (prediction.view(-1) - data.y[:, target_idx]).abs()
    maes += [mae]
maes = torch.cat(maes, dim=0)
mae = maes.mean().item()
print(mae, mae/chemical_accuracy[target_idx])
