In [1]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets.qm9 import QM9
import torch_geometric.datasets.qm9 as qm9
from torch_geometric.data import DataLoader
import torch_geometric.nn as tgnn
from torch_scatter import scatter
import tqdm
import numpy as np
import wandb

In [2]:
from torch_geometric.nn.models.schnet import GaussianSmearing
from dfs_transformer import EarlyStopping

In [3]:
# [0] Reports MAE in eV / Chemical Accuracy of the target variable U0. 
# The chemical accuracy of U0 is 0.043 see [1, Table 5].

# Reproduced table [0]
# MXMNet: 0.00590/0.043 = 0.13720930232558143
# HMGNN:  0.00592/0.043 = 0.13767441860465118
# MPNN:   0.01935/0.043 = 0.45
# KRR:    0.0251 /0.043 = 0.5837209302325582
# [0] https://paperswithcode.com/sota/formation-energy-on-qm9
# [1] Neural Message Passing for Quantum Chemistry, https://arxiv.org/pdf/1704.01212v2.pdf
# MXMNet https://arxiv.org/pdf/2011.07457v1.pdf
# HMGNN https://arxiv.org/pdf/2009.12710v1.pdf
# MPNN https://arxiv.org/pdf/1704.01212v2.pdf
# KRR HDAD kernel ridge regression https://arxiv.org/pdf/1702.05532.pdf
# HDAD means HDAD (Histogram of distances, anglesand dihedral angles)

# [2] Reports the average value of MAE / Chemical Accuracy of over all targets
# [2] https://paperswithcode.com/sota/drug-discovery-on-qm9
target_dict = {0: 'mu, D, Dipole moment', 
               1: 'alpha, {a_0}^3, Isotropic polarizability', 
               2: 'epsilon_{HOMO}, eV, Highest occupied molecular orbital energy',
               3: 'epsilon_{LUMO}, eV, Lowest unoccupied molecular orbital energy',
               4: 'Delta, eV, Gap between HOMO and LUMO',
               5: '< R^2 >, {a_0}^2, Electronic spatial extent',
               6: 'ZPVE, eV, Zero point vibrational energy', 
               7: 'U_0, eV, Internal energy at 0K',
               8: 'U, eV, Internal energy at 298.15K', 
               9: 'H, eV, Enthalpy at 298.15K',
               10: 'G, eV, Free energy at 298.15K',  
               11: 'c_{v}, cal\(mol K), Heat capacity at 298.15K'}

chemical_accuracy = {idx:0.043 for idx in range(12)}
chemical_accuracy[0] = 0.1
chemical_accuracy[1] = 0.1
chemical_accuracy[5] = 1.2
chemical_accuracy[6] = 0.0012
chemical_accuracy[11] = 0.050

In [4]:
wandb.init(project='QM9-GAT', entity='chrisxx')
config = wandb.config
config.hidden_dim = 128
config.nlayers = 3
config.nhead = 1
config.lr = 0.0003
config.n_epochs = 5000
config.patience = 5
config.factor = 0.95
config.minimal_lr = 6e-8
config.target_idx = 7
config.batch_size = 256
config.valid_patience = 100
config.valid_minimal_improvement = 0.005
config.n_train = 110000
config.n_valid = 10000
config.target_ratio = 0.3
config.store_starting_from_ratio = 1
config.required_improvement = 0.8
config.model_dir = '../models/qm9/MPNN/DFS/1/'
config.num_workers = 4
config.dfs_codes = '../datasets/qm9_geometric_work/min_dfs_codes.json'
config.use_pos = False
config.use_dist = True
config.comment = "Actually, in this version I modified how the readout function is used. Instead of computing "\
"the target property as the weighted sum of target properties, I compute a hidden representation first and the "\
"prediction subsequently."

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mchrisxx[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.11.0 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [5]:
dfs_codes = None
if config.dfs_codes is not None:
    import json
    with open(config.dfs_codes, 'r') as f:
        dfs_codes = json.load(f)

In [6]:
def transform(data, dfs_codes = dfs_codes, edge_transform = GaussianSmearing(0, 10, 50), use_dist=config.use_dist):
    features = data.x
    # make atomic number a one hot
    atomic_number = nn.functional.one_hot(features[:, 5].long(), 100)
    # make num_h a one hot
    num_h = nn.functional.one_hot(features[:, -1].long(), 9)
    data.x = torch.cat((features[:, :5], features[:, 6:-1], atomic_number, num_h), axis=1)
    if dfs_codes is not None:
        dfs_indices = nn.functional.one_hot(torch.LongTensor(dfs_codes[data.name]['dfs_indices']), 29)
        data.x = torch.cat((data.x, dfs_indices), axis=1)
    if use_dist:
        row, col = data.edge_index
        edge_weights = (data.pos[row] - data.pos[col]).norm(dim=-1)
        dist_feats = edge_transform(edge_weights)
        data.edge_attr = torch.cat((data.edge_attr, dist_feats), axis=1)
    return data

In [7]:
target_idx = config.target_idx

In [8]:
dataset = QM9('../datasets/qm9_geometric_work/', transform=transform)

In [9]:
dataset = dataset.shuffle()
train_dataset = dataset[:config.n_train]
valid_dataset = dataset[config.n_train:config.n_train+config.n_valid]
test_dataset = dataset[config.n_train+config.n_valid:]
config.n_test = len(test_dataset)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, pin_memory=True, num_workers=config.num_workers)
valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, pin_memory=True, num_workers=config.num_workers)
test_loader = DataLoader(test_dataset, batch_size=32, num_workers=config.num_workers)

In [10]:
import os
os.makedirs(config.model_dir, exist_ok=True)

In [11]:
torch.save(dataset.indices(), config.model_dir+'dataset_indices.pt')

In [12]:
ngpu=1
device = torch.device('cuda:0' if (torch.cuda.is_available() and ngpu > 0) else 'cpu')

In [13]:
def score(loader, model, device=device):
    model = model.to(device)
    pbar = tqdm.tqdm(enumerate(loader, 0))
    maes = []
    for i, data in pbar:
        data.to(device)
        prediction = model(data)
        mae = (prediction.view(-1) - data.y[:, target_idx]).abs()
        maes += [mae.detach().cpu()]
    maes = torch.cat(maes, dim=0)
    mae = maes.mean().item()
    return mae/chemical_accuracy[config.target_idx]

# Model

In [14]:
target_vec = []

In [15]:
# based on https://schnetpack.readthedocs.io/en/stable/tutorials/tutorial_02_qm9.html
# and https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/schnet.html#SchNet
for data in train_loader:
    data = data.to(device)
    atomU0s = torch.tensor(qm9.atomrefs[target_idx], device=device)[torch.argmax(data.x[:, :5], axis=1)]
    target_modular = scatter(atomU0s, data.batch, dim=-1, reduce='sum')
    target_vec += [(data.y[:, target_idx] - target_modular).detach().cpu().numpy()]
target_vec = np.concatenate(target_vec, axis=0)

In [16]:
target_mean = np.mean(target_vec)
target_std = np.std(target_vec)

In [17]:
class GATNN(nn.Module):
    def __init__(self, vert_dim, edge_dim, hidden_dim=config.hidden_dim, nhead=config.nhead, nlayers=config.nlayers,
                 #ro_dmodel=128, ro_dim_feedforward=512, ro_nlayers=6, ro_nhead=8, 
                 mean=None, std=None, atomref=None,
                 max_vertices=29, max_edges=28, use_pos=False):
        """
        transfomer model is some type of transformer that 
        """
        super(GATNN, self).__init__()
        self.hidden_dim = hidden_dim
        gat_layers = []
        for _ in range(nlayers):
            msg = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), 
                                      nn.ReLU(inplace=True),
                                      nn.Linear(hidden_dim, hidden_dim),
                                      nn.ReLU(inplace=True),
                                      nn.Linear(hidden_dim, hidden_dim))
            gat_layers += [(tgnn.TransformerConv(hidden_dim, hidden_dim), 'x, edge_index, edge_attr -> x')]
        
        self.gat_layers = tgnn.Sequential('x, edge_index, edge_attr', gat_layers)
        # gate turns the hidden states into attention weights
        self.gate = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), 
                                      nn.ReLU(inplace=True),
                                      nn.Linear(hidden_dim, hidden_dim),
                                      nn.ReLU(inplace=True),
                                      nn.Linear(hidden_dim, 1)) # sequential instead of linear also good improvement
        # fc_out turns the hidden states into predictions
        self.fc_out = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), 
                                      nn.ReLU(inplace=True),
                                      nn.Linear(hidden_dim, hidden_dim),
                                      nn.ReLU(inplace=True),
                                      nn.Linear(hidden_dim, 1)) # sequential instead of linear huge improvment
        self.readout = tgnn.GlobalAttention(self.gate, None)
        
        self.x_emb = nn.Linear(vert_dim, hidden_dim)
        if use_pos:
            self.pos_emb = nn.Linear(3, hidden_dim)
        else:
            self.pos_emb = None
        self.edge_emb = nn.Linear(edge_dim, hidden_dim)
        
        #TODO: maybe:
        #self.readout = SelfAttentionReadout(ro_dmodel, ro_dim_feedforward, ro_nlayers, ro_nhead)
        
        self.mean = mean
        self.std = std
        self.register_buffer('initial_atomref', atomref)
        self.atomref = None
        if atomref is not None:
            self.atomref = nn.Embedding(100, 1)
            self.atomref.weight.data.copy_(atomref)
    
    
    def forward(self, data):
        x, z, edge_index, batch = data.x, data.z, data.edge_index, data.batch
        h = self.x_emb(x) 
        if self.pos_emb is not None:
            h += self.pos_emb(data.pos)
        edge_attr = self.edge_emb(data.edge_attr).view(-1, 1, self.hidden_dim)
        h = self.gat_layers(h, edge_index, edge_attr) 
        out = self.readout(h, batch)
        out = self.fc_out(out)
        
        #batch = batch.permute(1, 0, 2) # seq_dim x batch_dim x n_model
        #transformer_out = self.enc(batch)
        #out = self.fc_out(transformer_out[0]) 
        
        # tricks from Schnet
        if self.mean is not None and self.std is not None:
            out = out * self.std + self.mean
                
        if self.atomref is not None:
            out = out + tgnn.global_add_pool(self.atomref(z), batch)
        
        return out

