In [2]:
%load_ext autoreload
%autoreload 2

import os.path as osp

import torch

from torch_geometric.nn import SchNet
from torch_geometric.data import DataLoader
from torch_geometric.datasets import QM9

path = '../datasets/qm9_geometric_schnet/'
dataset = QM9(path)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


target_dict = {0: 'mu, D, Dipole moment', 
               1: 'alpha, {a_0}^3, Isotropic polarizability', 
               2: 'epsilon_{HOMO}, eV, Highest occupied molecular orbital energy',
               3: 'epsilon_{LUMO}, eV, Lowest unoccupied molecular orbital energy',
               4: 'Delta, eV, Gap between HOMO and LUMO',
               5: '< R^2 >, {a_0}^2, Electronic spatial extent',
               6: 'ZPVE, eV, Zero point vibrational energy', 
               7: 'U_0, eV, Internal energy at 0K',
               8: 'U, eV, Internal energy at 298.15K', 
               9: 'H, eV, Enthalpy at 298.15K',
               10: 'G, eV, Free energy at 298.15K',  
               11: 'c_{v}, cal\(mol K), Heat capacity at 298.15K'}

chemical_accuracy = {idx:0.043 for idx in range(12)}
chemical_accuracy[0] = 0.1
chemical_accuracy[1] = 0.1
chemical_accuracy[5] = 1.2
chemical_accuracy[6] = 0.0012
chemical_accuracy[11] = 0.050

for target in range(12):
    model, datasets = SchNet.from_qm9_pretrained(path, dataset, target)
    train_dataset, val_dataset, test_dataset = datasets

    model = model.to(device)
    loader = DataLoader(test_dataset, batch_size=256)

    maes = []
    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            pred = model(data.z, data.pos, data.batch)
        mae = (pred.view(-1) - data.y[:, target]).abs()
        maes.append(mae)

    mae = torch.cat(maes, dim=0)

    print(f'Target: {target:02d}, MAE: {mae.mean():.5f} ± {mae.std():.5f}, {target_dict[target]}')
    print(f'Target: {target:02d}, MAE/CA: {mae.mean()/chemical_accuracy[target]:.5f}')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Target: 00, MAE: 0.02079 ± 0.03071, mu, D, Dipole moment
Target: 00, MAE/CA: 0.20793
Target: 01, MAE: 0.12105 ± 0.39941, alpha, {a_0}^3, Isotropic polarizability
Target: 01, MAE/CA: 1.21050
Target: 02, MAE: 0.04657 ± 0.04822, epsilon_{HOMO}, eV, Highest occupied molecular orbital energy
Target: 02, MAE/CA: 1.08298
Target: 03, MAE: 0.03818 ± 0.05109, epsilon_{LUMO}, eV, Lowest unoccupied molecular orbital energy
Target: 03, MAE/CA: 0.88785
Target: 04, MAE: 0.07387 ± 0.07824, Delta, eV, Gap between HOMO and LUMO
Target: 04, MAE/CA: 1.71791
Target: 05, MAE: 0.15656 ± 0.34679, < R^2 >, {a_0}^2, Electronic spatial extent
Target: 05, MAE/CA: 0.13047
Target: 06, MAE: 0.00160 ± 0.00165, ZPVE, eV, Zero point vibrational energy
Target: 06, MAE/CA: 1.33331
Target: 07, MAE: 0.01200 ± 0.03197, U_0, eV, Internal energy at 0K
Target: 07, MAE/CA: 0.27908
Target: 08, MAE: 0.01195 ± 0.02293, U, eV, Internal energy at

In [3]:
loader = DataLoader(train_dataset, batch_size=256)

In [4]:
data = next(iter(loader))

In [5]:
data.y[:, 7]

tensor([-13500.2002,  -9869.9160, -11420.4141, -11003.3154, -10474.6338,
        -11476.8008, -12386.4258, -10968.1787, -10904.2373, -10462.1357,
        -11510.2188, -12521.2744, -11980.6719, -11948.8447, -12384.2314,
        -13466.5732, -10501.5967, -10430.1768, -11509.4023, -11285.2617,
        -11543.8369, -12926.9893, -11342.2207,  -9958.5605, -11576.7158,
         -9372.0371,  -9990.7021, -10499.7939, -11543.4561, -10598.2959,
        -10879.2910, -11915.9258,  -9990.1885, -12522.2549, -10936.3916,
        -10501.0518, -11440.4834, -11316.8711, -11881.3984, -11948.5596,
         -9521.3730, -10393.1055, -12215.1455, -11890.7637, -12353.0352,
        -11577.3203, -12488.8477, -11406.5771, -10970.8096, -12554.3760,
        -11001.4766, -12385.2832, -11373.2334, -12417.5645, -12384.1992,
        -10498.6953, -12249.3037, -11575.7588, -10531.3691, -11341.7100,
        -10432.2412, -11374.2295, -11452.7656, -10501.0742, -10865.5889,
         -9901.2715, -13466.6680, -11812.1357, -115