In [1]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets.qm9 import QM9
import torch_geometric.datasets.qm9 as qm9
from torch_geometric.data import DataLoader
from torch.utils.data import Dataset

import torch_geometric.nn as tgnn
from torch_scatter import scatter
import tqdm
import numpy as np
import wandb
import json
import os

import sys
sys.path = ['/home/chrisw/Documents/projects/2021/graph-transformer/src'] + sys.path
from dfs_transformer import EarlyStopping, PositionalEncoding
import math

In [2]:
# [0] Reports MAE in eV / Chemical Accuracy of the target variable U0. 
# The chemical accuracy of U0 is 0.043 see [1, Table 5].

# Reproduced table [0]
# MXMNet: 0.00590/0.043 = 0.13720930232558143
# HMGNN:  0.00592/0.043 = 0.13767441860465118
# MPNN:   0.01935/0.043 = 0.45
# KRR:    0.0251 /0.043 = 0.5837209302325582
# [0] https://paperswithcode.com/sota/formation-energy-on-qm9
# [1] Neural Message Passing for Quantum Chemistry, https://arxiv.org/pdf/1704.01212v2.pdf
# MXMNet https://arxiv.org/pdf/2011.07457v1.pdf
# HMGNN https://arxiv.org/pdf/2009.12710v1.pdf
# MPNN https://arxiv.org/pdf/1704.01212v2.pdf
# KRR HDAD kernel ridge regression https://arxiv.org/pdf/1702.05532.pdf
# HDAD means HDAD (Histogram of distances, anglesand dihedral angles)

# [2] Reports the average value of MAE / Chemical Accuracy of over all targets
# [2] https://paperswithcode.com/sota/drug-discovery-on-qm9
target_dict = {0: 'mu, D, Dipole moment', 
               1: 'alpha, {a_0}^3, Isotropic polarizability', 
               2: 'epsilon_{HOMO}, eV, Highest occupied molecular orbital energy',
               3: 'epsilon_{LUMO}, eV, Lowest unoccupied molecular orbital energy',
               4: 'Delta, eV, Gap between HOMO and LUMO',
               5: '< R^2 >, {a_0}^2, Electronic spatial extent',
               6: 'ZPVE, eV, Zero point vibrational energy', 
               7: 'U_0, eV, Internal energy at 0K',
               8: 'U, eV, Internal energy at 298.15K', 
               9: 'H, eV, Enthalpy at 298.15K',
               10: 'G, eV, Free energy at 298.15K',  
               11: 'c_{v}, cal\(mol K), Heat capacity at 298.15K'}

chemical_accuracy = {idx:0.043 for idx in range(12)}
chemical_accuracy[0] = 0.1
chemical_accuracy[1] = 0.1
chemical_accuracy[5] = 1.2
chemical_accuracy[6] = 0.0012
chemical_accuracy[11] = 0.050