In [18]:
data = next(iter(train_loader))
model = GATNN(data.x.shape[1], data.edge_attr.shape[1], atomref=dataset.atomref(target_idx), 
              mean=target_mean, std=target_std, use_pos=config.use_pos)
loss = nn.MSELoss(reduction='mean')
optimizer = optim.Adam(model.parameters(), lr=config.lr)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', verbose=True, patience=config.patience, factor=config.factor)
early_stopping = EarlyStopping(patience=config.valid_patience, delta=config.valid_minimal_improvement,
                              path=config.model_dir+'checkpoint.pt')



# Training

In [19]:
model = model.to(device)

In [None]:
loss_hist = []
min_mae = config.store_starting_from_ratio
try:
    # For each epoch
    for epoch in range(config.n_epochs):
        # For each batch in the dataloader
        pbar = tqdm.tqdm(enumerate(train_loader, 0))
        epoch_loss = 0
        for i, data in pbar:
            model.zero_grad()
            data.to(device)
            target = data.y[:, target_idx]
            prediction = model(data)
            output = loss(prediction.view(-1), target)
            mae = (prediction.view(-1) - target).abs().mean()
            epoch_loss = (epoch_loss*i + mae.item())/(i+1)
            
            pbar.set_description('Epoch %d: MAE/CA %2.6f'%(epoch+1, epoch_loss/chemical_accuracy[target_idx]))
            output.backward()
            optimizer.step()
            wandb.log({'MSE': output.item()})
        
        valid_loss = score(valid_loader, model)
        curr_lr = list(optimizer.param_groups)[0]['lr']
        wandb.log({'MAE':epoch_loss, 
                   'MAE/CA':epoch_loss/chemical_accuracy[target_idx],
                   'learning rate':curr_lr,
                   'MAE/CA valid':valid_loss})
        
        lr_scheduler.step(epoch_loss)
        early_stopping(valid_loss, model)
        loss_hist += [epoch_loss] 
        
        if early_stopping.early_stop:
            break
        
        if curr_lr < config.minimal_lr:
            break


except KeyboardInterrupt:
    print('keyboard interrupt caught')
    torch.save(model.state_dict(), config.model_dir+'gat_epoch%d.pt'%(epoch+1))

Epoch 1: MAE/CA 51.056676: : 430it [00:42, 10.07it/s] 
40it [00:01, 23.38it/s]
Epoch 2: MAE/CA 16.820828: : 430it [00:31, 13.54it/s]
40it [00:01, 22.47it/s]
Epoch 3: MAE/CA 14.125775: : 430it [00:31, 13.54it/s]
40it [00:01, 22.66it/s]
Epoch 4: MAE/CA 13.199750: : 430it [00:31, 13.60it/s]
40it [00:01, 23.51it/s]
Epoch 5: MAE/CA 12.056473: : 430it [00:31, 13.57it/s]
40it [00:01, 22.82it/s]

EarlyStopping counter: 1 out of 100



Epoch 6: MAE/CA 11.880204: : 430it [00:31, 13.73it/s]
40it [00:01, 23.88it/s]
Epoch 7: MAE/CA 10.673598: : 430it [00:30, 14.03it/s]
40it [00:01, 22.98it/s]
Epoch 8: MAE/CA 10.932510: : 430it [00:31, 13.53it/s]
40it [00:01, 23.16it/s]

EarlyStopping counter: 1 out of 100



Epoch 9: MAE/CA 9.775317: : 430it [00:31, 13.70it/s]
40it [00:01, 22.90it/s]

EarlyStopping counter: 2 out of 100



Epoch 10: MAE/CA 10.054907: : 430it [00:31, 13.65it/s]
40it [00:01, 24.13it/s]

EarlyStopping counter: 3 out of 100



Epoch 11: MAE/CA 9.806014: : 430it [00:31, 13.65it/s] 
40it [00:01, 24.56it/s]
Epoch 12: MAE/CA 9.273586: : 430it [00:31, 13.78it/s]
40it [00:01, 24.87it/s]

EarlyStopping counter: 1 out of 100



Epoch 13: MAE/CA 9.268512: : 430it [00:31, 13.77it/s]
40it [00:01, 23.96it/s]

EarlyStopping counter: 2 out of 100



Epoch 14: MAE/CA 9.191908: : 430it [00:31, 13.77it/s]
40it [00:01, 24.50it/s]

EarlyStopping counter: 3 out of 100



Epoch 15: MAE/CA 8.648076: : 430it [00:31, 13.78it/s]
40it [00:01, 24.13it/s]
Epoch 16: MAE/CA 8.699489: : 430it [00:31, 13.69it/s]
40it [00:01, 24.20it/s]
Epoch 17: MAE/CA 8.549317: : 430it [00:31, 13.81it/s]
40it [00:01, 24.71it/s]
Epoch 18: MAE/CA 8.227106: : 430it [00:31, 13.78it/s]
40it [00:01, 23.25it/s]

EarlyStopping counter: 1 out of 100



Epoch 19: MAE/CA 8.211718: : 430it [00:31, 13.68it/s]
40it [00:01, 23.90it/s]
Epoch 20: MAE/CA 7.815810: : 430it [00:30, 13.92it/s]
40it [00:01, 24.66it/s]

EarlyStopping counter: 1 out of 100



Epoch 21: MAE/CA 7.873553: : 430it [00:30, 13.92it/s]
40it [00:01, 23.82it/s]

EarlyStopping counter: 2 out of 100



Epoch 22: MAE/CA 7.830062: : 430it [00:31, 13.65it/s]
40it [00:01, 23.50it/s]

EarlyStopping counter: 3 out of 100



Epoch 23: MAE/CA 7.366191: : 430it [00:31, 13.83it/s]
40it [00:01, 24.30it/s]

EarlyStopping counter: 4 out of 100



Epoch 24: MAE/CA 7.298763: : 430it [00:31, 13.64it/s]
40it [00:01, 24.21it/s]

EarlyStopping counter: 5 out of 100



Epoch 25: MAE/CA 7.195370: : 430it [00:31, 13.65it/s]
40it [00:01, 22.56it/s]
Epoch 26: MAE/CA 7.251029: : 430it [00:31, 13.77it/s]
40it [00:01, 24.40it/s]
Epoch 27: MAE/CA 7.025701: : 430it [00:31, 13.79it/s]
40it [00:01, 24.47it/s]
Epoch 28: MAE/CA 6.819836: : 430it [00:31, 13.75it/s]
40it [00:01, 23.55it/s]

EarlyStopping counter: 1 out of 100



Epoch 29: MAE/CA 6.594952: : 430it [00:31, 13.73it/s]
40it [00:01, 23.34it/s]

EarlyStopping counter: 2 out of 100



Epoch 30: MAE/CA 6.896762: : 430it [00:31, 13.75it/s]
40it [00:01, 24.19it/s]
Epoch 31: MAE/CA 6.694578: : 430it [00:31, 13.75it/s]
40it [00:01, 23.87it/s]

EarlyStopping counter: 1 out of 100



Epoch 32: MAE/CA 6.423549: : 430it [00:31, 13.74it/s]
40it [00:01, 24.77it/s]

EarlyStopping counter: 2 out of 100



Epoch 33: MAE/CA 6.745142: : 430it [00:30, 13.90it/s]
40it [00:01, 24.68it/s]

EarlyStopping counter: 3 out of 100



Epoch 34: MAE/CA 6.425286: : 430it [00:30, 14.08it/s]
40it [00:01, 24.73it/s]


EarlyStopping counter: 4 out of 100


Epoch 35: MAE/CA 6.184722: : 430it [00:31, 13.75it/s]
40it [00:01, 24.09it/s]

EarlyStopping counter: 5 out of 100



