In [1]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets.qm9 import QM9
import torch_geometric.datasets.qm9 as qm9
from torch_geometric.data import DataLoader
import torch_geometric.nn as tgnn
from torch_scatter import scatter
import tqdm
import numpy as np
import wandb

In [2]:
# [0] Reports MAE in eV / Chemical Accuracy of the target variable U0. 
# The chemical accuracy of U0 is 0.043 see [1, Table 5].

# Reproduced table [0]
# MXMNet: 0.00590/0.043 = 0.13720930232558143
# HMGNN:  0.00592/0.043 = 0.13767441860465118
# MPNN:   0.01935/0.043 = 0.45
# KRR:    0.0251 /0.043 = 0.5837209302325582
# [0] https://paperswithcode.com/sota/formation-energy-on-qm9
# [1] Neural Message Passing for Quantum Chemistry, https://arxiv.org/pdf/1704.01212v2.pdf
# MXMNet https://arxiv.org/pdf/2011.07457v1.pdf
# HMGNN https://arxiv.org/pdf/2009.12710v1.pdf
# MPNN https://arxiv.org/pdf/1704.01212v2.pdf
# KRR HDAD kernel ridge regression https://arxiv.org/pdf/1702.05532.pdf
# HDAD means HDAD (Histogram of distances, anglesand dihedral angles)

# [2] Reports the average value of MAE / Chemical Accuracy of over all targets
# [2] https://paperswithcode.com/sota/drug-discovery-on-qm9
target_dict = {0: 'mu, D, Dipole moment', 
               1: 'alpha, {a_0}^3, Isotropic polarizability', 
               2: 'epsilon_{HOMO}, eV, Highest occupied molecular orbital energy',
               3: 'epsilon_{LUMO}, eV, Lowest unoccupied molecular orbital energy',
               4: 'Delta, eV, Gap between HOMO and LUMO',
               5: '< R^2 >, {a_0}^2, Electronic spatial extent',
               6: 'ZPVE, eV, Zero point vibrational energy', 
               7: 'U_0, eV, Internal energy at 0K',
               8: 'U, eV, Internal energy at 298.15K', 
               9: 'H, eV, Enthalpy at 298.15K',
               10: 'G, eV, Free energy at 298.15K',  
               11: 'c_{v}, cal\(mol K), Heat capacity at 298.15K'}

chemical_accuracy = {idx:0.043 for idx in range(12)}
chemical_accuracy[0] = 0.1
chemical_accuracy[1] = 0.1
chemical_accuracy[5] = 1.2
chemical_accuracy[6] = 0.0012
chemical_accuracy[11] = 0.050

In [3]:
wandb.init(project='QM9-GAT', entity='chrisxx')
config = wandb.config
config.hidden_dim = 256
config.nlayers = 6
config.nhead = 1
config.lr = 0.0003
config.n_epochs = 5000
config.patience = 5
config.factor = 0.95
config.minimal_lr = 6e-8
config.target_idx = 7
config.batch_size = 128
config.n_train = 110000
config.n_valid = 10000
config.target_ratio = 0.3
config.store_starting_from_ratio = 1
config.required_improvement = 0.8
config.model_dir = '../models/qm9/GAT3_dfs/'
config.num_workers = 4
config.dfs_codes = '../datasets/qm9_geometric_work/min_dfs_codes.json'

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mchrisxx[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.11.0 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [4]:
import json
with open(config.dfs_codes, 'r') as f:
    dfs_codes = json.load(f)

In [5]:
dfs_codes['gdb_55']['dfs_indices']

[5, 4, 9, 1, 13, 6, 7, 8, 10, 11, 12, 2, 3, 0, 14]

In [6]:
def int_2_one_hot(data, dfs_codes = dfs_codes):
    features = data.x
    # make atomic number a one hot
    atomic_number = nn.functional.one_hot(features[:, 5].long(), 100)
    # make num_h a one hot
    num_h = nn.functional.one_hot(features[:, -1].long(), 9)
    dfs_indices = nn.functional.one_hot(torch.LongTensor(dfs_codes[data.name]['dfs_indices']), 29)
    data.x = torch.cat((features[:, :5], features[:, 6:-1], atomic_number, num_h, dfs_indices), axis=1)
    return data

In [7]:
target_idx = config.target_idx

In [8]:
dataset = QM9('../datasets/qm9_geometric_work/', transform=int_2_one_hot)

In [9]:
dataset = dataset.shuffle()
train_dataset = dataset[:config.n_train]
valid_dataset = dataset[config.n_train:config.n_train+config.n_valid]
test_dataset = dataset[config.n_train+config.n_valid:]
config.n_test = len(test_dataset)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, pin_memory=True, num_workers=config.num_workers)
valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, num_workers=config.num_workers)
test_loader = DataLoader(test_dataset, batch_size=32, num_workers=config.num_workers)

In [10]:
import os
os.makedirs(config.model_dir, exist_ok=True)

In [11]:
torch.save(dataset.indices(), config.model_dir+'dataset_indices.pt')

In [12]:
ngpu=1
device = torch.device('cuda:0' if (torch.cuda.is_available() and ngpu > 0) else 'cpu')

# Model

In [13]:
target_vec = []

In [14]:
# based on https://schnetpack.readthedocs.io/en/stable/tutorials/tutorial_02_qm9.html
# and https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/schnet.html#SchNet
for data in train_loader:
    data = data.to(device)
    atomU0s = torch.tensor(qm9.atomrefs[target_idx], device=device)[torch.argmax(data.x[:, :5], axis=1)]
    target_modular = scatter(atomU0s, data.batch, dim=-1, reduce='sum')
    target_vec += [(data.y[:, target_idx] - target_modular).detach().cpu().numpy()]
target_vec = np.concatenate(target_vec, axis=0)

In [15]:
target_mean = np.mean(target_vec)
target_std = np.std(target_vec)

