In [1]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets.qm9 import QM9
import torch_geometric.datasets.qm9 as qm9
from torch_geometric.data import DataLoader
import torch_geometric.nn as tgnn
from torch_scatter import scatter
import tqdm
import numpy as np

In [2]:
# [0] Reports MAE in eV / Chemical Accuracy of the target variable U0. 
# The chemical accuracy of U0 is 0.043 see [1, Table 5].

# Reproduced table [0]
# MXMNet: 0.00590/0.043 = 0.13720930232558143
# HMGNN:  0.00592/0.043 = 0.13767441860465118
# MPNN:   0.01935/0.043 = 0.45
# KRR:    0.0251 /0.043 = 0.5837209302325582
# [0] https://paperswithcode.com/sota/formation-energy-on-qm9
# [1] Neural Message Passing for Quantum Chemistry, https://arxiv.org/pdf/1704.01212v2.pdf
# MXMNet https://arxiv.org/pdf/2011.07457v1.pdf
# HMGNN https://arxiv.org/pdf/2009.12710v1.pdf
# MPNN https://arxiv.org/pdf/1704.01212v2.pdf
# KRR HDAD kernel ridge regression https://arxiv.org/pdf/1702.05532.pdf
# HDAD means HDAD (Histogram of distances, anglesand dihedral angles)

# [2] Reports the average value of MAE / Chemical Accuracy of over all targets
# [2] https://paperswithcode.com/sota/drug-discovery-on-qm9
target_dict = {0: 'mu, D, Dipole moment', 
               1: 'alpha, {a_0}^3, Isotropic polarizability', 
               2: 'epsilon_{HOMO}, eV, Highest occupied molecular orbital energy',
               3: 'epsilon_{LUMO}, eV, Lowest unoccupied molecular orbital energy',
               4: 'Delta, eV, Gap between HOMO and LUMO',
               5: '< R^2 >, {a_0}^2, Electronic spatial extent',
               6: 'ZPVE, eV, Zero point vibrational energy', 
               7: 'U_0, eV, Internal energy at 0K',
               8: 'U, eV, Internal energy at 298.15K', 
               9: 'H, eV, Enthalpy at 298.15K',
               10: 'G, eV, Free energy at 298.15K',  
               11: 'c_{v}, cal\(mol K), Heat capacity at 298.15K'}

chemical_accuracy = {idx:0.043 for idx in range(12)}
chemical_accuracy[0] = 0.1
chemical_accuracy[1] = 0.1
chemical_accuracy[5] = 1.2
chemical_accuracy[6] = 0.0012
chemical_accuracy[11] = 0.050

# get rid of the degenerate molecules

In [3]:
from urllib import request
import tempfile
import os
at_url = "https://ndownloader.figshare.com/files/3195404"
tmpdir = tempfile.mkdtemp("gdb9")
tmp_path = os.path.join(tmpdir, "uncharacterized.txt")
request.urlretrieve(at_url, tmp_path)

evilmols = []
with open(tmp_path) as f:
    lines = f.readlines()
    for line in lines[9:-1]:
        evilmols.append(int(line.split()[0]))
evilgdbs = ['gdb_%d'%id for id in evilmols]

In [4]:
pre_filter = lambda d: (d.name not in evilgdbs)

In [5]:
dataset = QM9('../datasets/qm9_geometric/', pre_filter=pre_filter)#, pre_filter=pre_filter)

In [6]:
# actually QM9 already automatically gets rid of all the gad examples -.-

In [7]:
dataset = dataset.shuffle()
train_dataset = dataset[:110000]
valid_dataset = dataset[110000:120000]
test_dataset = dataset[120000:]
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=128)
test_loader = DataLoader(test_dataset, batch_size=32)

In [13]:
data = next(iter(train_loader))

In [14]:
# investigate batch

In [15]:
data.batch

tensor([  0,   0,   0,  ..., 127, 127, 127])

In [16]:
(data.batch.detach().numpy() == 2).sum()

20

In [17]:
# node features:

In [18]:
batch.x # one_hot(type), atomic_number, aromatic, sp1, sp2, sp3, num_hs -> 5+1+1+1+1+1+1 = 11

tensor([[0., 0., 0.,  ..., 0., 0., 1.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 2.],
        ...,
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.]])

In [19]:
batch.x.shape

torch.Size([2355, 11])

In [20]:
batch.pos

tensor([[ 0.3226,  1.2292, -0.3572],
        [ 0.1520, -0.1357, -0.0892],
        [ 1.1449, -1.0083, -0.8150],
        ...,
        [-1.9524, -0.1800,  0.7559],
        [-1.8315, -0.2886, -1.0171],
        [-1.2324, -5.8001,  0.9035]])

In [21]:
batch.z # atomic number

tensor([8, 6, 6,  ..., 1, 1, 1])

In [17]:
# edge features:

In [18]:
batch.edge_index

tensor([[   0,    1,    1,  ..., 2335, 2336, 2337],
        [   1,    0,    2,  ..., 2325, 2325, 2326]])

In [19]:
batch.edge_attr # one_hot(bond_type) -> 4 single, double, triple, aromatic

tensor([[0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        ...,
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.]])