Epoch 36: MAE/CA 6.348657: : 430it [00:31, 13.68it/s]
40it [00:01, 24.04it/s]
Epoch 37: MAE/CA 6.016967: : 430it [00:31, 13.74it/s]
40it [00:01, 22.51it/s]
Epoch 38: MAE/CA 6.192207: : 430it [00:31, 13.55it/s]
40it [00:01, 24.46it/s]

EarlyStopping counter: 1 out of 100



Epoch 39: MAE/CA 6.310489: : 430it [00:31, 13.64it/s]
40it [00:01, 23.46it/s]

EarlyStopping counter: 2 out of 100



Epoch 40: MAE/CA 5.880712: : 430it [00:31, 13.73it/s]
40it [00:01, 23.61it/s]
Epoch 41: MAE/CA 5.687949: : 430it [00:31, 13.81it/s]
40it [00:01, 23.74it/s]

EarlyStopping counter: 1 out of 100



Epoch 42: MAE/CA 6.088658: : 430it [00:31, 13.80it/s]
40it [00:01, 25.63it/s]
Epoch 43: MAE/CA 5.689010: : 430it [00:31, 13.80it/s]
40it [00:01, 23.68it/s]

EarlyStopping counter: 1 out of 100



Epoch 44: MAE/CA 5.826251: : 430it [00:31, 13.65it/s]
40it [00:01, 24.34it/s]

EarlyStopping counter: 2 out of 100



Epoch 45: MAE/CA 5.646966: : 430it [00:31, 13.64it/s]
40it [00:01, 23.00it/s]

EarlyStopping counter: 3 out of 100



Epoch 46: MAE/CA 5.773576: : 430it [00:31, 13.76it/s]
40it [00:01, 24.31it/s]

EarlyStopping counter: 4 out of 100



Epoch 47: MAE/CA 5.569738: : 430it [00:31, 13.68it/s]
40it [00:01, 24.47it/s]

EarlyStopping counter: 5 out of 100



Epoch 48: MAE/CA 5.571433: : 430it [00:31, 13.82it/s]
40it [00:01, 24.25it/s]
Epoch 49: MAE/CA 5.325124: : 430it [00:31, 13.75it/s]
40it [00:01, 23.24it/s]

EarlyStopping counter: 1 out of 100



Epoch 50: MAE/CA 5.486603: : 430it [00:31, 13.74it/s]
40it [00:01, 24.27it/s]
Epoch 51: MAE/CA 5.270191: : 430it [00:31, 13.80it/s]
40it [00:01, 23.46it/s]

EarlyStopping counter: 1 out of 100



Epoch 52: MAE/CA 5.440021: : 430it [00:31, 13.83it/s]
40it [00:01, 24.14it/s]

EarlyStopping counter: 2 out of 100



Epoch 53: MAE/CA 5.374889: : 430it [00:31, 13.69it/s]
40it [00:01, 22.80it/s]

EarlyStopping counter: 3 out of 100



Epoch 54: MAE/CA 5.292015: : 430it [00:31, 13.70it/s]
40it [00:01, 24.00it/s]
Epoch 55: MAE/CA 5.157748: : 430it [00:31, 13.81it/s]
40it [00:01, 22.79it/s]

EarlyStopping counter: 1 out of 100



Epoch 56: MAE/CA 5.015539: : 430it [00:31, 13.70it/s]
40it [00:01, 24.30it/s]

EarlyStopping counter: 2 out of 100



Epoch 57: MAE/CA 5.051798: : 430it [00:31, 13.68it/s]
40it [00:01, 24.09it/s]
Epoch 58: MAE/CA 5.010640: : 430it [00:31, 13.65it/s]
40it [00:01, 23.75it/s]

EarlyStopping counter: 1 out of 100



Epoch 59: MAE/CA 5.078619: : 430it [00:31, 13.72it/s]
40it [00:01, 23.56it/s]

EarlyStopping counter: 2 out of 100



Epoch 60: MAE/CA 4.930680: : 430it [00:31, 13.82it/s]
40it [00:01, 24.47it/s]

EarlyStopping counter: 3 out of 100



Epoch 61: MAE/CA 4.855141: : 430it [00:29, 14.36it/s]
40it [00:01, 23.33it/s]
Epoch 62: MAE/CA 5.009504: : 430it [00:31, 13.78it/s]
40it [00:01, 24.86it/s]

EarlyStopping counter: 1 out of 100



Epoch 63: MAE/CA 4.786706: : 430it [00:31, 13.79it/s]
40it [00:01, 23.70it/s]

EarlyStopping counter: 2 out of 100



Epoch 64: MAE/CA 5.014591: : 430it [00:31, 13.83it/s]
40it [00:01, 24.13it/s]

EarlyStopping counter: 3 out of 100



Epoch 65: MAE/CA 4.625597: : 430it [00:31, 13.52it/s]
40it [00:01, 21.39it/s]

EarlyStopping counter: 4 out of 100



Epoch 66: MAE/CA 4.776923: : 430it [00:31, 13.62it/s]
40it [00:01, 23.57it/s]
Epoch 67: MAE/CA 4.687299: : 430it [00:31, 13.69it/s]
40it [00:01, 24.77it/s]

EarlyStopping counter: 1 out of 100



Epoch 68: MAE/CA 4.493776: : 430it [00:31, 13.74it/s]
40it [00:01, 24.25it/s]

EarlyStopping counter: 2 out of 100



Epoch 69: MAE/CA 4.705773: : 430it [00:31, 13.63it/s]
40it [00:01, 23.81it/s]

EarlyStopping counter: 3 out of 100



Epoch 70: MAE/CA 4.389283: : 430it [00:31, 13.75it/s]
40it [00:01, 23.85it/s]

EarlyStopping counter: 4 out of 100



Epoch 71: MAE/CA 4.680752: : 430it [00:31, 13.70it/s]
40it [00:01, 24.03it/s]

EarlyStopping counter: 5 out of 100



Epoch 72: MAE/CA 4.562340: : 430it [00:31, 13.83it/s]
40it [00:01, 25.24it/s]

EarlyStopping counter: 6 out of 100



Epoch 73: MAE/CA 4.475966: : 430it [00:31, 13.82it/s]
40it [00:01, 23.26it/s]
Epoch 74: MAE/CA 4.332178: : 430it [00:31, 13.66it/s]
40it [00:01, 23.82it/s]

EarlyStopping counter: 1 out of 100



Epoch 75: MAE/CA 4.491509: : 430it [00:31, 13.55it/s]
40it [00:01, 24.65it/s]

EarlyStopping counter: 2 out of 100



Epoch 76: MAE/CA 4.526780: : 430it [00:31, 13.69it/s]
40it [00:01, 23.17it/s]
Epoch 77: MAE/CA 4.666842: : 430it [00:31, 13.66it/s]
40it [00:01, 24.23it/s]

EarlyStopping counter: 1 out of 100



Epoch 78: MAE/CA 4.256204: : 430it [00:31, 13.62it/s]
40it [00:01, 24.83it/s]
Epoch 79: MAE/CA 4.326361: : 430it [00:31, 13.58it/s]
40it [00:01, 23.08it/s]

EarlyStopping counter: 1 out of 100



Epoch 80: MAE/CA 4.274476: : 430it [00:31, 13.61it/s]
40it [00:01, 25.25it/s]
Epoch 81: MAE/CA 4.275934: : 430it [00:31, 13.64it/s]
40it [00:01, 23.67it/s]

EarlyStopping counter: 1 out of 100



Epoch 82: MAE/CA 4.188626: : 430it [00:31, 13.69it/s]
40it [00:01, 22.82it/s]

EarlyStopping counter: 2 out of 100



Epoch 83: MAE/CA 4.236641: : 430it [00:31, 13.71it/s]
40it [00:01, 24.52it/s]

EarlyStopping counter: 3 out of 100



Epoch 84: MAE/CA 4.235566: : 430it [00:31, 13.72it/s]
40it [00:01, 24.57it/s]

EarlyStopping counter: 4 out of 100



Epoch 85: MAE/CA 4.206530: : 430it [00:31, 13.73it/s]
40it [00:01, 23.99it/s]

EarlyStopping counter: 5 out of 100



Epoch 86: MAE/CA 4.105124: : 430it [00:31, 13.68it/s]
40it [00:01, 26.29it/s]
Epoch 87: MAE/CA 4.007649: : 430it [00:31, 13.64it/s]
40it [00:01, 26.03it/s]

EarlyStopping counter: 1 out of 100



Epoch 88: MAE/CA 4.147226: : 430it [00:30, 14.09it/s]
40it [00:01, 24.00it/s]

EarlyStopping counter: 2 out of 100



Epoch 89: MAE/CA 4.073855: : 430it [00:31, 13.70it/s]
40it [00:01, 24.93it/s]

EarlyStopping counter: 3 out of 100



Epoch 90: MAE/CA 4.093028: : 430it [00:31, 13.72it/s]
40it [00:01, 24.38it/s]

EarlyStopping counter: 4 out of 100



Epoch 91: MAE/CA 4.037698: : 430it [00:31, 13.58it/s]
40it [00:01, 24.29it/s]

EarlyStopping counter: 5 out of 100



Epoch 92: MAE/CA 4.008429: : 430it [00:31, 13.54it/s]
40it [00:01, 23.53it/s]