In [16]:
class SelfAttentionReadout(nn.Module):
    def __init__(self, dtarget, dmodel=512, dim_feedforward=2048, nlayers=6, nhead=8):
        self.dtarget = dtarget
        self.dmodel = dmodel
        self.dim_feedforward = dim_feedforward
        self.nhead = nheada
        self.nlayers = nlayers
        self.cls_token = nn.Parameter(torch.empty(1, 1, self.d_model), requires_grad=True)
        self.enc = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.d_model, 
                                                                    nhead=self.nhead, 
                                                                    dim_feedforward=self.dim_feedforward), 
                                         self.nlayers)
        self.fc_out = nn.Linear(self.d_model, self.dtarget)
        
    def forward(sequence):
        # sequence must be of shape: [n_seq, n_batch, dmodel]
        augmented_sequence = torch.cat((self.cls_token.expand(-1, sequence.shape[1], -1), sequence), dim=0)
        transformer_out = self.enc(augmented_sequence)
        out = self.fc_out(transformer_out[0]) 
        return out
        

class GATNN(nn.Module):
    def __init__(self, vert_dim, hidden_dim=config.hidden_dim, nhead=config.nhead, nlayers=config.nlayers,
                 #ro_dmodel=128, ro_dim_feedforward=512, ro_nlayers=6, ro_nhead=8, 
                 mean=None, std=None, atomref=None,
                 max_vertices=29, max_edges=28):
        """
        transfomer model is some type of transformer that 
        """
        super(GATNN, self).__init__()
        gat_layers = [(tgnn.GATv2Conv(hidden_dim, hidden_dim, heads=nhead), 'x, edge_index -> x')]
        for _ in range(nlayers-1):
            gat_layers += [(tgnn.GATv2Conv(hidden_dim, hidden_dim, heads=nhead), 'x, edge_index -> x')]
        
        self.gat_layers = tgnn.Sequential('x, edge_index', gat_layers)
        # gate turns the hidden states into attention weights
        self.gate = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), 
                                      nn.ReLU(inplace=True),
                                      nn.Linear(hidden_dim, hidden_dim),
                                      nn.ReLU(inplace=True),
                                      nn.Linear(hidden_dim, 1)) # sequential instead of linear also good improvement
        # fc_out turns the hidden states into predictions
        self.fc_out = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), 
                                      nn.ReLU(inplace=True),
                                      nn.Linear(hidden_dim, hidden_dim),
                                      nn.ReLU(inplace=True),
                                      nn.Linear(hidden_dim, 1)) # sequential instead of linear huge improvment
        self.readout = tgnn.GlobalAttention(self.gate, self.fc_out)
        
        self.x_emb = nn.Linear(vert_dim, hidden_dim)
        self.pos_emb = nn.Linear(3, hidden_dim)
        
        #TODO: maybe:
        #self.readout = SelfAttentionReadout(ro_dmodel, ro_dim_feedforward, ro_nlayers, ro_nhead)
        
        self.mean = mean
        self.std = std
        self.register_buffer('initial_atomref', atomref)
        self.atomref = None
        if atomref is not None:
            self.atomref = nn.Embedding(100, 1)
            self.atomref.weight.data.copy_(atomref)
    
    
    def forward(self, data):
        x, z, edge_index, batch = data.x, data.z, data.edge_index, data.batch
        h = self.x_emb(x) + self.pos_emb(data.pos)
        h = self.gat_layers(h, edge_index) 
        out = self.readout(h, batch)
        
        
        #batch = batch.permute(1, 0, 2) # seq_dim x batch_dim x n_model
        #transformer_out = self.enc(batch)
        #out = self.fc_out(transformer_out[0]) 
        
        # tricks from Schnet
        if self.mean is not None and self.std is not None:
            out = out * self.std + self.mean
                
        if self.atomref is not None:
            out = out + tgnn.global_add_pool(self.atomref(z), batch)
        
        return out

In [17]:
model = GATNN(next(iter(train_loader)).x.shape[1], atomref=dataset.atomref(target_idx), mean=target_mean, std=target_std)
loss = nn.MSELoss(reduction='mean')
optimizer = optim.Adam(model.parameters(), lr=config.lr)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', verbose=True, patience=config.patience, factor=config.factor)



# Training

In [18]:
model = model.to(device)

In [None]:
loss_hist = []
min_mae = config.store_starting_from_ratio
try:
    # For each epoch
    for epoch in range(config.n_epochs):
        # For each batch in the dataloader
        pbar = tqdm.tqdm(enumerate(train_loader, 0))
        epoch_loss = 0
        for i, data in pbar:
            model.zero_grad()
            data.to(device)
            target = data.y[:, target_idx]
            prediction = model(data)
            output = loss(prediction.view(-1), target)
            mae = (prediction.view(-1) - target).abs().mean()
            epoch_loss = (epoch_loss*i + mae.item())/(i+1)
            
            pbar.set_description('Epoch %d: MAE/CA %2.6f'%(epoch+1, epoch_loss/chemical_accuracy[target_idx]))
            output.backward()
            optimizer.step()
            wandb.log({'MSE': output.item()})
        curr_lr = list(optimizer.param_groups)[0]['lr']
        wandb.log({'MAE':epoch_loss, 
                   'MAE/CA':epoch_loss/chemical_accuracy[target_idx],
                   'learning rate':curr_lr})
        lr_scheduler.step(epoch_loss)
        loss_hist += [epoch_loss] 

        if epoch_loss/chemical_accuracy[target_idx] < min_mae*config.required_improvement:
            min_mae = epoch_loss/chemical_accuracy[target_idx]
            torch.save(model.state_dict(), config.model_dir+'gat_epoch%d.pt'%(epoch+1))
        if curr_lr < config.minimal_lr:
            break
        if epoch_loss/chemical_accuracy[target_idx] < config.target_ratio:
            break

except KeyboardInterrupt:
    print('keyboard interrupt caught')
    torch.save(model.state_dict(), config.model_dir+'gat_epoch%d.pt'%(epoch+1))