In [3]:
wandb.init(project='QM9-transformer', entity='chrisxx')
config = wandb.config
config.lr = 0.0003#*0.1
config.n_epochs = 10000
config.patience = 5
config.factor = 0.95
config.minimal_lr = 6e-8
config.target_idx = 7
config.batch_size = 256
config.n_train = 110000
config.n_valid = 10000
config.target_ratio = 0.1
config.store_starting_from_ratio = 1
config.required_improvement = 0.8
config.nlayers=3
config.nhead=32
config.d_model=512
config.dim_feedforward=4*config.d_model
config.model_dir = '../models/qm9/gtransformer_medium3/'
#config.dfs_codes = '../datasets/qm9_geometric/min_dfs_codes.json'
config.dfs_codes = '../results/all/1/min_dfs_codes_0_to_130831.json'
config.num_workers = 4

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mchrisxx[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.0 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [4]:
target_idx = config.target_idx

# Dataset

In [5]:
def preprocess_vertex_features(features):
    feat_np = features.detach().cpu().numpy()
    # make atomic number a one hot
    atomic_number = nn.functional.one_hot(features[:, 5].long(), 100)
    # make num_h a one hot
    num_h = nn.functional.one_hot(features[:, -4].long(), 9)
    return torch.cat((features[:, :5], features[:, 6:-4], features[:, -3:], atomic_number, num_h), axis=1)

In [6]:
class QM9DFSCodes(Dataset):
    def __init__(self, dfs_codes, qm9_dataset, 
                 target_idx, vert_feats = ['x', 'pos'],
                 vertex_transform=None, edge_transform=None):
        self.dfs_codes = dfs_codes
        self.qm9_dataset = qm9_dataset
        self.target_idx = target_idx
        self.vertex_features = vert_feats
        self.max_vert = 29
        self.max_len = np.max([len(d['min_dfs_code']) for d in self.dfs_codes.values()])
        self.feat_dim = 2 + qm9_dataset[0].edge_attr.shape[1] # 2 for the dfs indices
        for feat in vert_feats: 
            self.feat_dim += 2*qm9_dataset[0][feat].shape[1]
        self.vertex_transform = vertex_transform
        self.edge_transform = edge_transform

    def __len__(self):
        return len(self.qm9_dataset)

    def __getitem__(self, idx):
        """
        graph_repr: [batch_size, max_edges, 2 + n_vert_feat + n_edge_feat + n_vert_feat]
        return vertex_features, graph_repr, target
        """
        data = self.qm9_dataset[idx]
        code_dict = self.dfs_codes[data.name]
        code = code_dict['min_dfs_code']
        eid2nr = code_dict['edge_id_2_edge_number']
        
        vert_feats = [data[k].detach().cpu().numpy() for k in ['x', 'pos']]
        vert_feats = np.concatenate(vert_feats, axis=1)
        edge_feats = data.edge_attr.detach().cpu().numpy()
        
        d = {'dfs_from':np.zeros(self.max_len), 
             'dfs_to':np.zeros(self.max_len),
             'feat_from':np.zeros((self.max_len, vert_feats.shape[1])), 
             'feat_to':np.zeros((self.max_len, vert_feats.shape[1])), 
             'feat_edge':np.zeros((self.max_len, edge_feats.shape[1])),
             'n_edges':len(code)*np.ones((1,), dtype=np.int32),
             'z':np.zeros(self.max_vert)}
        
        for idx, e_tuple in enumerate(code):
            e_idx = eid2nr[str(e_tuple[-2])]
            from_idx = e_tuple[-3]
            to_idx = e_tuple[-1]
            d['dfs_from'][idx] = e_tuple[0]
            d['dfs_to'][idx] = e_tuple[1]
            d['feat_from'][idx] = vert_feats[from_idx]
            d['feat_to'][idx] = vert_feats[to_idx]
            d['feat_edge'][idx] = edge_feats[e_idx]      
        
        d['z'][:len(data.z)] = data.z.detach().cpu().numpy()
        
        d_tensors = {}
        d_tensors['dfs_from'] = torch.IntTensor(d['dfs_from'])
        d_tensors['dfs_to'] = torch.IntTensor(d['dfs_to'])
        d_tensors['feat_from'] = torch.Tensor(d['feat_from'])
        d_tensors['feat_to'] = torch.Tensor(d['feat_to'])
        d_tensors['feat_edge'] = torch.Tensor(d['feat_edge'])
        d_tensors['z'] = torch.IntTensor(d['z'])
        
        if self.vertex_transform:
            d_tensors['feat_from'] = self.vertex_transform(d_tensors['feat_from'])
            d_tensors['feat_to'] = self.vertex_transform(d_tensors['feat_to'])
        if self.edge_transform:
            d_tensors['feat_edge'] = self.edge_transform(d_tensors['feat_edge'])
        
        d_tensors['target'] = data.y[:, self.target_idx]
        return d_tensors
    
    def shuffle(self):
        self.qm9_dataset = self.qm9_dataset.shuffle()
        return self


In [7]:
with open(config.dfs_codes, 'r') as f:
    dfs_codes = json.load(f)

In [8]:
dset = QM9('../datasets/qm9_geometric/')



In [9]:
dset = dset.shuffle()
train_qm9 = DataLoader(dset[:config.n_train], batch_size=config.batch_size)
train_dataset = QM9DFSCodes(dfs_codes, dset[:config.n_train], target_idx, vertex_transform=preprocess_vertex_features)
valid_dataset = QM9DFSCodes(dfs_codes, dset[config.n_train:config.n_train+config.n_valid], target_idx, vertex_transform=preprocess_vertex_features) 
test_dataset = QM9DFSCodes(dfs_codes, dset[config.n_train+config.n_valid:], target_idx, vertex_transform=preprocess_vertex_features) 
config.n_test = len(test_dataset)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, pin_memory=True, num_workers=config.num_workers)
valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, num_workers=config.num_workers)
test_loader = DataLoader(test_dataset, batch_size=32, num_workers=config.num_workers)