In [20]:
# targets 
batch.y[:, 7] # this target is U0

tensor([ -9901.8848, -13880.3682, -10429.8203, -12352.6895, -12386.7949,
        -11477.9453, -10970.4033, -11510.7705, -10397.1201, -10936.2412,
        -11947.5918, -10936.9834, -12418.4424,  -9990.6123, -11812.5859,
        -10598.2139, -12487.3730, -10201.7812, -12488.0674, -11375.9658,
        -10935.9336, -11203.6748, -10501.2666, -10365.5938, -11981.1465,
        -10969.7646,  -9959.0596, -13445.0566, -11406.6924, -10303.0293,
         -9899.7002, -11001.3623, -12892.2529, -10566.2627,  -8270.6357,
        -11542.9277, -11412.1387, -10442.9336, -11374.1250, -10800.6699,
        -12521.3145, -10969.0459, -11509.5977, -12553.3594, -11337.9521,
        -10800.9102, -11543.5244, -11372.9795, -11542.6953, -11947.9307,
         -9866.8457, -11406.8066, -10970.1709,  -9926.0098, -10396.0488,
         -9488.0020, -10397.0566, -10474.3574, -10566.3066, -11477.5166,
         -9797.6631, -12353.7471, -11510.5088, -12354.1094, -10442.4521,
        -10935.4834, -10532.0947, -13434.8857, -113

In [21]:
ngpu=1
device = torch.device('cuda:0' if (torch.cuda.is_available() and ngpu > 0) else 'cpu')

# Model

In [22]:
lr = 0.005
n_epochs = 10000
patience = 1
factor = 0.96
target_idx = 7

dataset.atomref(target_idx)

In [23]:
target_vec = []

In [24]:
for data in train_loader:
    data = data.to(device)
    atomU0s = torch.tensor(qm9.atomrefs[target_idx], device=device)[torch.argmax(data.x[:, :5], axis=1)]
    target_modular = scatter(atomU0s, data.batch, dim=-1, reduce='sum')
    target_vec += [(data.y[:, target_idx] - target_modular).detach().cpu().numpy()]
target_vec = np.concatenate(target_vec, axis=0)

In [25]:
target_mean = np.mean(target_vec)
target_std = np.std(target_vec)

In [26]:
def normalize(target, mean=target_mean, std=target_std):
    return (target - mean)/std

def denormalize(target, mean=target_mean, std=target_std):
    return target*std + mean

In [27]:
model = tgnn.SchNet(atomref=dataset.atomref(target_idx), mean=target_mean, std=target_std).to(device)
loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', verbose=True, patience=patience, factor=factor)

# Training

In [None]:
loss_hist = []
try:
    # For each epoch
    for epoch in range(n_epochs):
        # For each batch in the dataloader
        pbar = tqdm.tqdm(enumerate(train_loader, 0))
        epoch_loss = 0
        for i, data in pbar:
            model.zero_grad()
            data.to(device)
            # prepare target by subtracting the modular estimate
            #if target_idx in qm9.atomrefs:
            #    atomU0s = torch.tensor(qm9.atomrefs[target_idx], device=device)[torch.argmax(data.x[:, :5], axis=1)]
            #    target_modular = scatter(atomU0s, data.batch, dim=-1, reduce='sum')
                #target_modular = torch.tensor([atomU0s[data.batch == idx].sum().item() for idx in range(data.y.shape[0])], device=device)
            #    target = data.y[:,target_idx] - target_modular
            #else:
            #    target = data.y[:, target_idx]
            target = data.y[:, target_idx]
            prediction = model(data.z, data.pos, data.batch)
            output = loss(prediction.view(-1), target)
            mae = (prediction.view(-1) - target).abs().mean()
            #output = loss(prediction.view(-1), normalize(target))
            #mae = (denormalize(prediction.view(-1)) - target).abs().mean()
            epoch_loss = (epoch_loss*i + mae.item())/(i+1)
            
            pbar.set_description('MAE/CA %2.6f'%(epoch_loss/chemical_accuracy[target_idx]))
            output.backward()
            optimizer.step()

        lr_scheduler.step(epoch_loss)
        loss_hist += [epoch_loss]
        if epoch_loss/chemical_accuracy[target_idx] < 0.45:
            break

except KeyboardInterrupt:
    print('keyboard interrupt caught')

MAE/CA 1955.202223: : 860it [00:29, 28.84it/s] 
MAE/CA 154.254144: : 860it [00:29, 29.11it/s]
MAE/CA 134.656876: : 860it [00:30, 28.42it/s]
MAE/CA 124.570416: : 860it [00:29, 29.10it/s]
MAE/CA 99.660748: : 116it [00:04, 28.66it/s] 

In [None]:
model = model.to('cpu')
pbar = tqdm.tqdm(enumerate(test_loader, 0))
epoch_loss = 0
maes = []
for i, data in pbar:
    data.to('cpu')
    prediction = model(data.z, data.x, data.batch)
    mae = (denormalize(prediction.view(-1)) - data.y[:, target_idx]).abs()
    maes += [mae]
maes = torch.cat(maes, dim=0)
mae = maes.mean().item()
print(mae, mae/chemical_accuracy[target_idx])