Epoch 1: MAE/CA 60.256000: : 860it [00:58, 14.81it/s] 
Epoch 2: MAE/CA 30.442681: : 860it [00:28, 30.05it/s]
Epoch 3: MAE/CA 25.512466: : 860it [00:26, 32.20it/s]
Epoch 4: MAE/CA 23.191192: : 860it [01:03, 13.47it/s]
Epoch 5: MAE/CA 21.200580: : 860it [01:05, 13.07it/s]
Epoch 6: MAE/CA 21.408619: : 860it [01:04, 13.35it/s]
Epoch 7: MAE/CA 19.857862: : 860it [01:03, 13.47it/s]
Epoch 8: MAE/CA 19.217431: : 860it [01:02, 13.76it/s]
Epoch 9: MAE/CA 17.680467: : 860it [01:02, 13.71it/s]
Epoch 10: MAE/CA 17.252968: : 860it [01:03, 13.49it/s]
Epoch 11: MAE/CA 16.733026: : 860it [01:03, 13.61it/s]
Epoch 12: MAE/CA 15.944528: : 860it [01:03, 13.61it/s]
Epoch 13: MAE/CA 15.328649: : 860it [01:03, 13.60it/s]
Epoch 14: MAE/CA 14.816600: : 860it [01:03, 13.48it/s]
Epoch 15: MAE/CA 14.462246: : 860it [01:02, 13.72it/s]
Epoch 16: MAE/CA 14.220037: : 860it [01:01, 14.08it/s]
Epoch 17: MAE/CA 13.624985: : 860it [01:01, 13.88it/s]
Epoch 18: MAE/CA 13.333091: : 860it [01:03, 13.45it/s]


Epoch 19: MAE/CA 12.979135: : 860it [01:02, 13.73it/s]
Epoch 20: MAE/CA 12.824632: : 860it [01:03, 13.46it/s]
Epoch 21: MAE/CA 12.559653: : 860it [01:04, 13.31it/s]
Epoch 22: MAE/CA 12.010566: : 860it [01:04, 13.29it/s]
Epoch 23: MAE/CA 11.900697: : 860it [01:04, 13.27it/s]
Epoch 24: MAE/CA 11.967046: : 860it [01:04, 13.27it/s]
Epoch 25: MAE/CA 11.044375: : 860it [01:03, 13.56it/s]
Epoch 26: MAE/CA 11.374483: : 860it [01:02, 13.68it/s]
Epoch 27: MAE/CA 10.985874: : 860it [01:03, 13.55it/s]
Epoch 28: MAE/CA 10.807180: : 860it [01:02, 13.76it/s]
Epoch 29: MAE/CA 10.763309: : 860it [01:02, 13.80it/s]
Epoch 30: MAE/CA 10.589708: : 860it [01:03, 13.52it/s]
Epoch 31: MAE/CA 10.690382: : 860it [01:02, 13.67it/s]
Epoch 32: MAE/CA 10.440558: : 860it [01:03, 13.64it/s]
Epoch 33: MAE/CA 10.380736: : 860it [01:03, 13.58it/s]
Epoch 34: MAE/CA 10.159466: : 860it [01:04, 13.44it/s]
Epoch 35: MAE/CA 9.826881: : 860it [01:03, 13.64it/s] 
Epoch 36: MAE/CA 9.987311: : 860it [01:02, 13.81it/s] 
Epoch 37: 

Epoch 38: MAE/CA 9.614575: : 860it [01:03, 13.49it/s]
Epoch 39: MAE/CA 9.654520: : 860it [01:03, 13.51it/s] 
Epoch 40: MAE/CA 9.618973: : 860it [01:03, 13.48it/s]
Epoch 41: MAE/CA 9.310701: : 860it [01:03, 13.64it/s]
Epoch 42: MAE/CA 9.261021: : 860it [01:03, 13.64it/s]
Epoch 43: MAE/CA 9.631921: : 860it [01:04, 13.43it/s] 
Epoch 44: MAE/CA 9.049483: : 860it [01:05, 13.11it/s]
Epoch 45: MAE/CA 9.218913: : 860it [01:03, 13.45it/s]
Epoch 46: MAE/CA 9.136587: : 860it [01:03, 13.52it/s]
Epoch 47: MAE/CA 8.949791: : 860it [01:03, 13.59it/s]
Epoch 48: MAE/CA 8.738943: : 860it [01:02, 13.75it/s]
Epoch 49: MAE/CA 8.639500: : 860it [01:02, 13.69it/s]
Epoch 50: MAE/CA 8.925526: : 860it [01:03, 13.49it/s]
Epoch 51: MAE/CA 8.518734: : 860it [01:03, 13.65it/s]
Epoch 52: MAE/CA 8.620298: : 860it [01:02, 13.76it/s]
Epoch 53: MAE/CA 8.599943: : 860it [01:02, 13.81it/s]
Epoch 54: MAE/CA 8.611248: : 860it [01:03, 13.46it/s]
Epoch 55: MAE/CA 8.590863: : 860it [01:02, 13.69it/s]


Epoch 56: MAE/CA 8.373694: : 860it [01:02, 13.75it/s]
Epoch 57: MAE/CA 8.325387: : 860it [01:04, 13.25it/s]
Epoch 58: MAE/CA 8.368658: : 860it [01:02, 13.76it/s]
Epoch 59: MAE/CA 8.127126: : 860it [01:03, 13.50it/s]
Epoch 60: MAE/CA 8.158320: : 860it [01:03, 13.53it/s]
Epoch 61: MAE/CA 8.180656: : 860it [01:02, 13.83it/s]
Epoch 62: MAE/CA 8.248116: : 860it [01:02, 13.80it/s]
Epoch 63: MAE/CA 8.112350: : 860it [01:02, 13.74it/s]
Epoch 64: MAE/CA 7.935727: : 860it [01:03, 13.51it/s]
Epoch 65: MAE/CA 7.896967: : 860it [01:02, 13.81it/s]
Epoch 66: MAE/CA 7.757702: : 860it [01:02, 13.68it/s]
Epoch 67: MAE/CA 7.922446: : 860it [01:03, 13.48it/s]
Epoch 68: MAE/CA 7.876334: : 860it [01:02, 13.81it/s]
Epoch 69: MAE/CA 7.776291: : 860it [01:00, 14.12it/s]
Epoch 70: MAE/CA 7.761118: : 860it [01:03, 13.49it/s]
Epoch 71: MAE/CA 7.744836: : 860it [01:02, 13.75it/s]
Epoch 72: MAE/CA 7.769797: : 860it [01:02, 13.77it/s]
Epoch 73: MAE/CA 7.527641: : 860it [01:02, 13.74it/s]
Epoch 74: MAE/CA 7.619354: :