In [10]:
os.makedirs(config.model_dir, exist_ok=True)

In [11]:
torch.save(dset.indices(), config.model_dir+'dataset_indices.pt')

In [12]:
ngpu=1
device = torch.device('cuda:0' if (torch.cuda.is_available() and ngpu > 0) else 'cpu')

# Model

In [13]:


class MoleculeTransformer(nn.Module):
    def __init__(self, vert_dim, edge_dim, d_model=512, nhead=8, nlayers=4, dim_feedforward=2048, mean=None, std=None, atomref=None,
                 max_vertices=29, max_edges=28):
        """
        transfomer model is some type of transformer that 
        """
        super(MoleculeTransformer, self).__init__()
        # atomic masses could be used as additional features
        # see https://github.com/rusty1s/pytorch_geometric/blob/97d3177dc43858f66c07bb66d7dc12506b986199/torch_geometric/nn/models/schnet.py#L113
        self.vert_dim = vert_dim
        self.edge_dim = edge_dim
        self.d_model = d_model
        self.nhead = nhead
        self.nlayers = nlayers
        self.dim_feedforward = dim_feedforward
        self.max_vertices = max_vertices
        self.max_edges = max_edges
        
        self.emb_seq = PositionalEncoding(d_model, max_len=max_edges) 
        
        self.emb_dfs = nn.Embedding(self.max_vertices, d_model // 2)
        self.emb_vertex = nn.Linear(self.vert_dim, d_model // 2)
        self.emb_edge = nn.Linear(self.edge_dim, d_model)        
        
        self.cls_token = nn.Parameter(torch.empty(1, 1, self.d_model), requires_grad=True)
        self.enc = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.d_model, nhead=self.nhead, dim_feedforward=self.dim_feedforward), self.nlayers)
        
        self.fc_out = nn.Linear(self.d_model, 1)
        
        self.mean = mean
        self.std = std
        self.register_buffer('initial_atomref', atomref)
        self.atomref = None
        if atomref is not None:
            self.atomref = nn.Embedding(100, 1)
            self.atomref.weight.data.copy_(atomref)
            
        nn.init.normal_(self.cls_token, mean=.0, std=.5)
    
    def forward(self, data):
        z = data['z']
        dfs_from_emb = self.emb_dfs(data['dfs_from'])
        dfs_to_emb = self.emb_dfs(data['dfs_to'])
        dfs_emb = torch.cat((dfs_from_emb, dfs_to_emb), -1)
        from_emb = self.emb_vertex(data['feat_from'])
        to_emb = self.emb_vertex(data['feat_to'])
        feat_emb = torch.cat((from_emb, to_emb), -1)
        edge_emb = self.emb_edge(data['feat_edge'])
        batch = dfs_emb + feat_emb + edge_emb # batch_dim x seq_dim x n_model
        batch = batch.permute(1, 0, 2) # seq_dim x batch_dim x n_model
        batch = self.emb_seq(batch * math.sqrt(self.d_model))
        batch = torch.cat((self.cls_token.expand(-1, batch.shape[1], -1), batch), dim=0)
        transformer_out = self.enc(batch)
        out = self.fc_out(transformer_out[0]) 
        
        # tricks from Schnet
        if self.mean is not None and self.std is not None:
            out = out * self.std + self.mean
        
        if self.atomref is not None:
            out = out + torch.sum(self.atomref(z), axis=1)
        
        return out
        
        

