In [1]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets.qm9 import QM9
from torch_geometric.data import DataLoader
import torch_geometric.nn as tgnn 

In [2]:
import numpy as np

In [3]:
# [0] Reports MAE in eV / Chemical Accuracy of the target variable U0. 
# The chemical accuracy of U0 is 0.043 see [1, Table 5].

# Reproduced table [0]
# MXMNet: 0.00590/0.043 = 0.13720930232558143
# HMGNN:  0.00592/0.043 = 0.13767441860465118
# MPNN:   0.01935/0.043 = 0.45
# KRR:    0.0251 /0.043 = 0.5837209302325582
# [0] https://paperswithcode.com/sota/formation-energy-on-qm9
# [1] Neural Message Passing for Quantum Chemistry, https://arxiv.org/pdf/1704.01212v2.pdf
# MXMNet https://arxiv.org/pdf/2011.07457v1.pdf
# HMGNN https://arxiv.org/pdf/2009.12710v1.pdf
# MPNN https://arxiv.org/pdf/1704.01212v2.pdf
# KRR HDAD kernel ridge regression https://arxiv.org/pdf/1702.05532.pdf
# HDAD means HDAD (Histogram of distances, anglesand dihedral angles)

# [2] Reports the average value of MAE / Chemical Accuracy of over all targets
# [2] https://paperswithcode.com/sota/drug-discovery-on-qm9


# get rid of the degenerate molecules

In [4]:
from urllib import request
import tempfile
import os
at_url = "https://ndownloader.figshare.com/files/3195404"
tmpdir = tempfile.mkdtemp("gdb9")
tmp_path = os.path.join(tmpdir, "uncharacterized.txt")
request.urlretrieve(at_url, tmp_path)

evilmols = []
with open(tmp_path) as f:
    lines = f.readlines()
    for line in lines[9:-1]:
        evilmols.append(int(line.split()[0]))
evilgdbs = ['gdb_%d'%id for id in evilmols]

In [5]:
pre_filter = lambda d: (d.name not in evilgdbs)

In [6]:
dataset = QM9('../datasets/qm9_geometric/', pre_filter=pre_filter)#, pre_filter=pre_filter)

In [7]:
# actually QM9 already automatically gets rid of all the gad examples -.-

In [8]:
dataset = dataset.shuffle()
train_dataset = dataset[:110000]
valid_dataset = dataset[110000:120000]
test_dataset = dataset[120000:]
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=128)
test_loader = DataLoader(test_dataset, batch_size=128)

In [9]:
batch = next(iter(train_loader))

In [10]:
# investigate batch

In [11]:
batch['batch']

tensor([  0,   0,   0,  ..., 127, 127, 127])

In [12]:
(batch['batch'].detach().numpy() == 2).sum()

22

In [13]:
# node features:

In [14]:
batch.x # one_hot(type), atomic_number, aromatic, sp1, sp2, sp3, num_hs -> 5+1+1+1+1+1+1 = 11

tensor([[0., 1., 0.,  ..., 0., 0., 3.],
        [0., 1., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.]])

In [15]:
batch.x.shape

torch.Size([2297, 11])

In [16]:
batch.pos

tensor([[-0.0049,  1.5117, -0.1215],
        [-0.0176,  0.0170,  0.0694],
        [ 1.2530, -0.6577, -0.0063],
        ...,
        [ 0.9842, -2.2380, -3.5032],
        [ 2.3951, -2.3642, -2.4359],
        [ 1.5303, -4.6011, -1.5787]])

In [17]:
batch.z # atomic number

tensor([6, 6, 8,  ..., 1, 1, 1])

In [18]:
# edge features:

In [19]:
batch.edge_index

tensor([[   0,    0,    0,  ..., 2294, 2295, 2296],
        [   1,    9,   10,  ..., 2278, 2278, 2279]])

In [20]:
batch.edge_attr # one_hot(bond_type) -> 4 single, double, triple, aromatic

tensor([[1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        ...,
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.]])

In [21]:
# targets 
batch.y[:, 7] # this target is U0

tensor([-10532.7139, -11543.3477, -11003.2148, -12554.8350, -11542.6406,
         -9927.1943, -11375.0762, -10971.3789, -11811.9580, -10937.3184,
        -12521.1416, -10532.0996, -12521.8428, -11947.6113, -10499.3945,
        -10969.3662, -11340.6250, -10474.8252, -10441.5195, -11510.4395,
        -11511.0371, -11542.1904, -10530.3506,  -9902.7480, -12894.7676,
         -8832.9922, -11890.0645, -10968.8594, -13467.5547, -10395.8750,
        -12319.5137, -10566.7725,  -9555.1309,  -7854.6553,  -9868.3984,
         -9843.8320, -10936.9795, -12487.2910, -13500.3174, -12790.3838,
        -11948.7891, -10598.8477, -12014.4307, -10936.4736, -11480.0479,
        -11811.1992,  -8890.2744,  -8426.6377, -12013.2822, -10499.4492,
        -10473.8096, -11543.1904, -11542.8223, -11003.9316, -11510.0625,
        -10473.0088, -10462.4971, -10532.9219, -10501.2314, -11477.4102,
        -12352.7998, -13547.3076, -10531.9941,  -8920.7334, -11270.7197,
        -10766.8135, -10936.1865, -10396.8125, -119

In [22]:
ngpu=1
device = torch.device('cuda:0' if (torch.cuda.is_available() and ngpu > 0) else 'cpu')

# Model

In [23]:
# replicate SchNet result
device

device(type='cuda', index=0)

In [24]:
model = tgnn.SchNet()

# Training

In [25]:
model(batch.z, batch.x, batch.batch)

tensor([[  1.0953],
        [  2.0611],
        [ -5.3898],
        [  3.3855],
        [  2.8477],
        [ 10.7769],
        [ 12.4857],
        [  6.9972],
        [ 13.0189],
        [ 12.0838],
        [ 10.3522],
        [  0.3201],
        [ 12.2930],
        [ 12.6779],
        [ 10.0587],
        [  6.6310],
        [ 13.4541],
        [ 10.5241],
        [ 13.1216],
        [ 11.2627],
        [ 10.2725],
        [  2.8477],
        [  1.0953],
        [ 10.8157],
        [ 14.1934],
        [ 11.0020],
        [ 13.6529],
        [  6.9964],
        [ 16.0289],
        [  8.1846],
        [ 12.4389],
        [-12.9778],
        [-16.8434],
        [  8.9305],
        [ 12.5748],
        [ 11.4530],
        [ 11.9472],
        [ 14.1909],
        [ 14.5417],
        [ 13.0255],
        [ 13.2505],
        [-37.8239],
        [ -2.5261],
        [ 11.5590],
        [ 14.2496],
        [ 12.7822],
        [ 10.0682],
        [  7.7985],
        [ -4.8652],
        [  9.8292],