Epoch 75: MAE/CA 7.517363: : 860it [01:03, 13.63it/s]
Epoch 76: MAE/CA 7.728577: : 860it [01:02, 13.77it/s]
Epoch 77: MAE/CA 7.326569: : 860it [01:03, 13.50it/s]
Epoch 78: MAE/CA 7.598956: : 860it [01:02, 13.79it/s]
Epoch 79: MAE/CA 7.327667: : 860it [01:02, 13.74it/s]
Epoch 80: MAE/CA 7.194400: : 860it [01:03, 13.59it/s]
Epoch 81: MAE/CA 7.242820: : 860it [01:02, 13.75it/s]
Epoch 82: MAE/CA 7.375974: : 860it [01:03, 13.58it/s]
Epoch 83: MAE/CA 7.189141: : 860it [01:04, 13.36it/s]
Epoch 84: MAE/CA 7.320348: : 860it [01:03, 13.47it/s]
Epoch 85: MAE/CA 7.299859: : 860it [01:03, 13.54it/s]
Epoch 86: MAE/CA 7.196783: : 860it [01:02, 13.77it/s]
Epoch 87: MAE/CA 7.213690: : 860it [01:03, 13.51it/s]
Epoch 88: MAE/CA 6.969705: : 860it [01:03, 13.52it/s]
Epoch 89: MAE/CA 7.198739: : 860it [01:03, 13.64it/s]
Epoch 90: MAE/CA 7.340590: : 860it [01:03, 13.60it/s]
Epoch 91: MAE/CA 7.088004: : 860it [01:03, 13.53it/s]
Epoch 92: MAE/CA 7.043235: : 860it [01:02, 13.79it/s]


Epoch 93: MAE/CA 6.749474: : 860it [01:03, 13.45it/s]
Epoch 94: MAE/CA 7.103979: : 860it [01:04, 13.38it/s]
Epoch 95: MAE/CA 6.839277: : 860it [01:03, 13.63it/s]
Epoch 96: MAE/CA 7.000394: : 860it [01:03, 13.60it/s]
Epoch 97: MAE/CA 6.924128: : 860it [01:03, 13.51it/s]
Epoch 98: MAE/CA 6.863960: : 860it [01:02, 13.79it/s]
Epoch 99: MAE/CA 6.747378: : 860it [01:02, 13.72it/s]
Epoch 100: MAE/CA 6.862455: : 860it [01:04, 13.38it/s]
Epoch 101: MAE/CA 6.646489: : 860it [01:02, 13.66it/s]
Epoch 102: MAE/CA 6.914526: : 860it [01:02, 13.67it/s]
Epoch 103: MAE/CA 6.679031: : 860it [01:02, 13.81it/s]
Epoch 104: MAE/CA 6.707541: : 860it [01:03, 13.49it/s]
Epoch 105: MAE/CA 6.603617: : 860it [01:02, 13.79it/s]
Epoch 106: MAE/CA 6.760024: : 860it [01:03, 13.64it/s]
Epoch 107: MAE/CA 6.609812: : 860it [01:03, 13.48it/s]
Epoch 108: MAE/CA 6.618617: : 860it [01:03, 13.59it/s]
Epoch 109: MAE/CA 6.601959: : 860it [01:02, 13.75it/s]
Epoch 110: MAE/CA 6.595075: : 860it [01:03, 13.54it/s]
Epoch 111: MAE/CA

Epoch 112: MAE/CA 6.633593: : 860it [01:02, 13.69it/s]
Epoch 113: MAE/CA 6.623266: : 860it [01:03, 13.51it/s]
Epoch 114: MAE/CA 6.540719: : 860it [01:03, 13.51it/s]
Epoch 115: MAE/CA 6.674621: : 860it [01:02, 13.66it/s]
Epoch 116: MAE/CA 6.423968: : 860it [01:04, 13.41it/s]
Epoch 117: MAE/CA 6.676493: : 860it [01:04, 13.28it/s]
Epoch 118: MAE/CA 6.288312: : 860it [01:02, 13.66it/s]
Epoch 119: MAE/CA 6.430030: : 860it [01:02, 13.85it/s]
Epoch 120: MAE/CA 6.576562: : 860it [01:02, 13.69it/s]
Epoch 121: MAE/CA 6.412230: : 860it [01:02, 13.71it/s]
Epoch 122: MAE/CA 6.299045: : 860it [01:04, 13.42it/s]
Epoch 123: MAE/CA 6.516094: : 860it [01:03, 13.50it/s]
Epoch 124: MAE/CA 6.263135: : 860it [01:03, 13.45it/s]
Epoch 125: MAE/CA 6.362676: : 860it [01:04, 13.36it/s]
Epoch 126: MAE/CA 6.338649: : 860it [01:03, 13.53it/s]
Epoch 127: MAE/CA 6.267380: : 860it [01:15, 11.42it/s]
Epoch 128: MAE/CA 6.274619: : 860it [01:15, 11.36it/s]
Epoch 129: MAE/CA 6.171721: : 860it [01:16, 11.27it/s]