EarlyStopping counter: 6 out of 100



Epoch 93: MAE/CA 3.994234: : 430it [00:31, 13.64it/s]
40it [00:01, 24.64it/s]

EarlyStopping counter: 7 out of 100



Epoch 94: MAE/CA 3.827675: : 430it [00:31, 13.77it/s]
40it [00:01, 23.62it/s]

EarlyStopping counter: 8 out of 100



Epoch 95: MAE/CA 4.170381: : 430it [00:31, 13.61it/s]
40it [00:01, 24.50it/s]

EarlyStopping counter: 9 out of 100



Epoch 96: MAE/CA 3.750955: : 430it [00:31, 13.73it/s]
40it [00:01, 24.52it/s]

EarlyStopping counter: 10 out of 100



Epoch 97: MAE/CA 4.069860: : 430it [00:31, 13.60it/s]
40it [00:01, 24.23it/s]
Epoch 98: MAE/CA 3.860284: : 430it [00:30, 13.88it/s]
40it [00:01, 24.47it/s]
Epoch 99: MAE/CA 3.690528: : 430it [00:31, 13.71it/s]
40it [00:01, 23.80it/s]

EarlyStopping counter: 1 out of 100



Epoch 100: MAE/CA 3.786409: : 430it [00:31, 13.79it/s]
40it [00:01, 23.73it/s]

EarlyStopping counter: 2 out of 100



Epoch 101: MAE/CA 3.906533: : 430it [00:31, 13.65it/s]
40it [00:01, 23.09it/s]

EarlyStopping counter: 3 out of 100



Epoch 102: MAE/CA 3.719567: : 430it [00:31, 13.69it/s]
40it [00:01, 24.58it/s]

EarlyStopping counter: 4 out of 100



Epoch 103: MAE/CA 3.870143: : 430it [00:31, 13.52it/s]
40it [00:01, 23.82it/s]

EarlyStopping counter: 5 out of 100



Epoch 104: MAE/CA 3.835426: : 430it [00:31, 13.75it/s]
40it [00:01, 25.15it/s]

EarlyStopping counter: 6 out of 100



Epoch 105: MAE/CA 3.792246: : 430it [00:31, 13.69it/s]
40it [00:01, 24.96it/s]

Epoch   105: reducing learning rate of group 0 to 2.8500e-04.
EarlyStopping counter: 7 out of 100



Epoch 106: MAE/CA 3.506047: : 430it [00:31, 13.73it/s]
40it [00:01, 24.78it/s]

EarlyStopping counter: 8 out of 100



Epoch 107: MAE/CA 3.635181: : 430it [00:30, 13.88it/s]
40it [00:01, 24.47it/s]
Epoch 108: MAE/CA 3.732739: : 430it [00:31, 13.75it/s]
40it [00:01, 24.35it/s]

EarlyStopping counter: 1 out of 100



Epoch 109: MAE/CA 3.653960: : 430it [00:31, 13.71it/s]
40it [00:01, 24.77it/s]

EarlyStopping counter: 2 out of 100



Epoch 110: MAE/CA 3.675824: : 430it [00:31, 13.62it/s]
40it [00:01, 24.42it/s]

EarlyStopping counter: 3 out of 100



Epoch 111: MAE/CA 3.507761: : 430it [00:31, 13.69it/s]
40it [00:01, 25.05it/s]

EarlyStopping counter: 4 out of 100



Epoch 112: MAE/CA 3.704929: : 430it [00:31, 13.65it/s]
40it [00:01, 24.07it/s]

Epoch   112: reducing learning rate of group 0 to 2.7075e-04.
EarlyStopping counter: 5 out of 100



Epoch 113: MAE/CA 3.473368: : 430it [00:31, 13.71it/s]
40it [00:01, 23.86it/s]
Epoch 114: MAE/CA 3.392414: : 430it [00:31, 13.72it/s]
40it [00:01, 24.40it/s]

EarlyStopping counter: 1 out of 100



Epoch 115: MAE/CA 3.523479: : 430it [00:30, 13.93it/s]
40it [00:01, 24.82it/s]

EarlyStopping counter: 2 out of 100



Epoch 116: MAE/CA 3.445903: : 430it [00:31, 13.73it/s]
40it [00:01, 24.28it/s]

EarlyStopping counter: 3 out of 100



Epoch 117: MAE/CA 3.388356: : 430it [00:31, 13.83it/s]
40it [00:01, 24.64it/s]

EarlyStopping counter: 4 out of 100



Epoch 118: MAE/CA 3.620874: : 430it [00:31, 13.73it/s]
40it [00:01, 25.84it/s]

EarlyStopping counter: 5 out of 100



Epoch 119: MAE/CA 3.356104: : 430it [00:31, 13.69it/s]
40it [00:01, 24.05it/s]

EarlyStopping counter: 6 out of 100



Epoch 120: MAE/CA 3.318996: : 430it [00:31, 13.68it/s]
40it [00:01, 21.87it/s]

EarlyStopping counter: 7 out of 100



Epoch 121: MAE/CA 3.597140: : 430it [00:31, 13.71it/s]
40it [00:01, 24.01it/s]

EarlyStopping counter: 8 out of 100



Epoch 122: MAE/CA 3.295219: : 430it [00:31, 13.53it/s]
40it [00:01, 23.48it/s]

EarlyStopping counter: 9 out of 100



Epoch 123: MAE/CA 3.484341: : 430it [00:31, 13.84it/s]
40it [00:01, 24.41it/s]

EarlyStopping counter: 10 out of 100



Epoch 124: MAE/CA 3.214213: : 430it [00:31, 13.81it/s]
40it [00:01, 24.10it/s]
Epoch 125: MAE/CA 3.302120: : 430it [00:31, 13.81it/s]
40it [00:01, 24.77it/s]
Epoch 126: MAE/CA 3.405917: : 430it [00:31, 13.65it/s]
40it [00:01, 22.95it/s]

EarlyStopping counter: 1 out of 100



Epoch 127: MAE/CA 3.308014: : 430it [00:31, 13.46it/s]
40it [00:01, 23.18it/s]

EarlyStopping counter: 2 out of 100



Epoch 128: MAE/CA 3.385019: : 430it [00:31, 13.71it/s]
40it [00:01, 22.76it/s]

EarlyStopping counter: 3 out of 100



Epoch 129: MAE/CA 3.370674: : 430it [00:18, 23.44it/s]
40it [00:01, 24.38it/s]

EarlyStopping counter: 4 out of 100



Epoch 130: MAE/CA 3.227556: : 430it [00:17, 24.37it/s]
40it [00:01, 24.49it/s]

Epoch   130: reducing learning rate of group 0 to 2.5721e-04.
EarlyStopping counter: 5 out of 100



Epoch 131: MAE/CA 3.209274: : 430it [00:17, 24.24it/s]
40it [00:01, 23.65it/s]
Epoch 132: MAE/CA 3.254900: : 430it [00:18, 23.68it/s]
40it [00:01, 23.93it/s]

EarlyStopping counter: 1 out of 100



Epoch 133: MAE/CA 3.114929: : 430it [00:17, 24.24it/s]
40it [00:01, 24.47it/s]

EarlyStopping counter: 2 out of 100



Epoch 134: MAE/CA 3.287216: : 430it [00:18, 23.81it/s]
40it [00:01, 24.21it/s]

EarlyStopping counter: 3 out of 100



Epoch 135: MAE/CA 3.061480: : 430it [00:16, 25.45it/s]
40it [00:01, 26.23it/s]

EarlyStopping counter: 4 out of 100



Epoch 136: MAE/CA 3.267881: : 430it [00:17, 25.09it/s]
40it [00:01, 24.45it/s]

EarlyStopping counter: 5 out of 100



Epoch 137: MAE/CA 3.187589: : 430it [00:18, 23.67it/s]
40it [00:01, 24.60it/s]

EarlyStopping counter: 6 out of 100



Epoch 138: MAE/CA 3.054163: : 430it [00:17, 25.25it/s]
40it [00:01, 24.54it/s]
Epoch 139: MAE/CA 3.137940: : 430it [00:16, 25.32it/s]
40it [00:01, 24.25it/s]

EarlyStopping counter: 1 out of 100



Epoch 140: MAE/CA 3.224710: : 430it [00:17, 25.10it/s]
40it [00:01, 24.01it/s]

EarlyStopping counter: 2 out of 100



Epoch 141: MAE/CA 3.036719: : 430it [00:17, 24.54it/s]
40it [00:01, 24.42it/s]

EarlyStopping counter: 3 out of 100



Epoch 142: MAE/CA 3.112848: : 430it [00:17, 25.06it/s]
40it [00:01, 24.39it/s]

EarlyStopping counter: 4 out of 100



Epoch 143: MAE/CA 3.201986: : 430it [00:17, 24.98it/s]
40it [00:01, 25.18it/s]

EarlyStopping counter: 5 out of 100



Epoch 144: MAE/CA 3.100089: : 430it [00:17, 24.77it/s]
40it [00:01, 24.94it/s]

EarlyStopping counter: 6 out of 100



