In [1]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets.qm9 import QM9
import torch_geometric.datasets.qm9 as qm9
from torch_geometric.data import DataLoader
from torch.utils.data import Dataset

import torch_geometric.nn as tgnn
from torch_scatter import scatter
import tqdm
import numpy as np
import wandb
import json
import os

In [2]:
# [0] Reports MAE in eV / Chemical Accuracy of the target variable U0. 
# The chemical accuracy of U0 is 0.043 see [1, Table 5].

# Reproduced table [0]
# MXMNet: 0.00590/0.043 = 0.13720930232558143
# HMGNN:  0.00592/0.043 = 0.13767441860465118
# MPNN:   0.01935/0.043 = 0.45
# KRR:    0.0251 /0.043 = 0.5837209302325582
# [0] https://paperswithcode.com/sota/formation-energy-on-qm9
# [1] Neural Message Passing for Quantum Chemistry, https://arxiv.org/pdf/1704.01212v2.pdf
# MXMNet https://arxiv.org/pdf/2011.07457v1.pdf
# HMGNN https://arxiv.org/pdf/2009.12710v1.pdf
# MPNN https://arxiv.org/pdf/1704.01212v2.pdf
# KRR HDAD kernel ridge regression https://arxiv.org/pdf/1702.05532.pdf
# HDAD means HDAD (Histogram of distances, anglesand dihedral angles)

# [2] Reports the average value of MAE / Chemical Accuracy of over all targets
# [2] https://paperswithcode.com/sota/drug-discovery-on-qm9
target_dict = {0: 'mu, D, Dipole moment', 
               1: 'alpha, {a_0}^3, Isotropic polarizability', 
               2: 'epsilon_{HOMO}, eV, Highest occupied molecular orbital energy',
               3: 'epsilon_{LUMO}, eV, Lowest unoccupied molecular orbital energy',
               4: 'Delta, eV, Gap between HOMO and LUMO',
               5: '< R^2 >, {a_0}^2, Electronic spatial extent',
               6: 'ZPVE, eV, Zero point vibrational energy', 
               7: 'U_0, eV, Internal energy at 0K',
               8: 'U, eV, Internal energy at 298.15K', 
               9: 'H, eV, Enthalpy at 298.15K',
               10: 'G, eV, Free energy at 298.15K',  
               11: 'c_{v}, cal\(mol K), Heat capacity at 298.15K'}

chemical_accuracy = {idx:0.043 for idx in range(12)}
chemical_accuracy[0] = 0.1
chemical_accuracy[1] = 0.1
chemical_accuracy[5] = 1.2
chemical_accuracy[6] = 0.0012
chemical_accuracy[11] = 0.050

In [23]:
wandb.init(project='QM9-transformer', entity='chrisxx')
config = wandb.config
config.lr = 0.0003
config.n_epochs = 2000
config.patience = 5
config.factor = 0.95
config.minimal_lr = 6e-8
config.target_idx = 7
config.batch_size = 128
config.n_train = 110000
config.n_valid = 10000
config.target_ratio = 0.4
config.store_starting_from_ratio = 1
config.required_improvement = 0.8
config.model_dir = '../models/qm9/gtransformer_big/'
config.dfs_codes = '../datasets/qm9_geometric/min_dfs_codes.json'

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
MSE,0.10969
_runtime,17931.0
_timestamp,1626373849.0
_step,218911.0
MAE,0.24247
MAE/CA,5.63893
learning rate,0.00018
TEST MAE,0.3035
TEST MAE/CA,7.05805


0,1
MSE,▆▆██▅▅▇▅▅▄▅▆▄▄▄▄▃▄▆▃▆▄▃▂▂▃▂▂▂▂▁▂▂▂▁▁▂▁▂▁
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇███
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
MAE,█▇█▆▆▅▅▅▆▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁
MAE/CA,█▇█▆▆▅▅▅▆▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁
learning rate,███▇▇▇▇▆▆▆▆▆▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁
TEST MAE,▁
TEST MAE/CA,▁


In [4]:
target_idx = config.target_idx

# Dataset