Epoch 130: MAE/CA 6.279845: : 860it [01:16, 11.17it/s]
Epoch 131: MAE/CA 6.286273: : 860it [01:17, 11.07it/s]
Epoch 132: MAE/CA 6.204394: : 860it [01:16, 11.25it/s]
Epoch 133: MAE/CA 6.277412: : 860it [01:16, 11.22it/s]
Epoch 134: MAE/CA 6.219138: : 860it [01:16, 11.28it/s]
Epoch 135: MAE/CA 6.111808: : 860it [01:17, 11.13it/s]
Epoch 136: MAE/CA 6.196810: : 860it [01:16, 11.18it/s]
Epoch 137: MAE/CA 6.144193: : 860it [01:16, 11.21it/s]
Epoch 138: MAE/CA 6.183145: : 860it [01:16, 11.20it/s]
Epoch 139: MAE/CA 6.096660: : 860it [01:17, 11.03it/s]
Epoch 140: MAE/CA 6.104882: : 860it [01:12, 11.93it/s]
Epoch 141: MAE/CA 5.981016: : 860it [01:03, 13.54it/s]
Epoch 142: MAE/CA 6.083397: : 860it [01:03, 13.46it/s]
Epoch 143: MAE/CA 6.155681: : 860it [01:02, 13.66it/s]
Epoch 144: MAE/CA 6.015691: : 860it [01:03, 13.53it/s]
Epoch 145: MAE/CA 5.941901: : 860it [01:03, 13.59it/s]
Epoch 146: MAE/CA 6.023347: : 860it [01:03, 13.46it/s]
Epoch 147: MAE/CA 6.029485: : 860it [01:02, 13.72it/s]
Epoch 148:

Epoch 149: MAE/CA 6.520690: : 860it [01:03, 13.63it/s] 
Epoch 150: MAE/CA 5.641190: : 860it [01:02, 13.77it/s]
Epoch 151: MAE/CA 5.858141: : 860it [01:02, 13.81it/s]
Epoch 152: MAE/CA 5.838656: : 860it [01:04, 13.43it/s]
Epoch 153: MAE/CA 6.106229: : 860it [01:03, 13.60it/s]
Epoch 154: MAE/CA 5.751780: : 860it [01:05, 13.16it/s]
Epoch 155: MAE/CA 5.870333: : 860it [01:18, 10.94it/s]
Epoch 156: MAE/CA 5.967849: : 860it [01:16, 11.27it/s]

Epoch   156: reducing learning rate of group 0 to 2.8500e-04.



Epoch 157: MAE/CA 5.719621: : 860it [01:22, 10.48it/s]
Epoch 158: MAE/CA 5.553429: : 860it [01:21, 10.51it/s]
Epoch 159: MAE/CA 5.892396: : 860it [01:21, 10.50it/s]
Epoch 160: MAE/CA 5.635871: : 860it [01:22, 10.37it/s]
Epoch 161: MAE/CA 5.599098: : 860it [01:22, 10.47it/s]
Epoch 162: MAE/CA 5.665477: : 860it [01:22, 10.45it/s]
Epoch 163: MAE/CA 5.668959: : 860it [01:22, 10.45it/s]
Epoch 164: MAE/CA 5.594486: : 860it [01:19, 10.77it/s]

Epoch   164: reducing learning rate of group 0 to 2.7075e-04.



Epoch 165: MAE/CA 5.486224: : 860it [01:19, 10.78it/s]
Epoch 166: MAE/CA 5.410480: : 860it [01:19, 10.75it/s]
Epoch 167: MAE/CA 5.678275: : 860it [01:20, 10.68it/s]
Epoch 168: MAE/CA 5.301132: : 860it [01:19, 10.79it/s]
Epoch 169: MAE/CA 5.605940: : 860it [01:19, 10.76it/s]
Epoch 170: MAE/CA 5.342075: : 860it [01:19, 10.78it/s]
Epoch 171: MAE/CA 5.369219: : 860it [01:20, 10.63it/s]
Epoch 172: MAE/CA 5.366205: : 860it [01:19, 10.77it/s]
Epoch 173: MAE/CA 5.474069: : 860it [01:19, 10.79it/s]
Epoch 174: MAE/CA 5.421426: : 860it [01:19, 10.80it/s]

Epoch   174: reducing learning rate of group 0 to 2.5721e-04.



Epoch 175: MAE/CA 5.198170: : 860it [01:20, 10.64it/s]
Epoch 176: MAE/CA 5.177603: : 860it [01:19, 10.76it/s]
Epoch 177: MAE/CA 5.150402: : 860it [01:19, 10.81it/s]
Epoch 178: MAE/CA 5.244477: : 860it [01:20, 10.74it/s]
Epoch 179: MAE/CA 5.254792: : 860it [01:20, 10.69it/s]
Epoch 180: MAE/CA 5.248251: : 860it [01:19, 10.82it/s]
Epoch 181: MAE/CA 5.037151: : 860it [01:19, 10.79it/s]
Epoch 182: MAE/CA 5.110723: : 860it [01:19, 10.80it/s]
Epoch 183: MAE/CA 5.167471: : 860it [01:20, 10.68it/s]
Epoch 184: MAE/CA 5.171020: : 860it [01:19, 10.83it/s]
Epoch 185: MAE/CA 5.093172: : 860it [01:19, 10.76it/s]
Epoch 186: MAE/CA 5.119510: : 860it [01:19, 10.80it/s]
Epoch 187: MAE/CA 5.130581: : 860it [01:20, 10.66it/s]

Epoch   187: reducing learning rate of group 0 to 2.4435e-04.



Epoch 188: MAE/CA 5.142686: : 860it [01:19, 10.80it/s]
Epoch 189: MAE/CA 4.922497: : 860it [01:19, 10.80it/s]
Epoch 190: MAE/CA 5.079593: : 860it [01:19, 10.80it/s]
Epoch 191: MAE/CA 4.739625: : 860it [01:20, 10.67it/s]
Epoch 192: MAE/CA 5.045322: : 860it [01:19, 10.79it/s]
Epoch 193: MAE/CA 4.940745: : 860it [01:19, 10.77it/s]
Epoch 194: MAE/CA 4.965248: : 860it [01:20, 10.67it/s]
Epoch 195: MAE/CA 4.862366: : 860it [01:20, 10.72it/s]
Epoch 196: MAE/CA 4.973715: : 860it [01:19, 10.77it/s]
Epoch 197: MAE/CA 4.987186: : 860it [01:19, 10.78it/s]

Epoch   197: reducing learning rate of group 0 to 2.3213e-04.