Epoch 145: MAE/CA 3.063809: : 430it [00:17, 24.56it/s]
40it [00:01, 23.25it/s]

EarlyStopping counter: 7 out of 100



Epoch 146: MAE/CA 3.059191: : 430it [00:17, 24.46it/s]
40it [00:01, 24.67it/s]

EarlyStopping counter: 8 out of 100



Epoch 147: MAE/CA 3.067408: : 430it [00:16, 25.32it/s]
40it [00:01, 24.37it/s]

Epoch   147: reducing learning rate of group 0 to 2.4435e-04.
EarlyStopping counter: 9 out of 100



Epoch 148: MAE/CA 3.025206: : 430it [00:17, 25.03it/s]
40it [00:01, 24.20it/s]

EarlyStopping counter: 10 out of 100



Epoch 149: MAE/CA 2.963595: : 430it [00:17, 25.12it/s]
40it [00:01, 24.85it/s]

EarlyStopping counter: 11 out of 100



Epoch 150: MAE/CA 3.023752: : 430it [00:17, 24.64it/s]
40it [00:01, 24.76it/s]

EarlyStopping counter: 12 out of 100



Epoch 151: MAE/CA 2.951676: : 430it [00:17, 24.88it/s]
40it [00:01, 23.57it/s]

EarlyStopping counter: 13 out of 100



Epoch 152: MAE/CA 3.014209: : 430it [00:17, 25.13it/s]
40it [00:01, 23.89it/s]

EarlyStopping counter: 14 out of 100



Epoch 153: MAE/CA 2.929643: : 430it [00:16, 25.31it/s]
40it [00:01, 25.15it/s]

EarlyStopping counter: 15 out of 100



Epoch 154: MAE/CA 3.050562: : 430it [00:16, 25.32it/s]
40it [00:01, 24.65it/s]

EarlyStopping counter: 16 out of 100



Epoch 155: MAE/CA 2.929206: : 430it [00:17, 23.98it/s]
40it [00:01, 24.53it/s]

EarlyStopping counter: 17 out of 100



Epoch 156: MAE/CA 3.020544: : 430it [00:18, 23.53it/s]
40it [00:01, 26.00it/s]

EarlyStopping counter: 18 out of 100



Epoch 157: MAE/CA 2.907901: : 430it [00:18, 23.62it/s]
40it [00:01, 23.23it/s]
Epoch 158: MAE/CA 2.901098: : 430it [00:18, 23.45it/s]
40it [00:01, 22.23it/s]

EarlyStopping counter: 1 out of 100



Epoch 159: MAE/CA 2.845489: : 430it [00:18, 23.62it/s]
40it [00:01, 24.09it/s]

EarlyStopping counter: 2 out of 100



Epoch 160: MAE/CA 2.932276: : 430it [00:18, 23.84it/s]
40it [00:01, 23.88it/s]

EarlyStopping counter: 3 out of 100



Epoch 161: MAE/CA 2.847161: : 430it [00:17, 24.21it/s]
40it [00:01, 24.89it/s]

EarlyStopping counter: 4 out of 100



Epoch 162: MAE/CA 2.918435: : 430it [00:17, 24.41it/s]
40it [00:01, 25.04it/s]

EarlyStopping counter: 5 out of 100



Epoch 163: MAE/CA 2.966485: : 430it [00:16, 25.33it/s]
40it [00:01, 25.70it/s]

EarlyStopping counter: 6 out of 100



Epoch 164: MAE/CA 2.913134: : 430it [00:17, 24.92it/s]
40it [00:01, 23.85it/s]

EarlyStopping counter: 7 out of 100



Epoch 165: MAE/CA 2.749992: : 430it [00:17, 24.63it/s]
40it [00:01, 25.00it/s]

EarlyStopping counter: 8 out of 100



Epoch 166: MAE/CA 3.046068: : 430it [00:17, 25.23it/s]
40it [00:01, 25.00it/s]
Epoch 167: MAE/CA 2.896968: : 430it [00:17, 25.15it/s]
40it [00:01, 24.06it/s]

EarlyStopping counter: 1 out of 100



Epoch 168: MAE/CA 2.810080: : 430it [00:17, 24.78it/s]
40it [00:01, 25.23it/s]

EarlyStopping counter: 2 out of 100



Epoch 169: MAE/CA 2.731114: : 430it [00:17, 25.21it/s]
40it [00:01, 25.84it/s]

EarlyStopping counter: 3 out of 100



Epoch 170: MAE/CA 2.890232: : 430it [00:17, 25.15it/s]
40it [00:01, 25.54it/s]

EarlyStopping counter: 4 out of 100



Epoch 171: MAE/CA 2.856817: : 430it [00:16, 25.34it/s]
40it [00:01, 24.22it/s]

EarlyStopping counter: 5 out of 100



Epoch 172: MAE/CA 2.794129: : 430it [00:17, 24.64it/s]
40it [00:01, 24.73it/s]

EarlyStopping counter: 6 out of 100



Epoch 173: MAE/CA 2.840183: : 430it [00:17, 24.49it/s]
40it [00:01, 24.76it/s]

EarlyStopping counter: 7 out of 100



Epoch 174: MAE/CA 2.789772: : 430it [00:17, 24.78it/s]
40it [00:01, 24.18it/s]

EarlyStopping counter: 8 out of 100



Epoch 175: MAE/CA 2.829245: : 430it [00:17, 25.25it/s]
40it [00:01, 24.40it/s]

Epoch   175: reducing learning rate of group 0 to 2.3213e-04.
EarlyStopping counter: 9 out of 100



Epoch 176: MAE/CA 2.713301: : 430it [00:17, 25.18it/s]
40it [00:02, 14.14it/s]
Epoch 177: MAE/CA 2.707118: : 430it [00:17, 24.92it/s]
40it [00:01, 25.27it/s]

EarlyStopping counter: 1 out of 100



Epoch 178: MAE/CA 2.862145: : 430it [00:16, 25.30it/s]
40it [00:01, 24.78it/s]

EarlyStopping counter: 2 out of 100



Epoch 179: MAE/CA 2.729830: : 430it [00:17, 25.03it/s]
40it [00:01, 24.82it/s]

EarlyStopping counter: 3 out of 100



Epoch 180: MAE/CA 2.687071: : 430it [00:17, 25.03it/s]
40it [00:01, 24.95it/s]

EarlyStopping counter: 4 out of 100



Epoch 181: MAE/CA 2.786810: : 430it [00:16, 25.30it/s]
40it [00:01, 24.48it/s]

EarlyStopping counter: 5 out of 100



Epoch 182: MAE/CA 2.724288: : 430it [00:17, 25.14it/s]
40it [00:01, 25.61it/s]

EarlyStopping counter: 6 out of 100



Epoch 183: MAE/CA 2.759600: : 430it [00:16, 25.31it/s]
40it [00:02, 14.21it/s]

EarlyStopping counter: 7 out of 100



Epoch 184: MAE/CA 2.682467: : 430it [00:17, 24.67it/s]
40it [00:02, 14.83it/s]

EarlyStopping counter: 8 out of 100



Epoch 185: MAE/CA 2.722952: : 430it [00:16, 25.52it/s]
40it [00:02, 14.12it/s]


EarlyStopping counter: 9 out of 100


Epoch 186: MAE/CA 2.675010: : 430it [00:16, 25.58it/s]
40it [00:01, 25.56it/s]

EarlyStopping counter: 10 out of 100



Epoch 187: MAE/CA 2.635967: : 430it [00:17, 24.99it/s]
40it [00:01, 25.72it/s]

EarlyStopping counter: 11 out of 100



Epoch 188: MAE/CA 2.667073: : 430it [00:16, 25.51it/s]
40it [00:02, 15.84it/s]

EarlyStopping counter: 12 out of 100



Epoch 189: MAE/CA 2.719820: : 430it [00:17, 25.24it/s]
40it [00:02, 14.98it/s]
Epoch 190: MAE/CA 2.704260: : 430it [00:17, 24.89it/s]
40it [00:02, 15.66it/s]

EarlyStopping counter: 1 out of 100



Epoch 191: MAE/CA 2.624062: : 430it [00:17, 25.13it/s]
40it [00:02, 14.54it/s]

EarlyStopping counter: 2 out of 100



Epoch 192: MAE/CA 2.662725: : 430it [00:16, 25.32it/s]
40it [00:02, 14.10it/s]

EarlyStopping counter: 3 out of 100



Epoch 193: MAE/CA 2.742556: : 430it [00:16, 25.43it/s]
40it [00:02, 14.04it/s]

EarlyStopping counter: 4 out of 100



Epoch 194: MAE/CA 2.599143: : 430it [00:16, 25.50it/s]
40it [00:02, 14.31it/s]

EarlyStopping counter: 5 out of 100



Epoch 195: MAE/CA 2.614086: : 430it [00:17, 24.05it/s]
40it [00:02, 14.66it/s]

EarlyStopping counter: 6 out of 100



Epoch 196: MAE/CA 2.608348: : 430it [00:16, 25.61it/s]
40it [00:02, 14.31it/s]

EarlyStopping counter: 7 out of 100



Epoch 197: MAE/CA 2.696375: : 430it [00:16, 25.38it/s]
40it [00:02, 14.86it/s]

