In [1]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets.qm9 import QM9
from torch_geometric.data import DataLoader
import torch_geometric.nn as tgnn 
import tqdm
import numpy as np

In [2]:
# [0] Reports MAE in eV / Chemical Accuracy of the target variable U0. 
# The chemical accuracy of U0 is 0.043 see [1, Table 5].

# Reproduced table [0]
# MXMNet: 0.00590/0.043 = 0.13720930232558143
# HMGNN:  0.00592/0.043 = 0.13767441860465118
# MPNN:   0.01935/0.043 = 0.45
# KRR:    0.0251 /0.043 = 0.5837209302325582
# [0] https://paperswithcode.com/sota/formation-energy-on-qm9
# [1] Neural Message Passing for Quantum Chemistry, https://arxiv.org/pdf/1704.01212v2.pdf
# MXMNet https://arxiv.org/pdf/2011.07457v1.pdf
# HMGNN https://arxiv.org/pdf/2009.12710v1.pdf
# MPNN https://arxiv.org/pdf/1704.01212v2.pdf
# KRR HDAD kernel ridge regression https://arxiv.org/pdf/1702.05532.pdf
# HDAD means HDAD (Histogram of distances, anglesand dihedral angles)

# [2] Reports the average value of MAE / Chemical Accuracy of over all targets
# [2] https://paperswithcode.com/sota/drug-discovery-on-qm9
target_dict = {0: 'mu, D, Dipole moment', 
               1: 'alpha, {a_0}^3, Isotropic polarizability', 
               2: 'epsilon_{HOMO}, eV, Highest occupied molecular orbital energy',
               3: 'epsilon_{LUMO}, eV, Lowest unoccupied molecular orbital energy',
               4: 'Delta, eV, Gap between HOMO and LUMO',
               5: '< R^2 >, {a_0}^2, Electronic spatial extent',
               6: 'ZPVE, eV, Zero point vibrational energy', 
               7: 'U_0, eV, Internal energy at 0K',
               8: 'U, eV, Internal energy at 298.15K', 
               9: 'H, eV, Enthalpy at 298.15K',
               10: 'G, eV, Free energy at 298.15K',  
               11: 'c_{v}, cal\(mol K), Heat capacity at 298.15K'}

chemical_accuracy = {idx:0.043 for idx in range(12)}
chemical_accuracy[0] = 0.1
chemical_accuracy[1] = 0.1
chemical_accuracy[5] = 1.2
chemical_accuracy[6] = 0.0012
chemical_accuracy[11] = 0.050

# get rid of the degenerate molecules

In [3]:
from urllib import request
import tempfile
import os
at_url = "https://ndownloader.figshare.com/files/3195404"
tmpdir = tempfile.mkdtemp("gdb9")
tmp_path = os.path.join(tmpdir, "uncharacterized.txt")
request.urlretrieve(at_url, tmp_path)

evilmols = []
with open(tmp_path) as f:
    lines = f.readlines()
    for line in lines[9:-1]:
        evilmols.append(int(line.split()[0]))
evilgdbs = ['gdb_%d'%id for id in evilmols]

In [4]:
pre_filter = lambda d: (d.name not in evilgdbs)

In [5]:
dataset = QM9('../datasets/qm9_geometric/', pre_filter=pre_filter)#, pre_filter=pre_filter)

In [6]:
# actually QM9 already automatically gets rid of all the gad examples -.-

In [7]:
dataset = dataset.shuffle()
train_dataset = dataset[:110000]
valid_dataset = dataset[110000:120000]
test_dataset = dataset[120000:]
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=128)
test_loader = DataLoader(test_dataset, batch_size=32)

In [8]:
batch = next(iter(train_loader))

In [9]:
# investigate batch

In [10]:
batch['batch']

tensor([  0,   0,   0,  ..., 127, 127, 127])

In [11]:
(batch['batch'].detach().numpy() == 2).sum()

23

In [12]:
# node features:

In [13]:
batch.x # one_hot(type), atomic_number, aromatic, sp1, sp2, sp3, num_hs -> 5+1+1+1+1+1+1 = 11

tensor([[0., 1., 0.,  ..., 0., 0., 3.],
        [0., 1., 0.,  ..., 0., 0., 2.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        ...,
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.]])

In [14]:
batch.x.shape

torch.Size([2355, 11])

In [15]:
batch.pos