Epoch 198: MAE/CA 4.724604: : 860it [01:20, 10.66it/s]
Epoch 199: MAE/CA 4.864521: : 860it [01:19, 10.79it/s]
Epoch 200: MAE/CA 4.746220: : 860it [01:19, 10.82it/s]
Epoch 201: MAE/CA 4.660677: : 860it [01:19, 10.79it/s]
Epoch 202: MAE/CA 4.855687: : 860it [01:20, 10.66it/s]
Epoch 203: MAE/CA 4.730688: : 860it [01:19, 10.77it/s]
Epoch 204: MAE/CA 4.709053: : 860it [01:20, 10.75it/s]
Epoch 205: MAE/CA 4.695334: : 860it [01:19, 10.82it/s]
Epoch 206: MAE/CA 4.733123: : 860it [01:20, 10.67it/s]
Epoch 207: MAE/CA 4.757846: : 860it [01:19, 10.77it/s]

Epoch   207: reducing learning rate of group 0 to 2.2053e-04.



Epoch 208: MAE/CA 4.572071: : 860it [01:20, 10.74it/s]
Epoch 209: MAE/CA 4.538473: : 860it [01:19, 10.82it/s]
Epoch 210: MAE/CA 4.529856: : 860it [01:20, 10.65it/s]
Epoch 211: MAE/CA 4.598668: : 860it [01:19, 10.84it/s]
Epoch 212: MAE/CA 4.592218: : 860it [01:19, 10.78it/s]
Epoch 213: MAE/CA 4.461130: : 860it [01:19, 10.80it/s]
Epoch 214: MAE/CA 4.484048: : 860it [01:20, 10.64it/s]
Epoch 215: MAE/CA 4.592620: : 860it [01:20, 10.75it/s]
Epoch 216: MAE/CA 4.558271: : 860it [01:19, 10.76it/s]
Epoch 217: MAE/CA 4.516360: : 860it [01:20, 10.73it/s]
Epoch 218: MAE/CA 4.581836: : 860it [01:20, 10.63it/s]
Epoch 219: MAE/CA 4.553039: : 860it [01:20, 10.69it/s]

Epoch   219: reducing learning rate of group 0 to 2.0950e-04.



Epoch 220: MAE/CA 4.393767: : 860it [01:19, 10.76it/s]
Epoch 221: MAE/CA 4.341857: : 860it [01:20, 10.74it/s]
Epoch 222: MAE/CA 4.354998: : 860it [01:20, 10.63it/s]
Epoch 223: MAE/CA 4.319149: : 860it [01:19, 10.76it/s]
Epoch 224: MAE/CA 4.469215: : 860it [01:19, 10.76it/s]
Epoch 225: MAE/CA 4.259878: : 860it [01:20, 10.63it/s]
Epoch 226: MAE/CA 4.414581: : 860it [01:20, 10.74it/s]
Epoch 227: MAE/CA 4.293278: : 860it [01:19, 10.77it/s]
Epoch 228: MAE/CA 4.399116: : 860it [01:19, 10.76it/s]
Epoch 229: MAE/CA 4.234979: : 860it [01:20, 10.66it/s]
Epoch 230: MAE/CA 4.351281: : 860it [01:19, 10.78it/s]
Epoch 231: MAE/CA 4.406417: : 860it [01:19, 10.79it/s]
Epoch 232: MAE/CA 4.237685: : 860it [01:19, 10.77it/s]
Epoch 233: MAE/CA 4.249546: : 860it [01:20, 10.68it/s]
Epoch 234: MAE/CA 4.366280: : 860it [01:20, 10.75it/s]
Epoch 235: MAE/CA 4.275117: : 860it [01:19, 10.75it/s]

Epoch   235: reducing learning rate of group 0 to 1.9903e-04.



Epoch 236: MAE/CA 4.272588: : 860it [01:19, 10.82it/s]
Epoch 237: MAE/CA 4.073933: : 860it [01:20, 10.64it/s]
Epoch 238: MAE/CA 4.166023: : 860it [01:19, 10.77it/s]
Epoch 239: MAE/CA 4.188216: : 860it [01:19, 10.81it/s]
Epoch 240: MAE/CA 4.235964: : 860it [01:19, 10.75it/s]
Epoch 241: MAE/CA 4.032922: : 860it [01:20, 10.69it/s]
Epoch 242: MAE/CA 4.155029: : 860it [01:19, 10.76it/s]
Epoch 243: MAE/CA 4.211806: : 860it [01:19, 10.82it/s]
Epoch 244: MAE/CA 4.119874: : 860it [01:20, 10.74it/s]
Epoch 245: MAE/CA 4.090219: : 860it [01:20, 10.65it/s]
Epoch 246: MAE/CA 4.060995: : 860it [01:19, 10.83it/s]
Epoch 247: MAE/CA 4.143158: : 860it [01:19, 10.76it/s]

Epoch   247: reducing learning rate of group 0 to 1.8907e-04.



Epoch 248: MAE/CA 3.991725: : 860it [01:20, 10.72it/s]
Epoch 249: MAE/CA 3.928540: : 860it [01:20, 10.65it/s]
Epoch 250: MAE/CA 4.066718: : 860it [01:20, 10.66it/s]
Epoch 251: MAE/CA 4.000110: : 860it [01:19, 10.77it/s]
Epoch 252: MAE/CA 3.967311: : 860it [01:19, 10.76it/s]
Epoch 253: MAE/CA 3.949108: : 860it [01:20, 10.66it/s]
Epoch 254: MAE/CA 4.048011: : 860it [01:19, 10.75it/s]
Epoch 255: MAE/CA 3.882695: : 860it [01:19, 10.80it/s]
Epoch 256: MAE/CA 3.932272: : 860it [01:20, 10.66it/s]
Epoch 257: MAE/CA 3.982217: : 860it [01:19, 10.76it/s]
Epoch 258: MAE/CA 3.966105: : 860it [01:19, 10.79it/s]
Epoch 259: MAE/CA 3.936063: : 860it [01:19, 10.79it/s]
Epoch 260: MAE/CA 4.003151: : 860it [01:20, 10.62it/s]
Epoch 261: MAE/CA 3.967426: : 860it [01:19, 10.80it/s]

Epoch   261: reducing learning rate of group 0 to 1.7962e-04.