EarlyStopping counter: 8 out of 100



Epoch 198: MAE/CA 2.679876: : 430it [00:17, 25.10it/s]
40it [00:02, 14.16it/s]

EarlyStopping counter: 9 out of 100



Epoch 199: MAE/CA 2.585896: : 430it [00:17, 25.25it/s]
40it [00:02, 14.13it/s]

EarlyStopping counter: 10 out of 100



Epoch 200: MAE/CA 2.603906: : 430it [00:17, 25.20it/s]
40it [00:02, 14.19it/s]

EarlyStopping counter: 11 out of 100



Epoch 201: MAE/CA 2.696156: : 430it [00:17, 24.81it/s]
40it [00:02, 14.88it/s]

EarlyStopping counter: 12 out of 100



Epoch 202: MAE/CA 2.555949: : 430it [00:17, 25.12it/s]
40it [00:02, 14.60it/s]

EarlyStopping counter: 13 out of 100



Epoch 203: MAE/CA 2.526095: : 430it [00:17, 25.09it/s]
40it [00:02, 14.88it/s]

EarlyStopping counter: 14 out of 100



Epoch 204: MAE/CA 2.617174: : 430it [00:16, 25.32it/s]
40it [00:02, 15.25it/s]

EarlyStopping counter: 15 out of 100



Epoch 205: MAE/CA 2.656280: : 430it [00:16, 25.40it/s]
40it [00:02, 14.26it/s]

EarlyStopping counter: 16 out of 100



Epoch 206: MAE/CA 2.591353: : 430it [00:17, 25.22it/s]
40it [00:02, 14.62it/s]

EarlyStopping counter: 17 out of 100



Epoch 207: MAE/CA 2.652307: : 430it [00:16, 25.60it/s]
40it [00:02, 15.00it/s]

EarlyStopping counter: 18 out of 100



Epoch 208: MAE/CA 2.556341: : 430it [00:16, 25.51it/s]
40it [00:02, 15.25it/s]

EarlyStopping counter: 19 out of 100



Epoch 209: MAE/CA 2.580421: : 430it [00:16, 25.49it/s]
40it [00:02, 14.44it/s]

Epoch   209: reducing learning rate of group 0 to 2.2053e-04.
EarlyStopping counter: 20 out of 100



Epoch 210: MAE/CA 2.575459: : 430it [00:16, 25.50it/s]
40it [00:02, 14.85it/s]

EarlyStopping counter: 21 out of 100



Epoch 211: MAE/CA 2.515465: : 430it [00:16, 25.57it/s]
40it [00:02, 14.60it/s]

EarlyStopping counter: 22 out of 100



Epoch 212: MAE/CA 2.430225: : 430it [00:16, 25.43it/s]
40it [00:02, 14.30it/s]

EarlyStopping counter: 23 out of 100



Epoch 213: MAE/CA 2.614027: : 430it [00:16, 25.53it/s]
40it [00:02, 14.45it/s]
Epoch 214: MAE/CA 2.474699: : 430it [00:16, 25.40it/s]
40it [00:02, 14.70it/s]

EarlyStopping counter: 1 out of 100



Epoch 215: MAE/CA 2.460871: : 430it [00:17, 24.99it/s]
40it [00:02, 14.80it/s]

EarlyStopping counter: 2 out of 100



Epoch 216: MAE/CA 2.500460: : 430it [00:17, 25.26it/s]
40it [00:02, 15.50it/s]

EarlyStopping counter: 3 out of 100



Epoch 217: MAE/CA 2.402102: : 430it [00:17, 24.43it/s]
40it [00:02, 14.21it/s]

EarlyStopping counter: 4 out of 100



Epoch 218: MAE/CA 2.603653: : 430it [00:17, 25.17it/s]
40it [00:03, 11.61it/s]

EarlyStopping counter: 5 out of 100



Epoch 219: MAE/CA 2.491737: : 430it [00:16, 25.41it/s]
40it [00:02, 14.23it/s]

EarlyStopping counter: 6 out of 100



Epoch 220: MAE/CA 2.548538: : 430it [00:17, 25.22it/s]
40it [00:02, 14.64it/s]

EarlyStopping counter: 7 out of 100



Epoch 221: MAE/CA 2.427133: : 430it [00:17, 25.11it/s]
40it [00:02, 15.25it/s]

EarlyStopping counter: 8 out of 100



Epoch 222: MAE/CA 2.489855: : 430it [00:16, 25.46it/s]
40it [00:02, 14.53it/s]
Epoch 223: MAE/CA 2.602513: : 430it [00:17, 25.05it/s]
40it [00:02, 14.51it/s]

Epoch   223: reducing learning rate of group 0 to 2.0950e-04.
EarlyStopping counter: 1 out of 100



Epoch 224: MAE/CA 2.428069: : 430it [00:16, 25.64it/s]
40it [00:02, 14.95it/s]

EarlyStopping counter: 2 out of 100



Epoch 225: MAE/CA 2.399656: : 430it [00:16, 25.58it/s]
40it [00:02, 14.38it/s]

EarlyStopping counter: 3 out of 100



Epoch 226: MAE/CA 2.340371: : 430it [00:17, 25.05it/s]
40it [00:02, 14.38it/s]

EarlyStopping counter: 4 out of 100



Epoch 227: MAE/CA 2.458892: : 430it [00:16, 25.40it/s]
40it [00:02, 15.05it/s]

EarlyStopping counter: 5 out of 100



Epoch 228: MAE/CA 2.425190: : 430it [00:17, 25.07it/s]
40it [00:02, 15.24it/s]

EarlyStopping counter: 6 out of 100



Epoch 229: MAE/CA 2.440820: : 430it [00:16, 25.32it/s]
40it [00:02, 13.96it/s]

EarlyStopping counter: 7 out of 100



Epoch 230: MAE/CA 2.378590: : 430it [00:16, 25.73it/s]
40it [00:02, 14.75it/s]

EarlyStopping counter: 8 out of 100



Epoch 231: MAE/CA 2.402764: : 430it [00:16, 25.38it/s]
40it [00:02, 14.83it/s]

EarlyStopping counter: 9 out of 100



Epoch 232: MAE/CA 2.448604: : 430it [00:16, 25.42it/s]
40it [00:02, 14.53it/s]

Epoch   232: reducing learning rate of group 0 to 1.9903e-04.
EarlyStopping counter: 10 out of 100



Epoch 233: MAE/CA 2.309558: : 430it [00:16, 25.49it/s]
40it [00:02, 15.34it/s]

EarlyStopping counter: 11 out of 100



Epoch 234: MAE/CA 2.362324: : 430it [00:16, 25.63it/s]
40it [00:02, 15.04it/s]

EarlyStopping counter: 12 out of 100



Epoch 235: MAE/CA 2.295194: : 430it [00:17, 24.83it/s]
40it [00:02, 14.43it/s]

EarlyStopping counter: 13 out of 100



Epoch 236: MAE/CA 2.426047: : 430it [00:16, 25.73it/s]
40it [00:02, 14.43it/s]
Epoch 237: MAE/CA 2.346087: : 430it [00:16, 25.50it/s]
40it [00:02, 14.64it/s]

EarlyStopping counter: 1 out of 100



Epoch 238: MAE/CA 2.322334: : 430it [00:17, 25.23it/s]
40it [00:02, 14.69it/s]

EarlyStopping counter: 2 out of 100



Epoch 239: MAE/CA 2.390645: : 430it [00:16, 25.35it/s]
40it [00:02, 14.87it/s]

EarlyStopping counter: 3 out of 100



Epoch 240: MAE/CA 2.364107: : 430it [00:16, 25.31it/s]
40it [00:02, 14.85it/s]

EarlyStopping counter: 4 out of 100



Epoch 241: MAE/CA 2.404588: : 430it [00:17, 24.82it/s]
40it [00:02, 14.34it/s]

Epoch   241: reducing learning rate of group 0 to 1.8907e-04.
EarlyStopping counter: 5 out of 100



Epoch 242: MAE/CA 2.202068: : 430it [00:17, 25.16it/s]
40it [00:02, 13.97it/s]

EarlyStopping counter: 6 out of 100



Epoch 243: MAE/CA 2.271748: : 430it [00:16, 25.64it/s]
40it [00:02, 14.63it/s]

EarlyStopping counter: 7 out of 100



Epoch 244: MAE/CA 2.300096: : 430it [00:16, 25.35it/s]
40it [00:02, 14.65it/s]

EarlyStopping counter: 8 out of 100



Epoch 245: MAE/CA 2.356076: : 430it [00:16, 25.73it/s]
40it [00:03, 11.69it/s]

EarlyStopping counter: 9 out of 100



Epoch 246: MAE/CA 2.212048: : 430it [00:15, 27.08it/s]
40it [00:01, 26.11it/s]
Epoch 247: MAE/CA 2.289263: : 430it [00:16, 26.75it/s]
40it [00:01, 25.67it/s]

EarlyStopping counter: 1 out of 100



Epoch 248: MAE/CA 2.369892: : 430it [00:16, 26.72it/s]
40it [00:01, 25.45it/s]