In [14]:
target_vec = []

In [15]:
# based on https://schnetpack.readthedocs.io/en/stable/tutorials/tutorial_02_qm9.html
# and https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/schnet.html#SchNet
for data in train_qm9:
    data = data.to(device)
    atomU0s = torch.tensor(qm9.atomrefs[target_idx], device=device)[torch.argmax(data.x[:, :5], axis=1)]
    target_modular = scatter(atomU0s, data.batch, dim=-1, reduce='sum')
    target_vec += [(data.y[:, target_idx] - target_modular).detach().cpu().numpy()]
target_vec = np.concatenate(target_vec, axis=0)

In [16]:
target_mean = np.mean(target_vec)
target_std = np.std(target_vec)

In [17]:
d = next(iter(train_loader))
vert_dim = d['feat_from'].shape[-1]
edge_dim = d['feat_edge'].shape[-1]
model = MoleculeTransformer(vert_dim, edge_dim, nlayers=config.nlayers, nhead=config.nhead, 
                            d_model=config.d_model, dim_feedforward=config.dim_feedforward, 
                            atomref=dset.atomref(target_idx), mean=target_mean, std=target_std)

loss = nn.MSELoss(reduction='mean')
optimizer = optim.Adam(model.parameters(), lr=config.lr)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', verbose=True, patience=config.patience, factor=config.factor)

# Training

In [18]:
model = model.to(device)

In [None]:
loss_hist = []
min_mae = config.store_starting_from_ratio
try:
    # For each epoch
    for epoch in range(config.n_epochs):
        # For each batch in the dataloader
        pbar = tqdm.tqdm(enumerate(train_loader, 0))
        epoch_loss = 0
        for i, data in pbar:
            model.zero_grad()
            data = {key:d.to(device) for key, d in data.items()} 
            target = data['target']
            prediction = model(data)
            output = loss(prediction.view(-1), target.view(-1))
            mae = (prediction.view(-1) - target.view(-1)).abs().mean()
            epoch_loss = (epoch_loss*i + mae.item())/(i+1)
            
            pbar.set_description('Epoch %d: MAE/CA %2.6f'%(epoch+1, epoch_loss/chemical_accuracy[target_idx]))
            output.backward()
            optimizer.step()
            wandb.log({'MSE': output.item()})
        curr_lr = list(optimizer.param_groups)[0]['lr']
        wandb.log({'MAE':epoch_loss, 
                   'MAE/CA':epoch_loss/chemical_accuracy[target_idx],
                   'learning rate':curr_lr})
        lr_scheduler.step(epoch_loss)
        loss_hist += [epoch_loss] 

        if epoch_loss/chemical_accuracy[target_idx] < min_mae*config.required_improvement:
            min_mae = epoch_loss/chemical_accuracy[target_idx]
            torch.save(model.state_dict(), config.model_dir+'schnet_epoch%d.pt'%(epoch+1))
        if curr_lr < config.minimal_lr:
            break
        if epoch_loss/chemical_accuracy[target_idx] < config.target_ratio:
            break

except KeyboardInterrupt:
    print('keyboard interrupt caught')
    torch.save(model.state_dict(), config.model_dir+'schnet_epoch%d.pt'%(epoch+1))