Epoch 262: MAE/CA 3.791327: : 860it [01:19, 10.79it/s]
Epoch 263: MAE/CA 3.890183: : 860it [01:19, 10.76it/s]
Epoch 264: MAE/CA 3.729225: : 860it [01:20, 10.66it/s]
Epoch 265: MAE/CA 3.922931: : 860it [01:19, 10.77it/s]
Epoch 266: MAE/CA 3.728891: : 860it [01:20, 10.75it/s]
Epoch 267: MAE/CA 3.776467: : 860it [01:19, 10.78it/s]
Epoch 268: MAE/CA 3.769789: : 860it [01:20, 10.66it/s]
Epoch 269: MAE/CA 3.796067: : 860it [01:19, 10.77it/s]
Epoch 270: MAE/CA 3.858106: : 860it [01:19, 10.79it/s]

Epoch   270: reducing learning rate of group 0 to 1.7064e-04.



Epoch 271: MAE/CA 3.633713: : 860it [01:19, 10.79it/s]
Epoch 272: MAE/CA 3.757411: : 860it [01:20, 10.64it/s]
Epoch 273: MAE/CA 3.693877: : 860it [01:19, 10.81it/s]
Epoch 274: MAE/CA 3.570971: : 860it [01:19, 10.78it/s]
Epoch 275: MAE/CA 3.708316: : 860it [01:19, 10.77it/s]
Epoch 276: MAE/CA 3.652211: : 860it [01:20, 10.64it/s]
Epoch 277: MAE/CA 3.670701: : 860it [01:19, 10.76it/s]
Epoch 278: MAE/CA 3.680078: : 860it [01:19, 10.77it/s]
Epoch 279: MAE/CA 3.634240: : 860it [01:19, 10.76it/s]
Epoch 280: MAE/CA 3.601908: : 860it [01:20, 10.67it/s]

Epoch   280: reducing learning rate of group 0 to 1.6211e-04.



Epoch 281: MAE/CA 3.574792: : 860it [01:19, 10.77it/s]
Epoch 282: MAE/CA 3.550205: : 860it [01:19, 10.76it/s]
Epoch 283: MAE/CA 3.595402: : 860it [01:19, 10.78it/s]
Epoch 284: MAE/CA 3.520083: : 860it [01:20, 10.65it/s]
Epoch 285: MAE/CA 3.570973: : 860it [01:19, 10.75it/s]
Epoch 286: MAE/CA 3.459769: : 860it [01:19, 10.77it/s]
Epoch 287: MAE/CA 3.616994: : 860it [01:20, 10.71it/s]
Epoch 288: MAE/CA 3.479140: : 860it [01:20, 10.62it/s]
Epoch 289: MAE/CA 3.549014: : 860it [01:19, 10.79it/s]
Epoch 290: MAE/CA 3.555449: : 860it [01:19, 10.77it/s]
Epoch 291: MAE/CA 3.448370: : 860it [01:21, 10.61it/s]
Epoch 292: MAE/CA 3.582504: : 860it [01:19, 10.76it/s]
Epoch 293: MAE/CA 3.527053: : 860it [01:19, 10.79it/s]
Epoch 294: MAE/CA 3.440645: : 860it [01:19, 10.78it/s]
Epoch 295: MAE/CA 3.537594: : 860it [01:20, 10.67it/s]
Epoch 296: MAE/CA 3.453886: : 860it [01:19, 10.79it/s]
Epoch 297: MAE/CA 3.529444: : 860it [01:19, 10.78it/s]
Epoch 298: MAE/CA 3.356612: : 860it [01:19, 10.77it/s]


Epoch 299: MAE/CA 3.512502: : 860it [01:20, 10.63it/s]
Epoch 300: MAE/CA 3.474078: : 860it [01:19, 10.85it/s]
Epoch 301: MAE/CA 3.473067: : 860it [01:19, 10.80it/s]
Epoch 302: MAE/CA 3.477623: : 860it [01:19, 10.78it/s]
Epoch 303: MAE/CA 3.492175: : 860it [01:20, 10.65it/s]
Epoch 304: MAE/CA 3.436614: : 860it [01:19, 10.81it/s]

Epoch   304: reducing learning rate of group 0 to 1.5400e-04.



Epoch 305: MAE/CA 3.284576: : 860it [01:19, 10.80it/s]
Epoch 306: MAE/CA 3.344412: : 860it [01:19, 10.80it/s]
Epoch 307: MAE/CA 3.348460: : 860it [01:20, 10.67it/s]
Epoch 308: MAE/CA 3.358799: : 860it [01:20, 10.69it/s]
Epoch 309: MAE/CA 3.465260: : 860it [01:19, 10.79it/s]
Epoch 310: MAE/CA 3.341576: : 860it [01:19, 10.79it/s]
Epoch 311: MAE/CA 3.195715: : 860it [01:20, 10.64it/s]
Epoch 312: MAE/CA 3.320591: : 860it [01:19, 10.77it/s]
Epoch 313: MAE/CA 3.362424: : 860it [01:19, 10.76it/s]
Epoch 314: MAE/CA 3.348461: : 860it [01:19, 10.81it/s]
Epoch 315: MAE/CA 3.346867: : 860it [01:20, 10.64it/s]
Epoch 316: MAE/CA 3.340816: : 860it [01:20, 10.73it/s]
Epoch 317: MAE/CA 3.232975: : 860it [01:19, 10.78it/s]

Epoch   317: reducing learning rate of group 0 to 1.4630e-04.