Epoch   248: reducing learning rate of group 0 to 1.7962e-04.
EarlyStopping counter: 2 out of 100



Epoch 249: MAE/CA 2.208396: : 430it [00:15, 27.38it/s]
40it [00:01, 24.74it/s]
Epoch 250: MAE/CA 2.180564: : 430it [00:15, 27.23it/s]
40it [00:01, 24.93it/s]

EarlyStopping counter: 1 out of 100



Epoch 251: MAE/CA 2.260451: : 430it [00:15, 27.07it/s]
40it [00:01, 24.56it/s]

EarlyStopping counter: 2 out of 100



Epoch 252: MAE/CA 2.244921: : 430it [00:16, 26.80it/s]
40it [00:01, 24.52it/s]

EarlyStopping counter: 3 out of 100



Epoch 253: MAE/CA 2.278845: : 430it [00:15, 27.02it/s]
40it [00:01, 24.96it/s]

EarlyStopping counter: 4 out of 100



Epoch 254: MAE/CA 2.254585: : 430it [00:16, 26.83it/s]
40it [00:01, 24.40it/s]

EarlyStopping counter: 5 out of 100



Epoch 255: MAE/CA 2.203030: : 430it [00:16, 26.58it/s]
40it [00:01, 25.06it/s]

EarlyStopping counter: 6 out of 100



Epoch 256: MAE/CA 2.360352: : 430it [00:15, 26.89it/s]
40it [00:01, 24.60it/s]

Epoch   256: reducing learning rate of group 0 to 1.7064e-04.
EarlyStopping counter: 7 out of 100



Epoch 257: MAE/CA 2.159711: : 430it [00:15, 26.98it/s]
40it [00:01, 24.44it/s]

EarlyStopping counter: 8 out of 100



Epoch 258: MAE/CA 2.186368: : 430it [00:16, 26.87it/s]
40it [00:01, 24.89it/s]

EarlyStopping counter: 9 out of 100



Epoch 259: MAE/CA 2.188439: : 430it [00:15, 26.95it/s]
40it [00:01, 25.17it/s]

EarlyStopping counter: 10 out of 100



Epoch 260: MAE/CA 2.104491: : 430it [00:15, 26.99it/s]
40it [00:01, 24.64it/s]

EarlyStopping counter: 11 out of 100



Epoch 261: MAE/CA 2.242811: : 430it [00:16, 26.78it/s]
40it [00:01, 25.46it/s]

EarlyStopping counter: 12 out of 100



Epoch 262: MAE/CA 2.139707: : 430it [00:16, 26.76it/s]
40it [00:01, 24.04it/s]

EarlyStopping counter: 13 out of 100



Epoch 263: MAE/CA 2.236565: : 430it [00:15, 26.88it/s]
40it [00:01, 25.11it/s]

EarlyStopping counter: 14 out of 100



Epoch 264: MAE/CA 2.158317: : 430it [00:16, 26.77it/s]
40it [00:01, 25.30it/s]

EarlyStopping counter: 15 out of 100



Epoch 265: MAE/CA 2.205991: : 430it [00:16, 26.81it/s]
40it [00:01, 25.58it/s]

EarlyStopping counter: 16 out of 100



Epoch 266: MAE/CA 2.157681: : 430it [00:15, 27.00it/s]
40it [00:01, 26.05it/s]

Epoch   266: reducing learning rate of group 0 to 1.6211e-04.
EarlyStopping counter: 17 out of 100



Epoch 267: MAE/CA 2.082505: : 430it [00:16, 26.40it/s]
40it [00:01, 27.67it/s]

EarlyStopping counter: 18 out of 100



Epoch 268: MAE/CA 2.141266: : 430it [00:16, 26.45it/s]
40it [00:01, 25.20it/s]
Epoch 269: MAE/CA 2.104575: : 430it [00:16, 26.60it/s]
40it [00:01, 25.34it/s]

EarlyStopping counter: 1 out of 100



Epoch 270: MAE/CA 2.148257: : 430it [00:16, 26.86it/s]
40it [00:01, 27.28it/s]

EarlyStopping counter: 2 out of 100



Epoch 271: MAE/CA 2.093110: : 430it [00:16, 26.71it/s]
40it [00:01, 24.81it/s]

EarlyStopping counter: 3 out of 100



Epoch 272: MAE/CA 2.133964: : 430it [00:16, 26.81it/s]
40it [00:01, 25.16it/s]

EarlyStopping counter: 4 out of 100



Epoch 273: MAE/CA 2.154788: : 430it [00:16, 26.78it/s]
40it [00:01, 24.61it/s]

Epoch   273: reducing learning rate of group 0 to 1.5400e-04.
EarlyStopping counter: 5 out of 100



Epoch 274: MAE/CA 2.153877: : 430it [00:16, 26.57it/s]
40it [00:01, 25.53it/s]
Epoch 275: MAE/CA 2.025373: : 430it [00:15, 26.93it/s]
40it [00:01, 24.33it/s]

EarlyStopping counter: 1 out of 100



Epoch 276: MAE/CA 2.133260: : 430it [00:16, 26.61it/s]
40it [00:01, 24.77it/s]
Epoch 277: MAE/CA 2.090588: : 430it [00:16, 26.72it/s]
40it [00:01, 24.66it/s]

EarlyStopping counter: 1 out of 100



Epoch 278: MAE/CA 2.048921: : 430it [00:15, 26.91it/s]
40it [00:01, 25.36it/s]

EarlyStopping counter: 2 out of 100



Epoch 279: MAE/CA 2.027895: : 430it [00:15, 26.90it/s]
40it [00:01, 24.51it/s]
Epoch 280: MAE/CA 2.026287: : 430it [00:15, 26.99it/s]
40it [00:01, 23.95it/s]

EarlyStopping counter: 1 out of 100



Epoch 281: MAE/CA 2.138320: : 430it [00:15, 27.34it/s]
40it [00:01, 26.21it/s]

Epoch   281: reducing learning rate of group 0 to 1.4630e-04.
EarlyStopping counter: 2 out of 100



Epoch 282: MAE/CA 2.035783: : 430it [00:15, 26.88it/s]
40it [00:01, 25.08it/s]

EarlyStopping counter: 3 out of 100



Epoch 283: MAE/CA 2.013393: : 430it [00:15, 27.01it/s]
40it [00:01, 25.53it/s]

EarlyStopping counter: 4 out of 100



Epoch 284: MAE/CA 2.026752: : 430it [00:15, 26.97it/s]
40it [00:01, 25.21it/s]

EarlyStopping counter: 5 out of 100



Epoch 285: MAE/CA 2.090559: : 430it [00:16, 26.75it/s]
40it [00:01, 24.31it/s]
Epoch 286: MAE/CA 2.043931: : 430it [00:15, 27.10it/s]
40it [00:01, 24.89it/s]

EarlyStopping counter: 1 out of 100



Epoch 287: MAE/CA 1.995539: : 430it [00:15, 26.89it/s]
40it [00:01, 23.67it/s]

EarlyStopping counter: 2 out of 100



Epoch 288: MAE/CA 2.013283: : 430it [00:15, 27.09it/s]
40it [00:01, 24.86it/s]

EarlyStopping counter: 3 out of 100



Epoch 289: MAE/CA 2.046038: : 430it [00:15, 27.26it/s]
40it [00:01, 25.13it/s]

EarlyStopping counter: 4 out of 100



Epoch 290: MAE/CA 1.989990: : 430it [00:15, 27.28it/s]
40it [00:01, 24.50it/s]

EarlyStopping counter: 5 out of 100



Epoch 291: MAE/CA 2.055014: : 430it [00:15, 26.96it/s]
40it [00:01, 23.94it/s]

EarlyStopping counter: 6 out of 100



Epoch 292: MAE/CA 1.991931: : 430it [00:16, 26.86it/s]
40it [00:01, 24.93it/s]

EarlyStopping counter: 7 out of 100



Epoch 293: MAE/CA 2.021488: : 430it [00:15, 27.26it/s]
40it [00:01, 25.00it/s]

EarlyStopping counter: 8 out of 100



Epoch 294: MAE/CA 2.015802: : 430it [00:16, 26.78it/s]
40it [00:01, 22.78it/s]

EarlyStopping counter: 9 out of 100



Epoch 295: MAE/CA 1.959959: : 430it [00:16, 26.44it/s]
40it [00:01, 25.24it/s]

EarlyStopping counter: 10 out of 100



Epoch 296: MAE/CA 2.091736: : 430it [00:17, 24.96it/s]
40it [00:01, 24.43it/s]

EarlyStopping counter: 11 out of 100



Epoch 297: MAE/CA 1.975742: : 430it [00:16, 26.19it/s]
40it [00:01, 24.67it/s]

EarlyStopping counter: 12 out of 100



Epoch 298: MAE/CA 2.031029: : 430it [00:15, 26.90it/s]
40it [00:01, 25.22it/s]

EarlyStopping counter: 13 out of 100



Epoch 299: MAE/CA 1.978921: : 430it [00:15, 26.91it/s]
40it [00:01, 25.13it/s]