tensor([[-1.5166, -0.1828, -0.1133],
        [ 0.0056, -0.0814,  0.0492],
        [ 0.7901, -0.6233, -1.1367],
        ...,
        [-0.9217, -3.9321,  0.6493],
        [-0.0643, -3.2753, -3.3043],
        [-1.0692, -1.9978, -3.4906]])

In [16]:
batch.z # atomic number

tensor([6, 6, 6,  ..., 1, 1, 1])

In [17]:
# edge features:

In [18]:
batch.edge_index

tensor([[   0,    0,    0,  ..., 2352, 2353, 2354],
        [   1,    9,   10,  ..., 2344, 2346, 2346]])

In [19]:
batch.edge_attr # one_hot(bond_type) -> 4 single, double, triple, aromatic

tensor([[1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        ...,
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.]])

In [20]:
# targets 
batch.y[:, 7] # this target is U0

tensor([-11575.9463,  -9587.5293, -10566.7832, -11510.2109, -11375.6621,
        -11445.5293, -11444.1992, -11575.5596, -12926.6602,  -9866.0889,
        -11877.7139, -10598.5664, -11542.7969, -11575.9111,  -9811.1553,
         -9869.2227, -15526.9238, -10969.3643, -11238.6992, -10937.6680,
         -9360.2490, -12386.6045, -11512.3291,  -9496.2627,  -9555.1729,
        -11003.2412, -11946.9248,  -9530.2148, -10473.3555, -11419.0947,
        -13531.1602, -10834.4434, -11512.4375, -11810.7637, -12924.9268,
        -10396.0742, -11451.7021, -10736.2578, -11844.6309, -10936.7773,
        -11510.2256, -11947.7939, -12521.3105, -11812.1807, -11980.0410,
        -12570.1621, -10598.2256, -12320.2129, -11418.9541, -10306.8242,
        -11374.4512, -11949.1504, -10534.5605, -10531.8994, -10880.1689,
        -10969.2012, -10802.2783, -11916.2344, -12490.4453, -11844.0215,
         -9302.2051, -11915.2686, -11948.2002, -10533.2012, -12351.8398,
        -11512.9512, -12215.5850, -11542.5713, -115

In [21]:
ngpu=1
device = torch.device('cuda:0' if (torch.cuda.is_available() and ngpu > 0) else 'cpu')

# Model

In [22]:
model = tgnn.SchNet().to(device)

In [12]:
lr = 0.001
n_epochs = 1000
patience = 3
factor = 0.5
target_idx = 7

In [19]:
target_vec = [d.y[0, target_idx] for d in train_dataset]
target_mean = np.mean(target_vec)
target_std = np.std(target_vec)

In [21]:
def normalize(target, mean=target_mean, std=target_std):
    return (target - mean)/std

def denormalize(target, mean=target_mean, std=target_std):
    return target*std + mean

In [24]:
loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', verbose=True, patience=patience, factor=factor)

# Training

In [25]:
loss_hist = []
try:
    # For each epoch
    for epoch in range(n_epochs):
        # For each batch in the dataloader
        pbar = tqdm.tqdm(enumerate(train_loader, 0))
        epoch_loss = 0
        for i, data in pbar:
            model.zero_grad()
            data.to(device)
            target = data.y[:,target_idx]
            prediction = model(data.z, data.x, data.batch)
            output = loss(prediction.view(-1), target)
            mae = (prediction.view(-1) - target).abs().mean()
            epoch_loss = (epoch_loss*i + mae.item())/(i+1)
            
            pbar.set_description('MAE/CA %2.6f'%(epoch_loss/chemical_accuracy[target_idx]))
            output.backward()
            optimizer.step()
        
        lr_scheduler.step(epoch_loss)
        loss_hist += [epoch_loss]

except KeyboardInterrupt:
    print('keyboard interrupt caught')

MAE/CA 8777.794774: : 860it [00:37, 22.95it/s] 
MAE/CA 2884.960354: : 860it [00:29, 29.48it/s]
MAE/CA 2388.744362: : 860it [00:29, 29.40it/s]
MAE/CA 16800.808096: : 860it [00:27, 30.79it/s]
MAE/CA 3363.206487: : 860it [00:28, 30.08it/s] 
MAE/CA 1447.077370: : 860it [00:29, 29.64it/s]
MAE/CA 1379.656415: : 860it [00:28, 29.77it/s]
MAE/CA 1458.398715: : 860it [00:29, 29.26it/s]
MAE/CA 1516.325249: : 860it [00:28, 29.68it/s]
MAE/CA 1458.079929: : 860it [00:28, 30.44it/s]
MAE/CA 1601.710123: : 860it [00:29, 29.60it/s]
MAE/CA 1235.034499: : 3it [00:00, 29.04it/s]

Epoch    11: reducing learning rate of group 0 to 5.0000e-04.


MAE/CA 467.910249: : 860it [00:29, 29.63it/s]
MAE/CA 710.823362: : 860it [00:28, 30.65it/s]
MAE/CA 1569.816446: : 860it [00:28, 29.81it/s]
MAE/CA 599.948470: : 860it [00:28, 29.89it/s]
MAE/CA 830.270400: : 860it [00:28, 30.52it/s]
MAE/CA 424.466806: : 3it [00:00, 27.87it/s]

Epoch    16: reducing learning rate of group 0 to 2.5000e-04.


MAE/CA 349.809189: : 860it [00:28, 29.73it/s]
MAE/CA 495.679308: : 860it [00:28, 29.66it/s]
MAE/CA 633.943061: : 860it [00:29, 29.34it/s]
MAE/CA 639.004844: : 860it [00:28, 30.54it/s]
MAE/CA 645.258402: : 860it [00:28, 29.91it/s]
MAE/CA 214.008849: : 3it [00:00, 28.81it/s]

Epoch    21: reducing learning rate of group 0 to 1.2500e-04.


MAE/CA 180.152085: : 860it [00:28, 30.19it/s]
MAE/CA 275.282232: : 860it [00:29, 29.53it/s]
MAE/CA 368.335017: : 860it [00:28, 29.92it/s]
MAE/CA 371.273347: : 860it [00:29, 29.54it/s]
MAE/CA 356.008333: : 860it [00:28, 29.95it/s]
MAE/CA 171.942482: : 3it [00:00, 28.21it/s]

Epoch    26: reducing learning rate of group 0 to 6.2500e-05.


MAE/CA 116.553973: : 860it [00:29, 29.60it/s]
MAE/CA 174.559714: : 860it [00:28, 29.95it/s]
MAE/CA 192.935333: : 860it [00:28, 29.88it/s]
MAE/CA 200.037031: : 860it [00:27, 31.77it/s]
MAE/CA 198.601638: : 860it [00:26, 32.70it/s]
MAE/CA 153.292323: : 4it [00:00, 32.32it/s]

Epoch    31: reducing learning rate of group 0 to 3.1250e-05.


MAE/CA 90.658371: : 860it [00:28, 30.67it/s]
MAE/CA 118.184504: : 860it [00:29, 29.42it/s]
MAE/CA 120.318822: : 860it [00:27, 30.94it/s]
MAE/CA 124.571374: : 860it [00:29, 29.60it/s]
MAE/CA 111.899427: : 860it [00:28, 30.02it/s]
MAE/CA 63.624818: : 3it [00:00, 29.23it/s]

Epoch    36: reducing learning rate of group 0 to 1.5625e-05.


MAE/CA 74.925204: : 860it [00:28, 29.94it/s]
MAE/CA 79.298381: : 860it [00:28, 29.89it/s] 
MAE/CA 82.931687: : 860it [00:28, 30.22it/s]
MAE/CA 82.021807: : 860it [00:29, 29.49it/s]
MAE/CA 81.990715: : 860it [00:28, 30.24it/s]
MAE/CA 61.705937: : 3it [00:00, 27.07it/s]

Epoch    41: reducing learning rate of group 0 to 7.8125e-06.


MAE/CA 61.370633: : 860it [00:28, 29.85it/s]
MAE/CA 62.341280: : 860it [00:28, 29.74it/s]
MAE/CA 65.144089: : 860it [00:28, 29.82it/s]
MAE/CA 65.714317: : 860it [00:28, 29.89it/s]
MAE/CA 61.973558: : 860it [00:29, 29.61it/s]
MAE/CA 71.464967: : 4it [00:00, 32.21it/s]

Epoch    46: reducing learning rate of group 0 to 3.9063e-06.


MAE/CA 53.911062: : 860it [00:27, 31.01it/s]
MAE/CA 55.180577: : 860it [00:27, 30.83it/s]
MAE/CA 55.629490: : 860it [00:28, 29.68it/s]
MAE/CA 55.081618: : 860it [00:28, 30.13it/s]
MAE/CA 54.512288: : 860it [00:29, 29.48it/s]
MAE/CA 50.823404: : 3it [00:00, 29.58it/s]

Epoch    51: reducing learning rate of group 0 to 1.9531e-06.


MAE/CA 50.357352: : 860it [00:28, 30.32it/s]
MAE/CA 51.211475: : 860it [00:28, 30.47it/s]
MAE/CA 49.884306: : 860it [00:28, 29.85it/s]
MAE/CA 50.192335: : 860it [00:28, 29.73it/s]
MAE/CA 50.877346: : 860it [00:28, 29.73it/s]
MAE/CA 49.986689: : 860it [00:29, 29.62it/s]
MAE/CA 49.878298: : 860it [00:28, 30.07it/s]
MAE/CA 49.696817: : 860it [00:29, 29.37it/s]
MAE/CA 49.385212: : 860it [00:29, 29.61it/s]
MAE/CA 50.198666: : 860it [00:28, 30.38it/s]
MAE/CA 49.350295: : 860it [00:28, 30.08it/s]
MAE/CA 48.710841: : 860it [00:29, 29.65it/s]
MAE/CA 49.520917: : 860it [00:28, 29.79it/s]
MAE/CA 49.074475: : 860it [00:29, 29.58it/s]
MAE/CA 49.108368: : 860it [00:29, 29.45it/s]
MAE/CA 49.026972: : 860it [00:28, 30.22it/s]
MAE/CA 49.650784: : 3it [00:00, 28.94it/s]

Epoch    67: reducing learning rate of group 0 to 9.7656e-07.


MAE/CA 46.554962: : 860it [00:29, 29.44it/s]
MAE/CA 46.455642: : 860it [00:28, 30.20it/s]
MAE/CA 46.322601: : 860it [00:28, 29.76it/s]
MAE/CA 46.030036: : 860it [00:28, 29.70it/s]
MAE/CA 46.586703: : 860it [00:28, 30.32it/s]
MAE/CA 45.962794: : 860it [00:28, 29.93it/s]
MAE/CA 46.081492: : 860it [00:28, 30.17it/s]
MAE/CA 45.802411: : 860it [00:28, 30.02it/s]
MAE/CA 45.981452: : 860it [00:28, 29.86it/s]
MAE/CA 46.047284: : 860it [00:27, 30.79it/s]
MAE/CA 45.831753: : 860it [00:27, 31.24it/s]
MAE/CA 45.673474: : 860it [00:27, 31.64it/s]
MAE/CA 45.729802: : 860it [00:28, 30.44it/s]
MAE/CA 45.517220: : 860it [00:26, 32.59it/s]
MAE/CA 45.611710: : 860it [00:27, 31.30it/s]
MAE/CA 45.427708: : 860it [00:28, 30.15it/s]
MAE/CA 45.334091: : 860it [00:28, 30.27it/s]
MAE/CA 45.284536: : 860it [00:28, 30.19it/s]
MAE/CA 45.266596: : 860it [00:28, 30.17it/s]
MAE/CA 45.251322: : 860it [00:28, 30.18it/s]
MAE/CA 45.083523: : 860it [00:28, 30.08it/s]
MAE/CA 44.741050: : 860it [00:27, 31.03it/s]
MAE/CA 44.

Epoch   102: reducing learning rate of group 0 to 4.8828e-07.


MAE/CA 42.926700: : 860it [00:28, 30.31it/s]
MAE/CA 42.887575: : 860it [00:28, 30.21it/s]
MAE/CA 42.899702: : 860it [00:28, 30.16it/s]
MAE/CA 42.863640: : 860it [00:28, 30.49it/s]
MAE/CA 42.887905: : 860it [00:28, 30.15it/s]
MAE/CA 42.909999: : 860it [00:27, 31.01it/s]
MAE/CA 42.624647: : 860it [00:28, 30.35it/s]
MAE/CA 42.743920: : 860it [00:28, 30.32it/s]
MAE/CA 42.778102: : 860it [00:28, 30.55it/s]
MAE/CA 42.644396: : 860it [00:28, 30.31it/s]
MAE/CA 42.782712: : 860it [00:28, 30.19it/s]
MAE/CA 44.514022: : 3it [00:00, 28.29it/s]

Epoch   113: reducing learning rate of group 0 to 2.4414e-07.


MAE/CA 42.007517: : 860it [00:28, 30.06it/s]
MAE/CA 41.938400: : 860it [00:28, 30.05it/s]
MAE/CA 41.970625: : 860it [00:28, 30.26it/s]
MAE/CA 41.929128: : 860it [00:28, 30.51it/s]
MAE/CA 41.870809: : 860it [00:28, 30.12it/s]
MAE/CA 41.914580: : 860it [00:28, 30.34it/s]
MAE/CA 41.948352: : 860it [00:28, 30.20it/s]
MAE/CA 41.992890: : 860it [00:27, 30.72it/s]
MAE/CA 41.878291: : 860it [00:27, 31.21it/s]
MAE/CA 41.222757: : 4it [00:00, 32.07it/s]

Epoch   122: reducing learning rate of group 0 to 1.2207e-07.


MAE/CA 41.546769: : 860it [00:28, 30.64it/s]
MAE/CA 41.488531: : 860it [00:28, 30.70it/s]
MAE/CA 41.499975: : 860it [00:27, 31.80it/s]
MAE/CA 41.491273: : 860it [00:27, 30.78it/s]
MAE/CA 41.409851: : 860it [00:28, 30.14it/s]
MAE/CA 41.550671: : 860it [00:28, 30.07it/s]
MAE/CA 41.424678: : 860it [00:28, 30.66it/s]
MAE/CA 41.449548: : 860it [00:28, 30.08it/s]
MAE/CA 41.464037: : 860it [00:28, 30.06it/s]
MAE/CA 36.421310: : 3it [00:00, 28.77it/s]

Epoch   131: reducing learning rate of group 0 to 6.1035e-08.


MAE/CA 41.238484: : 860it [00:28, 30.64it/s]
MAE/CA 41.235602: : 860it [00:28, 30.06it/s]
MAE/CA 41.211464: : 860it [00:28, 30.32it/s]
MAE/CA 41.249015: : 860it [00:28, 30.19it/s]
MAE/CA 41.208383: : 860it [00:27, 31.10it/s]
MAE/CA 41.186187: : 860it [00:27, 30.78it/s]
MAE/CA 41.243497: : 860it [00:28, 30.24it/s]
MAE/CA 41.250133: : 860it [00:28, 30.22it/s]
MAE/CA 41.180584: : 860it [00:28, 30.30it/s]
MAE/CA 41.233380: : 860it [00:28, 30.07it/s]
MAE/CA 41.179825: : 860it [00:28, 30.17it/s]
MAE/CA 41.201961: : 860it [00:28, 30.46it/s]
MAE/CA 41.189457: : 860it [00:28, 30.25it/s]
MAE/CA 39.181924: : 3it [00:00, 28.44it/s]

Epoch   144: reducing learning rate of group 0 to 3.0518e-08.


MAE/CA 41.091128: : 860it [00:28, 30.14it/s]
MAE/CA 41.085791: : 860it [00:28, 30.41it/s]
MAE/CA 41.074552: : 860it [00:27, 30.83it/s]
MAE/CA 41.078271: : 860it [00:28, 30.37it/s]
MAE/CA 41.092762: : 860it [00:27, 31.26it/s]
MAE/CA 41.062130: : 860it [00:28, 30.14it/s]
MAE/CA 41.080654: : 860it [00:28, 30.21it/s]
MAE/CA 41.073961: : 860it [00:28, 30.13it/s]
MAE/CA 41.058163: : 860it [00:28, 30.25it/s]
MAE/CA 41.058636: : 860it [00:28, 30.01it/s]
MAE/CA 37.512934: : 3it [00:00, 29.19it/s]

Epoch   154: reducing learning rate of group 0 to 1.5259e-08.


MAE/CA 40.984478: : 860it [00:28, 30.22it/s]
MAE/CA 41.011334: : 860it [00:28, 30.30it/s]
MAE/CA 40.999743: : 860it [00:28, 30.10it/s]
MAE/CA 41.002900: : 860it [00:28, 30.24it/s]
MAE/CA 41.001175: : 860it [00:28, 30.06it/s]
MAE/CA 41.006427: : 860it [00:28, 30.17it/s]
MAE/CA 40.985799: : 860it [00:28, 30.21it/s]
MAE/CA 40.992688: : 860it [00:27, 31.39it/s]
MAE/CA 41.002267: : 860it [00:27, 31.20it/s]
MAE/CA 41.005550: : 860it [00:27, 31.46it/s]
MAE/CA 40.995493: : 860it [00:28, 30.24it/s]
MAE/CA 40.996130: : 860it [00:28, 30.27it/s]
MAE/CA 40.969406: : 860it [00:28, 30.04it/s]
MAE/CA 40.993145: : 860it [00:27, 31.34it/s]
MAE/CA 40.988431: : 860it [00:28, 30.40it/s]
MAE/CA 40.983035: : 860it [00:27, 31.69it/s]
MAE/CA 40.968317: : 860it [00:27, 31.24it/s]
MAE/CA 41.015442: : 860it [00:27, 31.04it/s]
MAE/CA 40.978364: : 860it [00:28, 30.69it/s]
MAE/CA 40.986892: : 860it [00:27, 30.90it/s]
MAE/CA 40.992684: : 860it [00:28, 30.19it/s]
MAE/CA 40.972017: : 860it [00:28, 30.29it/s]
MAE/CA 40.

MAE/CA 40.821637: : 860it [00:28, 30.20it/s]
MAE/CA 40.833683: : 860it [00:28, 30.68it/s]
MAE/CA 40.837081: : 860it [00:28, 30.55it/s]
MAE/CA 40.829015: : 860it [00:28, 30.42it/s]
MAE/CA 40.793577: : 860it [00:26, 31.97it/s]
MAE/CA 40.825576: : 860it [00:28, 30.19it/s]
MAE/CA 40.813993: : 860it [00:28, 30.17it/s]
MAE/CA 40.818861: : 860it [00:28, 29.97it/s]
MAE/CA 40.815820: : 860it [00:28, 30.15it/s]
MAE/CA 40.819121: : 860it [00:28, 30.06it/s]
MAE/CA 40.800072: : 860it [00:26, 32.07it/s]
MAE/CA 40.827596: : 860it [00:28, 30.59it/s]
MAE/CA 40.820761: : 860it [00:28, 30.29it/s]
MAE/CA 40.798005: : 860it [00:27, 31.43it/s]
MAE/CA 40.821764: : 860it [00:28, 30.14it/s]
MAE/CA 40.806848: : 860it [00:28, 30.22it/s]
MAE/CA 40.787601: : 860it [00:28, 30.26it/s]
MAE/CA 40.812266: : 860it [00:28, 30.58it/s]
MAE/CA 40.813903: : 860it [00:28, 30.16it/s]
MAE/CA 40.800989: : 860it [00:28, 30.15it/s]
MAE/CA 40.800358: : 860it [00:28, 30.03it/s]
MAE/CA 40.804379: : 860it [00:28, 30.08it/s]
MAE/CA 40.

MAE/CA 40.638441: : 860it [00:27, 30.87it/s]
MAE/CA 40.631763: : 860it [00:28, 30.48it/s]
MAE/CA 40.636997: : 860it [00:28, 30.13it/s]
MAE/CA 40.642409: : 860it [00:28, 29.86it/s]
MAE/CA 40.638884: : 860it [00:28, 29.77it/s]
MAE/CA 40.642665: : 860it [00:29, 28.81it/s]
MAE/CA 40.648631: : 860it [00:29, 28.89it/s]
MAE/CA 40.626875: : 860it [00:29, 29.53it/s]
MAE/CA 40.634348: : 860it [00:29, 29.24it/s]
MAE/CA 40.620158: : 860it [00:29, 29.35it/s]
MAE/CA 40.621501: : 860it [00:29, 29.07it/s]
MAE/CA 40.632107: : 860it [00:29, 29.25it/s]
MAE/CA 40.645292: : 860it [00:29, 29.59it/s]
MAE/CA 40.647370: : 860it [00:28, 29.84it/s]
MAE/CA 40.619040: : 860it [00:28, 29.67it/s]
MAE/CA 40.617071: : 860it [00:28, 29.78it/s]
MAE/CA 40.622986: : 860it [00:29, 29.28it/s]
MAE/CA 40.618622: : 860it [00:29, 29.64it/s]
MAE/CA 40.628799: : 860it [00:29, 29.02it/s]
MAE/CA 40.617287: : 860it [00:28, 29.75it/s]
MAE/CA 40.618477: : 860it [00:28, 29.69it/s]
MAE/CA 40.618466: : 860it [00:29, 29.03it/s]
MAE/CA 40.

keyboard interrupt caught





In [26]:
model = model.to('cpu')
pbar = tqdm.tqdm(enumerate(test_loader, 0))
epoch_loss = 0
maes = []
for i, data in pbar:
    data.to('cpu')
    prediction = model(data.z, data.x, data.batch)
    mae = (prediction.view(-1) - data.y[:, target_idx]).abs()
    maes += [mae]
maes = torch.cat(maes, dim=0)
mae = maes.mean().item()
print(mae, mae/chemical_accuracy[target_idx])


88it [00:06, 13.28it/s]


KeyboardInterrupt: 