In [5]:
class QM9DFSCodes(Dataset):
    def __init__(self, dfs_codes, qm9_dataset, 
                 target_idx, vert_feats = ['x', 'pos'],
                 vertex_transform=None, edge_transform=None):
        self.dfs_codes = dfs_codes
        self.qm9_dataset = qm9_dataset
        self.target_idx = target_idx
        self.vertex_features = vert_feats
        self.max_vert = 29
        self.max_len = np.max([len(d['min_dfs_code']) for d in self.dfs_codes.values()])
        self.feat_dim = 2 + qm9_dataset[0].edge_attr.shape[1] # 2 for the dfs indices
        for feat in vert_feats: 
            self.feat_dim += 2*qm9_dataset[0][feat].shape[1]
        self.vertex_transform = vertex_transform
        self.edge_transform = edge_transform

    def __len__(self):
        return len(self.qm9_dataset)

    def __getitem__(self, idx):
        """
        graph_repr: [batch_size, max_edges, 2 + n_vert_feat + n_edge_feat + n_vert_feat]
        return vertex_features, graph_repr, target
        """
        data = self.qm9_dataset[idx]
        code_dict = self.dfs_codes[data.name]
        code = code_dict['min_dfs_code']
        eid2nr = code_dict['edge_id_2_edge_number']
        
        vert_feats = [data[k].detach().cpu().numpy() for k in ['x', 'pos']]
        vert_feats = np.concatenate(vert_feats, axis=1)
        edge_feats = data.edge_attr.detach().cpu().numpy()
        
        d = {'dfs_from':np.zeros(self.max_len), 
             'dfs_to':np.zeros(self.max_len),
             'feat_from':np.zeros((self.max_len, vert_feats.shape[1])), 
             'feat_to':np.zeros((self.max_len, vert_feats.shape[1])), 
             'feat_edge':np.zeros((self.max_len, edge_feats.shape[1])),
             'n_edges':len(code)*np.ones((1,), dtype=np.int32),
             'z':np.zeros(self.max_vert)}
        
        for idx, e_tuple in enumerate(code):
            e_idx = eid2nr[str(e_tuple[-2])]
            from_idx = e_tuple[-3]
            to_idx = e_tuple[-1]
            d['dfs_from'][idx] = e_tuple[0]
            d['dfs_to'][idx] = e_tuple[1]
            d['feat_from'][idx] = vert_feats[from_idx]
            d['feat_to'][idx] = vert_feats[to_idx]
            d['feat_edge'][idx] = edge_feats[e_idx]      
        
        d['z'][:len(data.z)] = data.z.detach().cpu().numpy()
        
        d_tensors = {}
        d_tensors['dfs_from'] = torch.IntTensor(d['dfs_from'])
        d_tensors['dfs_to'] = torch.IntTensor(d['dfs_to'])
        d_tensors['feat_from'] = torch.Tensor(d['feat_from'])
        d_tensors['feat_to'] = torch.Tensor(d['feat_to'])
        d_tensors['feat_edge'] = torch.Tensor(d['feat_edge'])
        d_tensors['z'] = torch.IntTensor(d['z'])
        
        if self.vertex_transform:
            d_tensors['feat_from'] = self.vertex_transform(d_tensors['feat_from'])
            d_tensors['feat_to'] = self.vertex_transform(d_tensors['feat_to'])
        if self.edge_transform:
            d_tensors['feat_edge'] = self.edge_transform(d_tensors['feat_edge'])
        
        d_tensors['target'] = data.y[:, self.target_idx]
        return d_tensors
    
    def shuffle(self):
        self.qm9_dataset = self.qm9_dataset.shuffle()
        return self


In [6]:
with open(config.dfs_codes, 'r') as f:
    dfs_codes = json.load(f)

In [7]:
dset = QM9('../datasets/qm9_geometric/')