EarlyStopping counter: 14 out of 100



Epoch 300: MAE/CA 2.014600: : 430it [00:15, 27.01it/s]
40it [00:01, 24.99it/s]

EarlyStopping counter: 15 out of 100



Epoch 301: MAE/CA 1.953140: : 430it [00:15, 26.94it/s]
40it [00:01, 24.62it/s]

EarlyStopping counter: 16 out of 100



Epoch 302: MAE/CA 1.978627: : 430it [00:15, 26.92it/s]
40it [00:01, 23.38it/s]

EarlyStopping counter: 17 out of 100



Epoch 303: MAE/CA 1.973945: : 430it [00:16, 26.78it/s]
40it [00:01, 24.52it/s]

EarlyStopping counter: 18 out of 100



Epoch 304: MAE/CA 2.003341: : 430it [00:16, 25.57it/s]
40it [00:01, 23.50it/s]

EarlyStopping counter: 19 out of 100



Epoch 305: MAE/CA 1.998891: : 430it [00:16, 26.21it/s]
40it [00:01, 23.25it/s]

EarlyStopping counter: 20 out of 100



Epoch 306: MAE/CA 1.979557: : 430it [00:16, 26.29it/s]
40it [00:01, 25.33it/s]

EarlyStopping counter: 21 out of 100



Epoch 307: MAE/CA 1.975441: : 430it [00:17, 24.71it/s]
40it [00:01, 23.88it/s]

Epoch   307: reducing learning rate of group 0 to 1.3899e-04.
EarlyStopping counter: 22 out of 100



Epoch 308: MAE/CA 1.940856: : 430it [00:16, 25.72it/s]
40it [00:01, 25.80it/s]

EarlyStopping counter: 23 out of 100



Epoch 309: MAE/CA 1.981035: : 430it [00:16, 26.58it/s]
40it [00:01, 25.30it/s]

EarlyStopping counter: 24 out of 100



Epoch 310: MAE/CA 1.975952: : 430it [00:16, 25.95it/s]
40it [00:01, 27.20it/s]

EarlyStopping counter: 25 out of 100



Epoch 311: MAE/CA 1.981911: : 430it [00:16, 26.87it/s]
40it [00:01, 25.04it/s]

EarlyStopping counter: 26 out of 100



Epoch 312: MAE/CA 1.950753: : 430it [00:16, 26.80it/s]
40it [00:01, 26.70it/s]

EarlyStopping counter: 27 out of 100



Epoch 313: MAE/CA 1.919905: : 430it [00:16, 26.57it/s]
40it [00:01, 25.14it/s]

EarlyStopping counter: 28 out of 100



Epoch 314: MAE/CA 1.947467: : 430it [00:16, 26.59it/s]
40it [00:01, 25.17it/s]
Epoch 315: MAE/CA 1.897072: : 430it [00:16, 26.59it/s]
40it [00:01, 25.20it/s]

EarlyStopping counter: 1 out of 100



Epoch 316: MAE/CA 1.926393: : 430it [00:16, 26.60it/s]
40it [00:01, 24.38it/s]

EarlyStopping counter: 2 out of 100



Epoch 317: MAE/CA 2.078797: : 430it [00:16, 26.77it/s]
40it [00:01, 23.99it/s]

EarlyStopping counter: 3 out of 100



Epoch 318: MAE/CA 1.875763: : 430it [00:16, 26.37it/s]
40it [00:01, 24.95it/s]

EarlyStopping counter: 4 out of 100



Epoch 319: MAE/CA 1.904916: : 430it [00:16, 26.85it/s]
40it [00:01, 25.28it/s]

EarlyStopping counter: 5 out of 100



Epoch 320: MAE/CA 1.947269: : 430it [00:15, 27.15it/s]
40it [00:01, 25.28it/s]

EarlyStopping counter: 6 out of 100



Epoch 321: MAE/CA 1.909533: : 430it [00:15, 26.92it/s]
40it [00:01, 25.88it/s]

EarlyStopping counter: 7 out of 100



Epoch 322: MAE/CA 1.928878: : 430it [00:15, 26.88it/s]
40it [00:01, 24.23it/s]

EarlyStopping counter: 8 out of 100



Epoch 323: MAE/CA 1.952372: : 430it [00:16, 26.84it/s]
40it [00:01, 24.41it/s]

EarlyStopping counter: 9 out of 100



Epoch 324: MAE/CA 1.968314: : 430it [00:16, 26.77it/s]
40it [00:01, 24.34it/s]

Epoch   324: reducing learning rate of group 0 to 1.3204e-04.
EarlyStopping counter: 10 out of 100



Epoch 325: MAE/CA 1.870881: : 430it [00:16, 26.76it/s]
40it [00:01, 25.72it/s]

EarlyStopping counter: 11 out of 100



Epoch 326: MAE/CA 1.906882: : 430it [00:19, 22.00it/s]
40it [00:02, 19.71it/s]
Epoch 327: MAE/CA 1.845772: : 430it [00:16, 25.46it/s]
40it [00:01, 24.52it/s]

EarlyStopping counter: 1 out of 100



Epoch 328: MAE/CA 1.915860: : 430it [00:16, 26.67it/s]
40it [00:01, 24.33it/s]

EarlyStopping counter: 2 out of 100



Epoch 329: MAE/CA 1.859237: : 430it [00:16, 26.53it/s]
40it [00:01, 24.92it/s]

EarlyStopping counter: 3 out of 100



Epoch 330: MAE/CA 1.880798: : 430it [00:16, 25.81it/s]
40it [00:01, 23.91it/s]

EarlyStopping counter: 4 out of 100



Epoch 331: MAE/CA 1.962586: : 430it [00:20, 20.85it/s]
40it [00:01, 20.89it/s]

EarlyStopping counter: 5 out of 100



Epoch 332: MAE/CA 1.853800: : 430it [00:22, 18.85it/s]
40it [00:01, 21.75it/s]

EarlyStopping counter: 6 out of 100



Epoch 333: MAE/CA 1.893738: : 430it [00:22, 18.91it/s]
40it [00:02, 19.50it/s]

Epoch   333: reducing learning rate of group 0 to 1.2544e-04.
EarlyStopping counter: 7 out of 100



Epoch 334: MAE/CA 1.851326: : 430it [00:23, 18.44it/s]
40it [00:01, 22.96it/s]

EarlyStopping counter: 8 out of 100



Epoch 335: MAE/CA 1.820022: : 430it [00:22, 19.32it/s]
40it [00:01, 22.00it/s]

EarlyStopping counter: 9 out of 100



Epoch 336: MAE/CA 1.825808: : 430it [00:22, 19.47it/s]
40it [00:01, 23.88it/s]

EarlyStopping counter: 10 out of 100



Epoch 337: MAE/CA 1.835091: : 430it [00:22, 19.25it/s]
40it [00:01, 21.26it/s]
Epoch 338: MAE/CA 1.877686: : 430it [00:22, 19.46it/s]
40it [00:01, 22.51it/s]

EarlyStopping counter: 1 out of 100



Epoch 339: MAE/CA 1.877391: : 430it [00:22, 19.44it/s]
40it [00:01, 22.17it/s]

EarlyStopping counter: 2 out of 100



Epoch 340: MAE/CA 1.879940: : 430it [00:22, 19.07it/s]
40it [00:01, 21.51it/s]

EarlyStopping counter: 3 out of 100



Epoch 341: MAE/CA 1.835001: : 430it [00:22, 18.93it/s]
40it [00:01, 20.83it/s]

Epoch   341: reducing learning rate of group 0 to 1.1916e-04.
EarlyStopping counter: 4 out of 100



Epoch 342: MAE/CA 1.798097: : 430it [00:22, 19.27it/s]
40it [00:01, 20.12it/s]

EarlyStopping counter: 5 out of 100



Epoch 343: MAE/CA 1.890413: : 430it [00:23, 18.68it/s]
40it [00:01, 21.47it/s]

EarlyStopping counter: 6 out of 100



Epoch 344: MAE/CA 1.807644: : 430it [00:22, 19.44it/s]
40it [00:02, 18.96it/s]

EarlyStopping counter: 7 out of 100



Epoch 345: MAE/CA 1.814641: : 430it [00:23, 18.42it/s]
40it [00:01, 22.07it/s]

EarlyStopping counter: 8 out of 100



Epoch 346: MAE/CA 1.825876: : 430it [00:22, 19.05it/s]
40it [00:01, 20.60it/s]

EarlyStopping counter: 9 out of 100



Epoch 347: MAE/CA 1.773750: : 273it [00:14, 19.12it/s]

In [None]:
pbar = tqdm.tqdm(enumerate(test_loader, 0))
epoch_loss = 0
maes = []
for i, data in pbar:
    data.to(device)
    prediction = model(data)
    mae = (prediction.view(-1) - data.y[:, target_idx]).abs()
    maes += [mae.detach().cpu()]
maes = torch.cat(maes, dim=0)
mae = maes.mean().item()
print(mae, mae/chemical_accuracy[target_idx])
wandb.log({'TEST MAE':mae, 'TEST MAE/CA':mae/chemical_accuracy[target_idx]})