Epoch 1: MAE/CA 63.017904: : 430it [00:43,  9.90it/s] 
Epoch 2: MAE/CA 28.057313: : 430it [00:44,  9.65it/s]
Epoch 3: MAE/CA 23.745692: : 430it [00:45,  9.42it/s]
Epoch 4: MAE/CA 22.335142: : 430it [00:45,  9.48it/s]
Epoch 5: MAE/CA 20.803338: : 430it [00:45,  9.38it/s]
Epoch 6: MAE/CA 20.581358: : 430it [00:45,  9.47it/s]
Epoch 7: MAE/CA 18.438804: : 430it [00:45,  9.49it/s]
Epoch 8: MAE/CA 18.631341: : 430it [00:45,  9.41it/s]
Epoch 9: MAE/CA 17.291390: : 430it [00:44,  9.66it/s]
Epoch 10: MAE/CA 17.136621: : 430it [00:43,  9.81it/s]
Epoch 11: MAE/CA 16.120770: : 430it [00:43,  9.81it/s]
Epoch 12: MAE/CA 15.394222: : 430it [00:43,  9.85it/s]
Epoch 13: MAE/CA 15.063672: : 430it [00:43,  9.88it/s]
Epoch 14: MAE/CA 14.838182: : 430it [00:43,  9.89it/s]
Epoch 15: MAE/CA 14.081883: : 430it [00:43,  9.88it/s]
Epoch 16: MAE/CA 13.824629: : 430it [00:43,  9.88it/s]
Epoch 17: MAE/CA 13.226781: : 430it [00:43,  9.90it/s]
Epoch 18: MAE/CA 12.636732: : 430it [00:43,  9.86it/s]


Epoch 19: MAE/CA 12.598561: : 430it [00:43,  9.87it/s]
Epoch 20: MAE/CA 12.128147: : 430it [00:43,  9.88it/s]
Epoch 21: MAE/CA 12.012003: : 430it [00:43,  9.88it/s]
Epoch 22: MAE/CA 11.831824: : 430it [00:43,  9.91it/s]
Epoch 23: MAE/CA 11.333926: : 430it [00:41, 10.29it/s]
Epoch 24: MAE/CA 11.195457: : 430it [00:41, 10.24it/s]
Epoch 25: MAE/CA 13.835238: : 430it [00:41, 10.24it/s]
Epoch 26: MAE/CA 10.686425: : 430it [00:42, 10.23it/s]
Epoch 27: MAE/CA 10.265665: : 430it [00:41, 10.25it/s]
Epoch 28: MAE/CA 10.013017: : 430it [00:41, 10.28it/s]
Epoch 29: MAE/CA 10.059041: : 430it [00:41, 10.25it/s]
Epoch 30: MAE/CA 9.485958: : 430it [00:41, 10.28it/s]
Epoch 31: MAE/CA 9.956751: : 430it [00:41, 10.25it/s]
Epoch 32: MAE/CA 9.485880: : 430it [00:41, 10.28it/s]
Epoch 33: MAE/CA 9.333159: : 430it [00:41, 10.24it/s]
Epoch 34: MAE/CA 15.360440: : 430it [00:41, 10.28it/s]
Epoch 35: MAE/CA 10.628624: : 430it [00:41, 10.29it/s]
Epoch 36: MAE/CA 9.712707: : 430it [00:41, 10.25it/s]


Epoch 37: MAE/CA 9.070042: : 430it [00:41, 10.28it/s]
Epoch 38: MAE/CA 8.978870: : 430it [00:41, 10.33it/s]
Epoch 39: MAE/CA 8.895007: : 430it [00:42, 10.23it/s]
Epoch 40: MAE/CA 8.821952: : 430it [00:41, 10.29it/s]
Epoch 41: MAE/CA 9.172558: : 430it [00:41, 10.31it/s]
Epoch 42: MAE/CA 8.976397: : 430it [00:41, 10.32it/s]
Epoch 43: MAE/CA 8.590737: : 430it [00:41, 10.27it/s]
Epoch 44: MAE/CA 8.441087: : 430it [00:41, 10.29it/s]
Epoch 45: MAE/CA 8.272364: : 430it [00:41, 10.27it/s]
Epoch 46: MAE/CA 8.287269: : 430it [00:41, 10.26it/s]
Epoch 47: MAE/CA 8.192595: : 430it [00:41, 10.29it/s]
Epoch 48: MAE/CA 8.256614: : 430it [00:41, 10.25it/s]
Epoch 49: MAE/CA 10.004226: : 430it [00:41, 10.25it/s]
Epoch 50: MAE/CA 8.202182: : 430it [00:41, 10.26it/s]
Epoch 51: MAE/CA 7.515144: : 430it [00:41, 10.28it/s]
Epoch 52: MAE/CA 7.455879: : 430it [00:41, 10.26it/s]
Epoch 53: MAE/CA 7.508843: : 430it [00:41, 10.29it/s]
Epoch 54: MAE/CA 7.501104: : 430it [00:41, 10.26it/s]
Epoch 55: MAE/CA 7.664743: 