In [8]:
dset = dset.shuffle()
train_qm9 = DataLoader(dset[:config.n_train], batch_size=config.batch_size)
train_dataset = QM9DFSCodes(dfs_codes, dset[:config.n_train], target_idx)
valid_dataset = QM9DFSCodes(dfs_codes, dset[config.n_train:config.n_train+config.n_valid], target_idx) 
test_dataset = QM9DFSCodes(dfs_codes, dset[config.n_train+config.n_valid:], target_idx) 
config.n_test = len(test_dataset)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, pin_memory=True, num_workers=0)
valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=32, num_workers=0)

In [9]:
os.makedirs(config.model_dir, exist_ok=True)

In [24]:
torch.save(dset.indices(), config.model_dir+'dataset_indices.pt')

[34m[1mwandb[0m: Network error resolved after 0:05:52.857406, resuming normal operation.


In [11]:
ngpu=1
device = torch.device('cuda:0' if (torch.cuda.is_available() and ngpu > 0) else 'cpu')

# Model

In [12]:
class MoleculeTransformer(nn.Module):
    def __init__(self, vert_dim, edge_dim, d_model=512, nhead=8, nlayers=4, dim_feedforward=2048, mean=None, std=None, atomref=None,
                 max_vertices=29, max_edges=28):
        """
        transfomer model is some type of transformer that 
        """
        super(MoleculeTransformer, self).__init__()
        # atomic masses could be used as additional features
        # see https://github.com/rusty1s/pytorch_geometric/blob/97d3177dc43858f66c07bb66d7dc12506b986199/torch_geometric/nn/models/schnet.py#L113
        self.vert_dim = vert_dim
        self.edge_dim = edge_dim
        self.d_model = d_model
        self.nhead = nhead
        self.nlayers = nlayers
        self.dim_feedforward = dim_feedforward
        self.max_vertices = max_vertices
        self.max_edges = max_edges
        
        self.emb_dfs = nn.Embedding(self.max_vertices, d_model // 2)
        self.emb_vertex = nn.Linear(self.vert_dim, d_model // 2)
        self.emb_edge = nn.Linear(self.edge_dim, d_model)        
        
        self.cls_token = nn.Parameter(torch.empty(1, 1, self.d_model), requires_grad=True)
        self.enc = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=self.d_model, nhead=self.nhead, dim_feedforward=self.dim_feedforward), self.nlayers)
        
        self.fc_out = nn.Linear(self.d_model, 1)
        
        self.mean = mean
        self.std = std
        self.register_buffer('initial_atomref', atomref)
        self.atomref = None
        if atomref is not None:
            self.atomref = nn.Embedding(100, 1)
            self.atomref.weight.data.copy_(atomref)
            
        nn.init.normal_(self.cls_token, mean=.0, std=.5)
    
    def forward(self, data):
        z = data['z']
        dfs_from_emb = self.emb_dfs(data['dfs_from'])
        dfs_to_emb = self.emb_dfs(data['dfs_to'])
        dfs_emb = torch.cat((dfs_from_emb, dfs_to_emb), -1)
        from_emb = self.emb_vertex(data['feat_from'])
        to_emb = self.emb_vertex(data['feat_to'])
        feat_emb = torch.cat((from_emb, to_emb), -1)
        edge_emb = self.emb_edge(data['feat_edge'])
        batch = dfs_emb + feat_emb + edge_emb # batch_dim x seq_dim x n_model
        batch = batch.permute(1, 0, 2) # seq_dim x batch_dim x n_model
        batch = torch.cat((self.cls_token.expand(-1, batch.shape[1], -1), batch), dim=0)
        
        transformer_out = self.enc(batch)
        out = self.fc_out(transformer_out[0]) 
        
        # tricks from Schnet
        if self.mean is not None and self.std is not None:
            out = out * self.std + self.mean
        
        if self.atomref is not None:
            out = out + torch.sum(self.atomref(z), axis=1)
        
        return out
        
        

In [13]:
target_vec = []

In [14]:
# based on https://schnetpack.readthedocs.io/en/stable/tutorials/tutorial_02_qm9.html
# and https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/schnet.html#SchNet
for data in train_qm9:
    data = data.to(device)
    atomU0s = torch.tensor(qm9.atomrefs[target_idx], device=device)[torch.argmax(data.x[:, :5], axis=1)]
    target_modular = scatter(atomU0s, data.batch, dim=-1, reduce='sum')
    target_vec += [(data.y[:, target_idx] - target_modular).detach().cpu().numpy()]
target_vec = np.concatenate(target_vec, axis=0)

In [15]:
target_mean = np.mean(target_vec)
target_std = np.std(target_vec)

In [16]:
d = next(iter(train_loader))
vert_dim = d['feat_from'].shape[-1]
edge_dim = d['feat_edge'].shape[-1]
model = MoleculeTransformer(vert_dim, edge_dim, atomref=dset.atomref(target_idx), mean=target_mean, std=target_std)

loss = nn.MSELoss(reduction='mean')
optimizer = optim.Adam(model.parameters(), lr=config.lr)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', verbose=True, patience=config.patience, factor=config.factor)

# Training

In [17]:
model = model.to(device)
loss_hist = []

In [21]:
min_mae = config.store_starting_from_ratio
try:
    # For each epoch
    for epoch in range(config.n_epochs):
        # For each batch in the dataloader
        pbar = tqdm.tqdm(enumerate(train_loader, 0))
        epoch_loss = 0
        for i, data in pbar:
            model.zero_grad()
            data = {key:d.to(device) for key, d in data.items()} 
            target = data['target']
            prediction = model(data)
            output = loss(prediction.view(-1), target.view(-1))
            mae = (prediction.view(-1) - target.view(-1)).abs().mean()
            epoch_loss = (epoch_loss*i + mae.item())/(i+1)
            
            pbar.set_description('Epoch %d: MAE/CA %2.6f'%(epoch+1, epoch_loss/chemical_accuracy[target_idx]))
            output.backward()
            optimizer.step()
            wandb.log({'MSE': output.item()})
        curr_lr = list(optimizer.param_groups)[0]['lr']
        wandb.log({'MAE':epoch_loss, 
                   'MAE/CA':epoch_loss/chemical_accuracy[target_idx],
                   'learning rate':curr_lr})
        lr_scheduler.step(epoch_loss)
        loss_hist += [epoch_loss] 

        if epoch_loss/chemical_accuracy[target_idx] < min_mae*config.required_improvement:
            min_mae = epoch_loss/chemical_accuracy[target_idx]
            torch.save(model.state_dict(), config.model_dir+'gtransformer_epoch%d.pt'%(epoch+1))
        if curr_lr < config.minimal_lr:
            break
        if epoch_loss/chemical_accuracy[target_idx] < config.target_ratio:
            break

except KeyboardInterrupt:
    print('keyboard interrupt caught')
    torch.save(model.state_dict(), config.model_dir+'gtransformer_epoch%d.pt'%(epoch+1))

Epoch 1: MAE/CA 15.805389: : 860it [01:02, 13.74it/s]
Epoch 2: MAE/CA 15.169381: : 860it [01:05, 13.21it/s]
Epoch 3: MAE/CA 14.145570: : 860it [01:04, 13.40it/s]
Epoch 4: MAE/CA 14.015864: : 860it [01:00, 14.30it/s]
Epoch 5: MAE/CA 14.914134: : 860it [01:03, 13.61it/s]
Epoch 6: MAE/CA 13.536842: : 860it [01:05, 13.18it/s]
Epoch 7: MAE/CA 13.866200: : 860it [01:01, 14.02it/s]
Epoch 8: MAE/CA 13.278941: : 860it [00:59, 14.48it/s]
Epoch 9: MAE/CA 14.147046: : 860it [01:03, 13.65it/s]
Epoch 10: MAE/CA 13.836421: : 860it [01:03, 13.60it/s]
Epoch 11: MAE/CA 13.987762: : 860it [01:09, 12.40it/s]
Epoch 12: MAE/CA 12.753078: : 860it [01:12, 11.88it/s]
Epoch 13: MAE/CA 12.444085: : 860it [01:03, 13.63it/s]
Epoch 14: MAE/CA 14.526193: : 860it [01:00, 14.22it/s]
Epoch 15: MAE/CA 16.034590: : 860it [01:06, 12.90it/s]
Epoch 16: MAE/CA 13.333688: : 860it [01:07, 12.77it/s]
Epoch 17: MAE/CA 12.821490: : 860it [01:07, 12.71it/s]
Epoch 18: MAE/CA 13.096951: : 860it [01:07, 12.74it/s]
Epoch 19: MAE/CA 13

Epoch    29: reducing learning rate of group 0 to 2.8500e-04.


Epoch 20: MAE/CA 12.591199: : 860it [01:06, 13.03it/s]
Epoch 21: MAE/CA 12.906681: : 860it [01:07, 12.67it/s]
Epoch 22: MAE/CA 13.126674: : 860it [01:06, 12.91it/s]
Epoch 23: MAE/CA 12.135311: : 860it [01:06, 12.99it/s]
Epoch 24: MAE/CA 13.478188: : 860it [01:06, 12.93it/s]
Epoch 25: MAE/CA 13.308612: : 860it [01:06, 12.94it/s]
Epoch 26: MAE/CA 12.776156: : 860it [01:06, 12.91it/s]
Epoch 27: MAE/CA 12.942582: : 860it [01:06, 12.92it/s]
Epoch 28: MAE/CA 12.129993: : 860it [01:06, 12.94it/s]
Epoch 29: MAE/CA 12.159268: : 860it [01:06, 12.95it/s]
Epoch 30: MAE/CA 12.166283: : 860it [01:06, 12.97it/s]
Epoch 31: MAE/CA 12.184331: : 860it [01:06, 12.99it/s]
Epoch 32: MAE/CA 12.573886: : 860it [01:06, 12.92it/s]
Epoch 33: MAE/CA 12.591450: : 860it [01:06, 12.87it/s]
Epoch 34: MAE/CA 11.577575: : 860it [01:06, 12.85it/s]
Epoch 35: MAE/CA 11.775059: : 860it [01:07, 12.81it/s]
Epoch 36: MAE/CA 11.858796: : 860it [01:06, 12.84it/s]
Epoch 37: MAE/CA 12.562565: : 860it [01:06, 12.89it/s]
Epoch 38: 

Epoch    50: reducing learning rate of group 0 to 2.7075e-04.


Epoch 41: MAE/CA 11.194960: : 860it [01:06, 12.91it/s]
Epoch 42: MAE/CA 11.395279: : 860it [01:07, 12.82it/s]
Epoch 43: MAE/CA 11.799602: : 860it [01:07, 12.79it/s]
Epoch 44: MAE/CA 11.703280: : 860it [01:07, 12.78it/s]
Epoch 45: MAE/CA 11.046879: : 860it [01:06, 12.84it/s]
Epoch 46: MAE/CA 11.037964: : 860it [01:06, 12.85it/s]
Epoch 47: MAE/CA 10.983130: : 860it [01:06, 12.88it/s]
Epoch 48: MAE/CA 11.185805: : 860it [01:07, 12.80it/s]
Epoch 49: MAE/CA 11.178795: : 860it [01:08, 12.63it/s]
Epoch 50: MAE/CA 13.010411: : 860it [01:07, 12.67it/s]
Epoch 51: MAE/CA 12.212091: : 860it [01:07, 12.66it/s]
Epoch 52: MAE/CA 12.035996: : 860it [01:08, 12.61it/s]
Epoch 53: MAE/CA 11.374151: : 860it [01:07, 12.68it/s]


Epoch    63: reducing learning rate of group 0 to 2.5721e-04.


Epoch 54: MAE/CA 11.177121: : 860it [01:07, 12.69it/s]
Epoch 55: MAE/CA 10.895792: : 860it [01:11, 11.97it/s]
Epoch 56: MAE/CA 10.840669: : 860it [01:15, 11.39it/s]
Epoch 57: MAE/CA 12.272431: : 860it [01:07, 12.83it/s]
Epoch 58: MAE/CA 12.009669: : 860it [00:59, 14.53it/s]
Epoch 59: MAE/CA 11.249029: : 860it [00:59, 14.40it/s]
Epoch 60: MAE/CA 10.568789: : 860it [01:00, 14.22it/s]
Epoch 61: MAE/CA 10.609090: : 860it [01:02, 13.71it/s]
Epoch 62: MAE/CA 10.409435: : 860it [01:07, 12.77it/s]
Epoch 63: MAE/CA 10.924024: : 860it [01:07, 12.70it/s]
Epoch 64: MAE/CA 11.340211: : 860it [01:07, 12.65it/s]
Epoch 65: MAE/CA 10.792006: : 860it [01:07, 12.65it/s]
Epoch 66: MAE/CA 10.790958: : 860it [01:07, 12.71it/s]
Epoch 67: MAE/CA 10.186450: : 860it [01:07, 12.65it/s]
Epoch 68: MAE/CA 9.985707: : 860it [01:07, 12.66it/s] 
Epoch 69: MAE/CA 10.582305: : 860it [01:07, 12.70it/s]
Epoch 70: MAE/CA 10.287564: : 860it [01:07, 12.77it/s]
Epoch 71: MAE/CA 10.474260: : 860it [01:07, 12.69it/s]
Epoch 72: 

Epoch    84: reducing learning rate of group 0 to 2.4435e-04.


Epoch 75: MAE/CA 10.395594: : 860it [01:08, 12.63it/s]
Epoch 76: MAE/CA 10.268546: : 860it [01:08, 12.53it/s]
Epoch 77: MAE/CA 10.187668: : 860it [01:08, 12.57it/s]
Epoch 78: MAE/CA 10.092585: : 860it [01:07, 12.82it/s]
Epoch 79: MAE/CA 10.092655: : 860it [01:08, 12.56it/s]
Epoch 80: MAE/CA 9.707719: : 860it [01:07, 12.78it/s]
Epoch 81: MAE/CA 10.367082: : 860it [01:07, 12.77it/s]
Epoch 82: MAE/CA 10.247466: : 860it [01:07, 12.81it/s]
Epoch 83: MAE/CA 10.134599: : 860it [01:06, 12.85it/s]
Epoch 84: MAE/CA 11.052890: : 860it [01:07, 12.72it/s]
Epoch 85: MAE/CA 10.329376: : 860it [01:06, 12.90it/s]
Epoch 86: MAE/CA 9.652295: : 860it [01:07, 12.76it/s]
Epoch 87: MAE/CA 9.586045: : 860it [01:07, 12.76it/s]
Epoch 88: MAE/CA 9.604527: : 860it [01:07, 12.73it/s]
Epoch 89: MAE/CA 9.720873: : 860it [01:06, 12.84it/s]
Epoch 90: MAE/CA 9.884996: : 860it [01:07, 12.82it/s]
Epoch 91: MAE/CA 10.020887: : 860it [01:07, 12.81it/s]
Epoch 92: MAE/CA 9.884328: : 860it [01:06, 12.85it/s] 
Epoch 93: MAE/CA

Epoch   103: reducing learning rate of group 0 to 2.3213e-04.


Epoch 94: MAE/CA 9.814396: : 860it [01:06, 12.88it/s] 
Epoch 95: MAE/CA 9.644099: : 860it [01:06, 12.84it/s]
Epoch 96: MAE/CA 9.479100: : 860it [01:06, 12.84it/s]
Epoch 97: MAE/CA 9.576264: : 860it [01:07, 12.77it/s]
Epoch 98: MAE/CA 9.324957: : 860it [01:05, 13.08it/s]
Epoch 99: MAE/CA 9.514146: : 860it [01:09, 12.35it/s]
Epoch 100: MAE/CA 9.507185: : 860it [01:07, 12.78it/s]
Epoch 101: MAE/CA 9.607712: : 860it [01:14, 11.48it/s]
Epoch 102: MAE/CA 9.195068: : 860it [01:16, 11.28it/s]
Epoch 103: MAE/CA 9.415773: : 860it [01:10, 12.15it/s]
Epoch 104: MAE/CA 9.411882: : 860it [01:00, 14.32it/s]
Epoch 105: MAE/CA 9.184112: : 860it [01:04, 13.24it/s]
Epoch 106: MAE/CA 9.093210: : 860it [01:06, 12.84it/s]
Epoch 107: MAE/CA 9.060607: : 860it [01:07, 12.81it/s]
Epoch 108: MAE/CA 9.186903: : 860it [01:07, 12.82it/s]
Epoch 109: MAE/CA 8.883504: : 860it [01:06, 12.98it/s]
Epoch 110: MAE/CA 8.961036: : 860it [01:07, 12.81it/s]
Epoch 111: MAE/CA 8.904660: : 860it [01:07, 12.79it/s]
Epoch 112: MAE/

Epoch   139: reducing learning rate of group 0 to 2.2053e-04.


Epoch 130: MAE/CA 8.294313: : 860it [00:59, 14.44it/s]
Epoch 131: MAE/CA 8.338695: : 860it [00:56, 15.19it/s]
Epoch 132: MAE/CA 8.569461: : 860it [01:00, 14.16it/s]
Epoch 133: MAE/CA 8.250676: : 860it [00:58, 14.59it/s]
Epoch 134: MAE/CA 8.272770: : 860it [00:58, 14.61it/s]
Epoch 135: MAE/CA 8.193269: : 860it [00:58, 14.61it/s]
Epoch 136: MAE/CA 8.107286: : 860it [00:58, 14.60it/s]
Epoch 137: MAE/CA 8.224704: : 860it [00:58, 14.61it/s]
Epoch 138: MAE/CA 8.138376: : 860it [01:00, 14.30it/s]
Epoch 139: MAE/CA 8.098367: : 860it [00:59, 14.54it/s]
Epoch 140: MAE/CA 8.285211: : 860it [00:59, 14.48it/s]
Epoch 141: MAE/CA 7.987214: : 860it [00:58, 14.61it/s]
Epoch 142: MAE/CA 7.984403: : 860it [00:59, 14.36it/s]
Epoch 143: MAE/CA 7.946938: : 860it [00:59, 14.46it/s]
Epoch 144: MAE/CA 7.963986: : 860it [00:59, 14.56it/s]
Epoch 145: MAE/CA 8.223763: : 860it [00:59, 14.54it/s]
Epoch 146: MAE/CA 8.257481: : 860it [00:59, 14.52it/s]
Epoch 147: MAE/CA 8.464110: : 860it [00:59, 14.52it/s]
Epoch 148:

Epoch   168: reducing learning rate of group 0 to 2.0950e-04.


Epoch 159: MAE/CA 7.844792: : 860it [01:14, 11.54it/s]
Epoch 160: MAE/CA 7.995304: : 860it [01:17, 11.08it/s]
Epoch 161: MAE/CA 7.630594: : 860it [01:16, 11.30it/s]
Epoch 162: MAE/CA 7.590954: : 860it [01:15, 11.32it/s]
Epoch 163: MAE/CA 7.435776: : 860it [01:16, 11.21it/s]
Epoch 164: MAE/CA 7.499196: : 860it [01:15, 11.35it/s]


Epoch   174: reducing learning rate of group 0 to 1.9903e-04.


Epoch 165: MAE/CA 7.425297: : 860it [01:16, 11.25it/s]
Epoch 166: MAE/CA 7.426188: : 860it [01:05, 13.14it/s]
Epoch 167: MAE/CA 7.399407: : 860it [01:04, 13.27it/s]
Epoch 168: MAE/CA 7.388739: : 860it [01:05, 13.06it/s]
Epoch 169: MAE/CA 7.205552: : 860it [01:02, 13.76it/s]
Epoch 170: MAE/CA 7.265141: : 860it [01:03, 13.64it/s]
Epoch 171: MAE/CA 7.411328: : 860it [01:06, 13.00it/s]
Epoch 172: MAE/CA 7.602295: : 860it [01:08, 12.57it/s]
Epoch 173: MAE/CA 7.110268: : 860it [01:08, 12.48it/s]
Epoch 174: MAE/CA 6.963226: : 860it [01:09, 12.42it/s]
Epoch 175: MAE/CA 7.023363: : 860it [01:08, 12.53it/s]
Epoch 176: MAE/CA 7.209128: : 860it [01:10, 12.20it/s]
Epoch 177: MAE/CA 7.155317: : 860it [01:08, 12.48it/s]
Epoch 178: MAE/CA 7.268568: : 860it [01:09, 12.44it/s]
Epoch 179: MAE/CA 7.140032: : 860it [01:09, 12.37it/s]
Epoch 180: MAE/CA 7.204929: : 860it [01:09, 12.36it/s]


Epoch   190: reducing learning rate of group 0 to 1.8907e-04.


Epoch 181: MAE/CA 7.036961: : 860it [01:09, 12.40it/s]
Epoch 182: MAE/CA 6.935127: : 860it [01:07, 12.76it/s]
Epoch 183: MAE/CA 6.848633: : 860it [01:10, 12.16it/s]
Epoch 184: MAE/CA 6.801398: : 860it [01:08, 12.53it/s]
Epoch 185: MAE/CA 6.862222: : 860it [01:08, 12.49it/s]
Epoch 186: MAE/CA 6.760225: : 860it [01:09, 12.37it/s]
Epoch 187: MAE/CA 6.860900: : 860it [01:09, 12.41it/s]
Epoch 188: MAE/CA 6.805523: : 860it [01:09, 12.41it/s]
Epoch 189: MAE/CA 6.795055: : 860it [01:09, 12.40it/s]
Epoch 190: MAE/CA 6.849457: : 860it [01:09, 12.30it/s]
Epoch 191: MAE/CA 6.877143: : 860it [01:10, 12.13it/s]
Epoch 192: MAE/CA 7.015489: : 860it [01:09, 12.29it/s]


Epoch   202: reducing learning rate of group 0 to 1.7962e-04.


Epoch 193: MAE/CA 7.562793: : 860it [01:09, 12.42it/s]
Epoch 194: MAE/CA 6.875301: : 860it [01:09, 12.41it/s]
Epoch 195: MAE/CA 6.733657: : 860it [01:09, 12.41it/s]
Epoch 196: MAE/CA 6.859292: : 860it [01:10, 12.20it/s]
Epoch 197: MAE/CA 6.576509: : 860it [01:08, 12.48it/s]
Epoch 198: MAE/CA 6.659021: : 860it [01:09, 12.45it/s]
Epoch 199: MAE/CA 6.738879: : 860it [01:09, 12.33it/s]
Epoch 200: MAE/CA 6.647497: : 860it [01:08, 12.50it/s]
Epoch 201: MAE/CA 6.632392: : 860it [01:10, 12.25it/s]
Epoch 202: MAE/CA 6.502913: : 860it [01:09, 12.43it/s]
Epoch 203: MAE/CA 6.583929: : 860it [01:09, 12.29it/s]
Epoch 204: MAE/CA 6.476869: : 860it [01:08, 12.57it/s]
Epoch 205: MAE/CA 6.662472: : 860it [01:10, 12.21it/s]
Epoch 206: MAE/CA 6.582253: : 860it [01:08, 12.62it/s]
Epoch 207: MAE/CA 6.763137: : 860it [01:07, 12.71it/s]
Epoch 208: MAE/CA 6.401024: : 860it [01:08, 12.50it/s]
Epoch 209: MAE/CA 6.419488: : 860it [01:09, 12.38it/s]
Epoch 210: MAE/CA 6.389044: : 860it [01:08, 12.47it/s]
Epoch 211:

keyboard interrupt caught


In [22]:
pbar = tqdm.tqdm(enumerate(test_loader, 0))
epoch_loss = 0
maes = []
for i, data in pbar:
    data = {key:d.to(device) for key, d in data.items()} 
    target = data['target']
    prediction = model(data)
    mae = (prediction.view(-1) - target.view(-1)).abs()
    maes += [mae.detach().cpu()]
maes = torch.cat(maes, dim=0)
mae = maes.mean().item()
print(mae, mae/chemical_accuracy[target_idx])
wandb.log({'TEST MAE':mae, 'TEST MAE/CA':mae/chemical_accuracy[target_idx]})


339it [00:04, 78.20it/s]

0.3034960627555847 7.05804797106011