Epoch 318: MAE/CA 3.189315: : 860it [01:20, 10.66it/s]
Epoch 319: MAE/CA 3.253997: : 860it [01:20, 10.68it/s]
Epoch 320: MAE/CA 3.272986: : 860it [01:19, 10.77it/s]
Epoch 321: MAE/CA 3.193888: : 860it [01:19, 10.81it/s]
Epoch 322: MAE/CA 3.170852: : 860it [01:20, 10.65it/s]
Epoch 323: MAE/CA 3.251288: : 860it [01:19, 10.79it/s]
Epoch 324: MAE/CA 3.173084: : 860it [01:20, 10.74it/s]
Epoch 325: MAE/CA 3.207097: : 860it [01:18, 10.90it/s]
Epoch 326: MAE/CA 3.269140: : 860it [01:20, 10.64it/s]
Epoch 327: MAE/CA 3.145248: : 860it [01:19, 10.80it/s]
Epoch 328: MAE/CA 3.171751: : 860it [01:19, 10.75it/s]
Epoch 329: MAE/CA 3.199033: : 860it [01:19, 10.76it/s]
Epoch 330: MAE/CA 3.321155: : 860it [01:20, 10.63it/s]
Epoch 331: MAE/CA 3.151899: : 860it [01:19, 10.78it/s]
Epoch 332: MAE/CA 3.200796: : 860it [01:20, 10.74it/s]
Epoch 333: MAE/CA 3.140020: : 860it [01:19, 10.80it/s]
Epoch 334: MAE/CA 3.171432: : 860it [01:20, 10.62it/s]
Epoch 335: MAE/CA 3.197544: : 860it [01:20, 10.69it/s]


Epoch 336: MAE/CA 3.142333: : 860it [01:19, 10.77it/s]
Epoch 337: MAE/CA 3.133630: : 860it [01:20, 10.75it/s]
Epoch 338: MAE/CA 3.132390: : 860it [01:20, 10.67it/s]
Epoch 339: MAE/CA 3.205374: : 860it [01:20, 10.66it/s]
Epoch 340: MAE/CA 3.228975: : 860it [01:19, 10.78it/s]
Epoch 341: MAE/CA 3.134292: : 860it [01:20, 10.75it/s]
Epoch 342: MAE/CA 3.238355: : 860it [01:20, 10.68it/s]
Epoch 343: MAE/CA 3.126815: : 860it [01:19, 10.82it/s]
Epoch 344: MAE/CA 3.174801: : 860it [01:19, 10.78it/s]
Epoch 345: MAE/CA 3.177414: : 860it [01:19, 10.78it/s]
Epoch 346: MAE/CA 3.098965: : 860it [01:20, 10.62it/s]
Epoch 347: MAE/CA 3.134194: : 860it [01:19, 10.82it/s]
Epoch 348: MAE/CA 3.108588: : 860it [01:20, 10.73it/s]
Epoch 349: MAE/CA 3.140088: : 860it [01:20, 10.72it/s]
Epoch 350: MAE/CA 3.085098: : 860it [01:20, 10.69it/s]
Epoch 351: MAE/CA 3.101060: : 860it [01:19, 10.82it/s]
Epoch 352: MAE/CA 3.120108: : 860it [01:19, 10.79it/s]
Epoch 353: MAE/CA 3.148896: : 860it [01:20, 10.63it/s]
Epoch 354:

Epoch 355: MAE/CA 3.092982: : 860it [01:19, 10.78it/s]
Epoch 356: MAE/CA 3.093204: : 860it [01:20, 10.72it/s]

Epoch   356: reducing learning rate of group 0 to 1.3899e-04.



Epoch 357: MAE/CA 2.921906: : 860it [01:20, 10.66it/s]
Epoch 358: MAE/CA 2.980883: : 860it [01:19, 10.76it/s]
Epoch 359: MAE/CA 3.046183: : 860it [01:20, 10.74it/s]
Epoch 360: MAE/CA 2.977850: : 860it [01:20, 10.75it/s]
Epoch 361: MAE/CA 3.035957: : 860it [01:20, 10.64it/s]
Epoch 362: MAE/CA 3.046167: : 860it [01:19, 10.78it/s]
Epoch 363: MAE/CA 3.146569: : 860it [01:19, 10.77it/s]

Epoch   363: reducing learning rate of group 0 to 1.3204e-04.



Epoch 364: MAE/CA 2.881751: : 860it [01:19, 10.76it/s]
Epoch 365: MAE/CA 2.932005: : 860it [01:20, 10.65it/s]
Epoch 366: MAE/CA 2.900537: : 860it [01:19, 10.75it/s]
Epoch 367: MAE/CA 2.911046: : 860it [01:19, 10.76it/s]
Epoch 368: MAE/CA 2.929006: : 860it [01:20, 10.75it/s]
Epoch 369: MAE/CA 2.943218: : 860it [01:20, 10.66it/s]
Epoch 370: MAE/CA 2.925309: : 860it [01:19, 10.82it/s]

Epoch   370: reducing learning rate of group 0 to 1.2544e-04.



Epoch 371: MAE/CA 2.858246: : 860it [01:19, 10.76it/s]
Epoch 372: MAE/CA 2.793229: : 860it [01:19, 10.83it/s]
Epoch 373: MAE/CA 2.816554: : 860it [01:20, 10.65it/s]
Epoch 374: MAE/CA 2.787700: : 860it [01:19, 10.75it/s]
Epoch 375: MAE/CA 2.828139: : 860it [01:19, 10.78it/s]
Epoch 376: MAE/CA 2.914360: : 860it [01:19, 10.81it/s]
Epoch 377: MAE/CA 2.786612: : 860it [01:21, 10.53it/s]
Epoch 378: MAE/CA 2.851270: : 860it [01:23, 10.36it/s]
Epoch 379: MAE/CA 2.752133: : 860it [01:22, 10.45it/s]
Epoch 380: MAE/CA 2.802840: : 860it [01:22, 10.48it/s]
Epoch 381: MAE/CA 2.804523: : 540it [00:52, 10.83it/s]

In [None]:
pbar = tqdm.tqdm(enumerate(test_loader, 0))
epoch_loss = 0
maes = []
for i, data in pbar:
    data.to(device)
    prediction = prediction = model(data.x, data.z, data.edge_index, data.batch)
    mae = (prediction.view(-1) - data.y[:, target_idx]).abs()
    maes += [mae.detach().cpu()]
maes = torch.cat(maes, dim=0)
mae = maes.mean().item()
print(mae, mae/chemical_accuracy[target_idx])
wandb.log({'TEST MAE':mae, 'TEST MAE/CA':mae/chemical_accuracy[target_idx]})