Epoch 56: MAE/CA 7.577377: : 430it [00:41, 10.29it/s]
Epoch 57: MAE/CA 7.338995: : 430it [00:41, 10.26it/s]
Epoch 58: MAE/CA 7.155141: : 430it [00:41, 10.26it/s]
Epoch 59: MAE/CA 7.225821: : 430it [00:41, 10.29it/s]
Epoch 60: MAE/CA 7.199831: : 430it [00:41, 10.29it/s]
Epoch 61: MAE/CA 7.015299: : 430it [00:41, 10.28it/s]
Epoch 62: MAE/CA 7.111458: : 430it [00:41, 10.29it/s]
Epoch 63: MAE/CA 7.019242: : 430it [00:41, 10.29it/s]
Epoch 64: MAE/CA 6.747651: : 430it [00:41, 10.27it/s]
Epoch 65: MAE/CA 6.780707: : 430it [00:42, 10.23it/s]
Epoch 66: MAE/CA 7.011889: : 430it [00:41, 10.29it/s]
Epoch 67: MAE/CA 6.549718: : 430it [00:41, 10.28it/s]
Epoch 68: MAE/CA 6.674056: : 430it [00:41, 10.33it/s]
Epoch 69: MAE/CA 6.801798: : 430it [00:41, 10.26it/s]
Epoch 70: MAE/CA 6.373411: : 430it [00:41, 10.28it/s]
Epoch 71: MAE/CA 6.447154: : 430it [00:41, 10.27it/s]
Epoch 72: MAE/CA 7.074656: : 430it [00:41, 10.26it/s]
Epoch 73: MAE/CA 6.944672: : 430it [00:41, 10.30it/s]


Epoch 74: MAE/CA 6.710000: : 430it [00:41, 10.28it/s]
Epoch 75: MAE/CA 6.478446: : 430it [00:41, 10.27it/s]
Epoch 76: MAE/CA 6.134095: : 430it [00:41, 10.29it/s]
Epoch 77: MAE/CA 6.103857: : 430it [00:41, 10.28it/s]
Epoch 78: MAE/CA 6.392762: : 430it [00:41, 10.30it/s]
Epoch 79: MAE/CA 6.101596: : 430it [00:41, 10.31it/s]
Epoch 80: MAE/CA 6.049206: : 430it [00:41, 10.28it/s]
Epoch 81: MAE/CA 6.069285: : 430it [00:41, 10.30it/s]
Epoch 82: MAE/CA 6.865337: : 430it [00:41, 10.32it/s]
Epoch 83: MAE/CA 6.053768: : 430it [00:41, 10.30it/s]
Epoch 84: MAE/CA 6.027477: : 430it [00:41, 10.25it/s]
Epoch 85: MAE/CA 5.910259: : 430it [00:41, 10.29it/s]
Epoch 86: MAE/CA 5.886834: : 430it [00:41, 10.31it/s]
Epoch 87: MAE/CA 5.698193: : 430it [00:41, 10.29it/s]
Epoch 88: MAE/CA 5.749831: : 430it [00:41, 10.30it/s]
Epoch 89: MAE/CA 5.708868: : 430it [00:41, 10.31it/s]
Epoch 90: MAE/CA 5.903149: : 430it [00:41, 10.26it/s]
Epoch 91: MAE/CA 5.835208: : 430it [00:41, 10.27it/s]
Epoch 92: MAE/CA 5.556279: :

Epoch 93: MAE/CA 5.817783: : 430it [00:41, 10.27it/s]
Epoch 94: MAE/CA 5.587600: : 430it [00:41, 10.29it/s]
Epoch 95: MAE/CA 5.567751: : 430it [00:41, 10.29it/s]
Epoch 96: MAE/CA 6.326374: : 430it [00:41, 10.27it/s]
Epoch 97: MAE/CA 8.811915: : 430it [00:41, 10.28it/s]
Epoch 98: MAE/CA 6.772980: : 430it [00:42, 10.01it/s]

Epoch    98: reducing learning rate of group 0 to 2.8500e-04.



Epoch 99: MAE/CA 5.835675: : 430it [00:44,  9.71it/s]
Epoch 100: MAE/CA 5.790828: : 430it [00:44,  9.74it/s]
Epoch 101: MAE/CA 5.567065: : 430it [00:44,  9.64it/s]
Epoch 102: MAE/CA 5.478778: : 430it [00:44,  9.69it/s]
Epoch 103: MAE/CA 5.349838: : 430it [00:44,  9.61it/s]
Epoch 104: MAE/CA 5.270382: : 430it [00:44,  9.68it/s]
Epoch 105: MAE/CA 5.626650: : 430it [00:43,  9.86it/s]
Epoch 106: MAE/CA 5.594584: : 430it [00:43,  9.95it/s]
Epoch 107: MAE/CA 5.374212: : 430it [00:43,  9.90it/s]
Epoch 108: MAE/CA 5.529977: : 430it [00:43,  9.93it/s]
Epoch 109: MAE/CA 5.469444: : 430it [00:43,  9.87it/s]
Epoch 110: MAE/CA 5.263515: : 430it [00:43,  9.90it/s]
Epoch 111: MAE/CA 5.165210: : 430it [00:43,  9.87it/s]
Epoch 112: MAE/CA 5.432117: : 430it [00:43,  9.99it/s]
Epoch 113: MAE/CA 5.478494: : 430it [00:43,  9.93it/s]
Epoch 114: MAE/CA 5.170859: : 430it [00:43,  9.93it/s]
Epoch 115: MAE/CA 5.273289: : 430it [00:43,  9.95it/s]
Epoch 116: MAE/CA 5.207467: : 430it [00:43,  9.86it/s]


Epoch 117: MAE/CA 5.140483: : 430it [00:43,  9.94it/s]
Epoch 118: MAE/CA 5.041036: : 430it [00:43,  9.90it/s]
Epoch 119: MAE/CA 5.223862: : 430it [00:43,  9.87it/s]
Epoch 120: MAE/CA 5.008644: : 430it [00:43,  9.91it/s]
Epoch 121: MAE/CA 4.943980: : 430it [00:43,  9.93it/s]
Epoch 122: MAE/CA 4.906287: : 430it [00:43,  9.84it/s]
Epoch 123: MAE/CA 5.505883: : 430it [00:43,  9.85it/s]
Epoch 124: MAE/CA 5.298546: : 430it [00:43,  9.89it/s]
Epoch 125: MAE/CA 5.003893: : 430it [00:43,  9.92it/s]
Epoch 126: MAE/CA 4.932430: : 430it [00:43,  9.89it/s]
Epoch 127: MAE/CA 4.819454: : 430it [00:43,  9.98it/s]
Epoch 128: MAE/CA 6.323909: : 430it [00:43,  9.86it/s]
Epoch 129: MAE/CA 4.918481: : 430it [00:43,  9.95it/s]
Epoch 130: MAE/CA 4.757665: : 430it [00:43,  9.84it/s]
Epoch 131: MAE/CA 4.615905: : 430it [00:43,  9.86it/s]
Epoch 132: MAE/CA 4.652309: : 430it [00:43,  9.89it/s]
Epoch 133: MAE/CA 4.703271: : 430it [00:43,  9.93it/s]
Epoch 134: MAE/CA 4.665003: : 430it [00:43,  9.83it/s]
Epoch 135:

Epoch 136: MAE/CA 5.519937: : 430it [00:44,  9.75it/s]
Epoch 137: MAE/CA 4.982547: : 430it [00:43,  9.87it/s]

Epoch   137: reducing learning rate of group 0 to 2.7075e-04.



Epoch 138: MAE/CA 4.540966: : 430it [00:43,  9.83it/s]
Epoch 139: MAE/CA 4.559812: : 430it [00:43,  9.83it/s]
Epoch 140: MAE/CA 4.505383: : 430it [00:43,  9.88it/s]
Epoch 141: MAE/CA 4.611263: : 430it [00:43,  9.85it/s]
Epoch 142: MAE/CA 4.492968: : 430it [00:43,  9.83it/s]
Epoch 143: MAE/CA 4.718575: : 430it [00:43,  9.84it/s]
Epoch 144: MAE/CA 5.069453: : 430it [00:43,  9.79it/s]
Epoch 145: MAE/CA 4.825599: : 430it [00:43,  9.88it/s]
Epoch 146: MAE/CA 4.654704: : 430it [00:43,  9.94it/s]
Epoch 147: MAE/CA 4.506820: : 430it [00:43,  9.86it/s]
Epoch 148: MAE/CA 4.505009: : 430it [00:43,  9.85it/s]

Epoch   148: reducing learning rate of group 0 to 2.5721e-04.



Epoch 149: MAE/CA 4.372544: : 430it [00:43,  9.93it/s]
Epoch 150: MAE/CA 4.394717: : 430it [00:43,  9.84it/s]
Epoch 151: MAE/CA 4.591254: : 430it [00:43,  9.88it/s]
Epoch 152: MAE/CA 4.638151: : 430it [00:43,  9.91it/s]
Epoch 153: MAE/CA 4.447993: : 430it [00:43,  9.94it/s]
Epoch 154: MAE/CA 4.345435: : 430it [00:43,  9.92it/s]
Epoch 155: MAE/CA 4.246084: : 430it [00:43,  9.98it/s]
Epoch 156: MAE/CA 4.338842: : 430it [00:43,  9.89it/s]
Epoch 157: MAE/CA 4.291998: : 430it [00:43,  9.86it/s]
Epoch 158: MAE/CA 4.439103: : 430it [00:43,  9.81it/s]
Epoch 159: MAE/CA 4.487752: : 430it [00:43,  9.81it/s]
Epoch 160: MAE/CA 4.351913: : 430it [00:43,  9.85it/s]
Epoch 161: MAE/CA 4.245516: : 430it [00:43,  9.87it/s]
Epoch 162: MAE/CA 4.198900: : 430it [00:43,  9.89it/s]
Epoch 163: MAE/CA 4.205074: : 430it [00:43,  9.97it/s]
Epoch 164: MAE/CA 4.169624: : 430it [00:43,  9.88it/s]
Epoch 165: MAE/CA 4.214996: : 430it [00:43,  9.86it/s]
Epoch 166: MAE/CA 4.155288: : 218it [00:22,  9.69it/s]

In [None]:
pbar = tqdm.tqdm(enumerate(test_loader, 0))
epoch_loss = 0
maes = []
for i, data in pbar:
    data = {key:d.to(device) for key, d in data.items()} 
    target = data['target']
    prediction = model(data)
    mae = (prediction.view(-1) - target.view(-1)).abs()
    maes += [mae.detach().cpu()]
maes = torch.cat(maes, dim=0)
mae = maes.mean().item()
print(mae, mae/chemical_accuracy[target_idx])
wandb.log({'TEST MAE':mae, 'TEST MAE/CA':mae/chemical_accuracy[target_idx]})
