In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import numpy as np
from torch.utils.data import Dataset, DataLoader
import glob
import wandb
import os
import torch.optim as optimizers
import dfs_code
from torch_geometric.data import InMemoryDataset, Data
import pickle
import torch
import torch.nn as nn
import tqdm
import copy
import pandas as pd
import torch.nn.functional as F
import sys
import yaml
import functools
from ml_collections import ConfigDict
from ogb.graphproppred import PygGraphPropPredDataset

sys.path = ['../../src'] + sys.path
from dfs_transformer import DFSCodeSeq2SeqFC, Deepchem2TorchGeometric, Trainer, to_cuda

Using backend: pytorch


In [3]:
def compute_roc(model, loader, evaluator):
    with torch.no_grad():
        preds = []
        ys = []
        for i, data in tqdm.tqdm(enumerate(testloader)):
            data = [to_cuda(d, device) for d in data]
            pred = model(*data[:-1])
            preds += [pred.cpu()]
            ys += [data[-1].cpu()]
        preds = torch.cat(preds, dim=0)
        ys = torch.cat(ys, dim=0)
        return evaluator.eval({'y_true':ys, 'y_pred':preds})['rocauc']

In [4]:
fname = '../../config/selfattn/finetune_ogb.yaml'
with open(fname) as file:
    config = ConfigDict(yaml.load(file, Loader=yaml.FullLoader))

In [5]:
config

accumulate_grads: 2
batch_size: 50
clip_gradient: 0.5
decay_factor: 0.8
es_improvement: 0.0
es_path: null
es_patience: 10
fingerprint: cls
gpu_id: 0
load_last: true
lr: 0.0003
lr_head: 0.003
lr_patience: 3
lr_pretrained: 0.0003
minimal_lr: 6.0e-08
n_classes: 349
n_epochs: 25
n_frozen: 5
path: ../../results/ogbn_mag/timeout1/
pretrained_class: DFSCodeSeq2SeqFC
pretrained_dir: null
pretrained_entity: dfstransformer
pretrained_model: rnd2min
pretrained_project: ogbn-mag
pretrained_yaml: null
require_min_dfs_code: false
seed: 123
strict: true
use_local_yaml: false
weight_decay: 0.1

In [6]:
config.pretrained_project = 'pubchem'
config.pretrained_model = 'rnd2min2-10M-euler'
config.es_period = 300
config.lr = 0.000003 # 0.000003 war gut
config.alpha = 0
config.fingerprint = "cls3"

In [7]:
config.require_min_dfs_code = False

In [8]:
onlyRandom = not config.require_min_dfs_code

In [9]:
mol_csv = pd.read_csv('../../datasets/ogbg_molhiv/mol.csv')

dataset = PygGraphPropPredDataset(name = "ogbg-molhiv") 
split_idx = dataset.get_idx_split() 

# check whether we get the correct splits

In [10]:
for split in ["train", "valid", "test"]:
    csv_labels = mol_csv["HIV_active"][split_idx[split].numpy()].to_numpy()
    ogb_labels = np.asarray([d.y.item() for d in dataset[split_idx[split]]])
    if (ogb_labels == csv_labels).sum() == len(ogb_labels):
        print("All %s labels are identical."%split)

All train labels are identical.
All valid labels are identical.
All test labels are identical.


In [11]:
train_smiles = mol_csv["smiles"][split_idx["train"].numpy()].to_numpy()
train_labels = mol_csv["HIV_active"][split_idx["train"].numpy()].to_numpy()
valid_smiles = mol_csv["smiles"][split_idx["valid"].numpy()].to_numpy()
valid_labels = mol_csv["HIV_active"][split_idx["valid"].numpy()].to_numpy()
test_smiles = mol_csv["smiles"][split_idx["test"].numpy()].to_numpy()
test_labels = mol_csv["HIV_active"][split_idx["test"].numpy()].to_numpy()

In [12]:
loaddir = "../../results/mymoleculenet_plus_features/hiv/1/" # ogbg uses other smiles than deepchem...
loaddir = None
train = Deepchem2TorchGeometric(train_smiles, train_labels, loaddir=loaddir, onlyRandom=onlyRandom)
valid = Deepchem2TorchGeometric(valid_smiles, valid_labels, loaddir=loaddir, onlyRandom=onlyRandom)
test = Deepchem2TorchGeometric(test_smiles, test_labels, loaddir=loaddir, onlyRandom=onlyRandom)

In [13]:
def collate_fn(dlist, alpha=config.alpha):
    node_batch = [] 
    edge_batch = []
    y_batch = []
    rnd_code_batch = []
    for d in dlist:
        node_batch += [d.node_features.clone()]
        edge_batch += [d.edge_features.clone()]
        rnd_code, rnd_index = dfs_code.rnd_dfs_code_from_torch_geometric(d, d.z.numpy().tolist(), 
                                                                         np.argmax(d.edge_attr.numpy(), axis=1).tolist())
        rnd_code = torch.tensor(np.asarray(rnd_code), dtype=torch.long)
        rnd_code_batch += [rnd_code]
        y_batch += [d.y.clone()]
    y = torch.cat(y_batch).unsqueeze(1)
    y = (1-alpha)*y + alpha/2
    return rnd_code_batch, node_batch, edge_batch, y

In [14]:
coll_val = functools.partial(collate_fn, alpha=0)

In [15]:
trainloader = DataLoader(train, shuffle=True, batch_size=config.batch_size, collate_fn=collate_fn, num_workers=8)
validloader = DataLoader(valid, shuffle=False, batch_size=config.batch_size, collate_fn=coll_val, num_workers=8)
testloader = DataLoader(test, shuffle=False, batch_size=config.batch_size, collate_fn=coll_val, num_workers=8)

In [16]:
name = "rnd2min2-10M-cls3"
mode = "online"

In [17]:
# download pretrained model
run = wandb.init(mode=mode, 
                 project=config.pretrained_project, 
                 entity=config.pretrained_entity, 
                 job_type="inference")
model_at = run.use_artifact(config.pretrained_model + ":latest")
model_dir = model_at.download()
run.finish()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mchrisxx[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.2 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
2021-09-28 14:55:12.000978: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/opt/cuda/extras/CUPTI/lib64/:/opt/intel/lib:/opt/intel/mkl/lib/intel64:/opt/intel:/opt/ibm/ILOG/CPLEX_Studio1210/cplex/bin/x86-64_linux:/opt/ibm/ILOG/CPLEX_Studio1210/cplex/python/3.7/x86-64_linux:/opt/intel/clck_latest/lib:/opt/intel/daal/lib:/opt/intel/intelpython3/lib:/opt/intel/ipp/lib:/opt/intel/itac_2019/lib:/opt/intel/itac_latest/lib:/opt/in

[34m[1mwandb[0m: Downloading large artifact rnd2min2-10M-euler:latest, 95.63MB. 2 files... Done. 0:0:0


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

In [18]:
with open(model_dir+"/config.yaml") as file:
    mconfig = ConfigDict(yaml.load(file, Loader=yaml.FullLoader))

In [19]:
config.model = mconfig

In [20]:
run = wandb.init(mode=mode, project="ogbg-hiv", entity="dfstransformer", 
                 name=name, config=config.to_dict(), job_type="evaluation")

[34m[1mwandb[0m: wandb version 0.12.2 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
2021-09-28 14:55:18.584598: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/opt/cuda/extras/CUPTI/lib64/:/opt/intel/lib:/opt/intel/mkl/lib/intel64:/opt/intel:/opt/ibm/ILOG/CPLEX_Studio1210/cplex/bin/x86-64_linux:/opt/ibm/ILOG/CPLEX_Studio1210/cplex/python/3.7/x86-64_linux:/opt/intel/clck_latest/lib:/opt/intel/daal/lib:/opt/intel/intelpython3/lib:/opt/intel/ipp/lib:/opt/intel/itac_2019/lib:/opt/intel/itac_latest/lib:/opt/intel/mkl/lib:/opt/intel/mkl_/lib:/opt/intel/mpirt/lib:/opt/intel/tbb/lib:/opt/intel/clck/2019.0/lib:/opt/intel/compilers_and_libraries_2019/linux/lib:/opt/intel/compilers_and_libraries/linux/lib:/opt/intel/itac/2019.0.018/lib:/opt/intel/itac_2019/intel64/lib:/opt/intel/

In [21]:
m = mconfig.model
t = config

In [22]:
ce = nn.CrossEntropyLoss(ignore_index=-1)
bce = nn.BCEWithLogitsLoss()    

In [23]:
class TransformerPlusHead(nn.Module):
    def __init__(self, encoder, n_encoding, n_classes, fingerprint='cls'):
        super(TransformerPlusHead, self).__init__()
        self.encoder = encoder
        self.head = nn.Linear(n_encoding, n_classes)
        self.fingerprint = fingerprint
    
    def forward(self, C, N, E):
        features = self.encoder.encode(C, N, E, method=self.fingerprint)
        output = self.head(features)
        return output
        

In [24]:
from ogb.graphproppred import Evaluator

evaluator = Evaluator(name = 'ogbg-molhiv')

In [25]:
print(evaluator.expected_input_format)
print(evaluator.expected_output_format)

==== Expected input format of Evaluator for ogbg-molhiv
{'y_true': y_true, 'y_pred': y_pred}
- y_true: numpy ndarray or torch tensor of shape (num_graph, num_task)
- y_pred: numpy ndarray or torch tensor of shape (num_graph, num_task)
where y_pred stores score values (for computing AUC score),
num_task is 1, and each row corresponds to one graph.
nan values in y_true are ignored during evaluation.

==== Expected output format of Evaluator for ogbg-molhiv
{'rocauc': rocauc}
- rocauc (float): ROC-AUC score averaged across 1 task(s)



In [26]:
def loss(pred, y, ce=bce):
    return ce(pred, y)

def acc(pred, y):
    y_pred = (pred > 0.5).squeeze()
    y = (y > 0.5)
    return (y_pred == y.squeeze()).sum()/len(y)
    

In [27]:
scorer = functools.partial(compute_roc, loader=validloader, evaluator=evaluator)

In [28]:
device = torch.device('cuda:%d'%t.gpu_id if torch.cuda.is_available()  else 'cpu')
encoder = DFSCodeSeq2SeqFC(**m)
    
if t.load_last and model_dir is not None:
    encoder.load_state_dict(torch.load(model_dir+'/checkpoint.pt', map_location=device))

In [29]:
model = TransformerPlusHead(encoder, m.emb_dim*5*m.n_class_tokens*m.nlayers, 1, fingerprint=t.fingerprint)

In [30]:
param_groups = [
    {'amsgrad': False,
     'betas': (0.9,0.98),
     'eps': 1e-09,
     'lr': 0.000003,
     'params': model.encoder.parameters(),
     'weight_decay': 0},
    {'amsgrad': False,
     'betas': (0.9, 0.999),
     'eps': 1e-08,
     'lr': 0.00075,
     'params': model.head.parameters(),
     'weight_decay': 0}
]

In [31]:
del t.model

In [32]:
t

accumulate_grads: 2
alpha: 0
batch_size: 50
clip_gradient: 0.5
decay_factor: 0.8
es_improvement: 0.0
es_path: null
es_patience: 10
es_period: 300
fingerprint: cls3
gpu_id: 0
load_last: true
lr: 3.0e-06
lr_head: 0.003
lr_patience: 3
lr_pretrained: 0.0003
minimal_lr: 6.0e-08
n_classes: 349
n_epochs: 25
n_frozen: 5
path: ../../results/ogbn_mag/timeout1/
pretrained_class: DFSCodeSeq2SeqFC
pretrained_dir: null
pretrained_entity: dfstransformer
pretrained_model: rnd2min2-10M-euler
pretrained_project: pubchem
pretrained_yaml: null
require_min_dfs_code: false
seed: 123
strict: true
use_local_yaml: false
weight_decay: 0.1

In [33]:
trainer = Trainer(model, trainloader, loss, validloader=validloader, metrics={'acc': acc}, scorer=scorer, wandb_run = run, param_groups=param_groups, **t)

In [34]:
trainer.fit()

Epoch 1: loss 0.155405 0.9600:  45%|███████████████████████████████████████████████████████████████▉                                                                             | 299/659 [01:04<01:24,  4.26it/s]
0it [00:00, ?it/s][A
1it [00:00,  5.46it/s][A
2it [00:00,  4.83it/s][A
3it [00:00,  6.26it/s][A
6it [00:00, 10.78it/s][A
8it [00:00, 12.41it/s][A
10it [00:01,  9.78it/s][A
13it [00:01, 12.98it/s][A
15it [00:01, 12.48it/s][A
17it [00:01,  9.22it/s][A
20it [00:01, 12.24it/s][A
22it [00:02, 12.24it/s][A
24it [00:02, 12.03it/s][A
26it [00:02,  9.67it/s][A
28it [00:02, 11.19it/s][A
30it [00:02, 10.20it/s][A
32it [00:03, 11.17it/s][A
34it [00:03,  9.62it/s][A
36it [00:03, 11.09it/s][A
38it [00:03, 10.87it/s][A
40it [00:03, 11.58it/s][A
42it [00:04,  8.81it/s][A
45it [00:04, 11.44it/s][A
47it [00:04, 11.07it/s][A
49it [00:04,  8.29it/s][A
52it [00:04, 10.93it/s][A
54it [00:05, 10.17it/s][A
56it [00:05, 11.26it/s][A
58it [00:05,  9.67it/s][A
61it [00:05, 12

EarlyStopping counter: 1 out of 10


Epoch 3: loss 0.129180 0.9800:  73%|██████████████████████████████████████████████████████████████████████████████████████████████████████▉                                      | 481/659 [01:56<00:42,  4.15it/s]
0it [00:00, ?it/s][A
1it [00:00,  3.05it/s][A
2it [00:00,  3.91it/s][A
3it [00:00,  5.15it/s][A
6it [00:00, 10.33it/s][A
8it [00:00, 11.49it/s][A
10it [00:01,  9.14it/s][A
12it [00:01, 10.16it/s][A
14it [00:01,  9.98it/s][A
16it [00:01, 10.48it/s][A
18it [00:01, 10.27it/s][A
20it [00:02, 10.51it/s][A
22it [00:02,  9.52it/s][A
24it [00:02,  9.97it/s][A
26it [00:02, 10.45it/s][A
28it [00:02, 10.33it/s][A
30it [00:03,  8.73it/s][A
32it [00:03,  9.72it/s][A
34it [00:03, 10.07it/s][A
36it [00:03,  9.77it/s][A
38it [00:04,  9.29it/s][A
40it [00:04,  9.92it/s][A
42it [00:04,  9.11it/s][A
43it [00:04,  9.18it/s][A
45it [00:04,  7.70it/s][A
47it [00:05,  9.60it/s][A
49it [00:05,  9.66it/s][A
51it [00:05,  9.62it/s][A
53it [00:05,  8.07it/s][A
55it [00:05,  9

EarlyStopping counter: 2 out of 10


Epoch 3: loss 0.125732 1.0000: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 659/659 [02:44<00:00,  4.00it/s]
Epoch 4: loss 0.116595 0.9600:  19%|██████████████████████████                                                                                                                   | 122/659 [00:28<02:19,  3.86it/s]
0it [00:00, ?it/s][A
1it [00:00,  1.22it/s][A
3it [00:01,  3.32it/s][A
5it [00:01,  5.68it/s][A
7it [00:01,  7.94it/s][A
9it [00:01,  9.20it/s][A
11it [00:01,  9.13it/s][A
13it [00:01, 11.00it/s][A
15it [00:01, 12.79it/s][A
17it [00:02, 12.73it/s][A
19it [00:02, 12.04it/s][A
21it [00:02, 10.60it/s][A
23it [00:02,  9.86it/s][A
25it [00:02,  8.87it/s][A
26it [00:03,  8.87it/s][A
28it [00:03,  9.23it/s][A
29it [00:03,  9.24it/s][A
30it [00:03,  8.72it/s][A
32it [00:03,  9.49it/s][A
34it [00:03,  9.87it/s][A
35it [00:04,  9.75it/s][A
37it [00:04, 10.02i

EarlyStopping counter: 3 out of 10


Epoch 4: loss 0.123224 0.9800:  64%|██████████████████████████████████████████████████████████████████████████████████████████▎                                                  | 422/659 [01:46<00:47,  5.03it/s]
0it [00:00, ?it/s][A
1it [00:00,  2.62it/s][A
2it [00:00,  3.46it/s][A
3it [00:00,  4.65it/s][A
5it [00:00,  7.67it/s][A
7it [00:00, 10.20it/s][A
9it [00:01, 11.81it/s][A
11it [00:01,  9.32it/s][A
13it [00:01, 10.85it/s][A
15it [00:01, 12.48it/s][A
17it [00:01, 13.01it/s][A
19it [00:02, 11.34it/s][A
21it [00:02, 10.66it/s][A
23it [00:02, 10.40it/s][A
25it [00:02,  9.75it/s][A
27it [00:02,  9.93it/s][A
29it [00:03, 10.11it/s][A
31it [00:03,  9.37it/s][A
33it [00:03,  9.83it/s][A
35it [00:03,  9.62it/s][A
36it [00:03,  9.48it/s][A
37it [00:03,  9.33it/s][A
39it [00:04,  9.41it/s][A
40it [00:04,  9.32it/s][A
41it [00:04,  9.17it/s][A
42it [00:04,  8.87it/s][A
43it [00:04,  8.68it/s][A
45it [00:04,  9.54it/s][A
47it [00:04,  9.85it/s][A
48it [00:05,  9.

EarlyStopping counter: 4 out of 10


Epoch 4: loss 0.122457 1.0000: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 659/659 [02:48<00:00,  3.90it/s]
Epoch 5: loss 0.120863 0.9800:  10%|█████████████▌                                                                                                                                | 63/659 [00:15<02:02,  4.87it/s]
0it [00:00, ?it/s][A
1it [00:00,  5.57it/s][A
2it [00:00,  6.87it/s][A
3it [00:00,  7.86it/s][A
4it [00:00,  8.18it/s][A
5it [00:00,  6.26it/s][A
7it [00:00,  8.90it/s][A
9it [00:01,  7.76it/s][A
11it [00:01,  9.60it/s][A
13it [00:01,  9.39it/s][A
15it [00:01,  9.85it/s][A
17it [00:02,  5.62it/s][A
19it [00:02,  7.02it/s][A
21it [00:02,  7.39it/s][A
23it [00:02,  8.22it/s][A
25it [00:03,  6.87it/s][A
27it [00:03,  8.01it/s][A
29it [00:03,  8.54it/s][A
30it [00:03,  8.60it/s][A
31it [00:03,  8.82it/s][A
32it [00:04,  8.76it/s][A
33it [00:04,  7.38it/

EarlyStopping counter: 5 out of 10


Epoch 5: loss 0.118956 1.0000:  55%|█████████████████████████████████████████████████████████████████████████████▋                                                               | 363/659 [01:35<01:10,  4.22it/s]
0it [00:00, ?it/s][A
1it [00:00,  2.81it/s][A
3it [00:00,  7.04it/s][A
5it [00:00,  9.19it/s][A
7it [00:00,  8.97it/s][A
9it [00:01,  7.16it/s][A
10it [00:01,  7.11it/s][A
12it [00:01,  8.87it/s][A
14it [00:01, 10.26it/s][A
16it [00:01,  9.30it/s][A
18it [00:02,  7.88it/s][A
20it [00:02,  9.06it/s][A
22it [00:02, 10.29it/s][A
24it [00:02,  9.51it/s][A
26it [00:03,  7.45it/s][A
28it [00:03,  8.79it/s][A
30it [00:03,  9.39it/s][A
32it [00:03,  9.37it/s][A
34it [00:04,  7.45it/s][A
36it [00:04,  8.62it/s][A
38it [00:04,  9.81it/s][A
40it [00:04,  9.67it/s][A
42it [00:04,  7.78it/s][A
44it [00:05,  9.10it/s][A
46it [00:05, 10.32it/s][A
48it [00:05, 10.07it/s][A
50it [00:05,  7.52it/s][A
52it [00:06,  8.86it/s][A
54it [00:06,  9.87it/s][A
56it [00:06,  9

EarlyStopping counter: 1 out of 10


Epoch 6: loss 0.118302 1.0000:  46%|█████████████████████████████████████████████████████████████████                                                                            | 304/659 [01:26<01:23,  4.27it/s]
0it [00:00, ?it/s][A
1it [00:00,  4.97it/s][A
2it [00:00,  5.75it/s][A
4it [00:00,  6.93it/s][A
5it [00:00,  5.09it/s][A
7it [00:01,  6.88it/s][A
9it [00:01,  8.46it/s][A
10it [00:01,  7.82it/s][A
12it [00:01,  9.16it/s][A
13it [00:01,  6.41it/s][A
15it [00:02,  7.68it/s][A
17it [00:02,  8.74it/s][A
18it [00:02,  8.27it/s][A
20it [00:02,  9.52it/s][A
22it [00:02,  7.39it/s][A
24it [00:03,  8.55it/s][A
26it [00:03,  8.13it/s][A
28it [00:03,  9.03it/s][A
29it [00:03,  6.84it/s][A
30it [00:03,  7.23it/s][A
32it [00:04,  8.44it/s][A
34it [00:04,  8.09it/s][A
36it [00:04,  9.07it/s][A
37it [00:04,  6.95it/s][A
39it [00:04,  8.31it/s][A
41it [00:05,  9.45it/s][A
43it [00:05,  8.29it/s][A
45it [00:05,  7.08it/s][A
47it [00:05,  8.06it/s][A
49it [00:06,  8.

EarlyStopping counter: 2 out of 10


Epoch 6: loss 0.117847 0.9400:  92%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏           | 604/659 [02:51<00:14,  3.73it/s]
0it [00:00, ?it/s][A
1it [00:00,  2.54it/s][A
3it [00:00,  5.89it/s][A
4it [00:00,  6.79it/s][A
5it [00:00,  7.54it/s][A
6it [00:00,  8.16it/s][A
8it [00:01,  7.82it/s][A
9it [00:01,  6.30it/s][A
11it [00:01,  7.88it/s][A
12it [00:01,  8.28it/s][A
13it [00:01,  8.57it/s][A
14it [00:01,  8.08it/s][A
15it [00:02,  8.52it/s][A
16it [00:02,  8.14it/s][A
17it [00:02,  3.42it/s][A
19it [00:03,  5.04it/s][A
21it [00:03,  6.51it/s][A
22it [00:03,  6.81it/s][A
24it [00:03,  7.54it/s][A
25it [00:03,  5.68it/s][A
27it [00:04,  6.98it/s][A
29it [00:04,  7.95it/s][A
30it [00:04,  7.62it/s][A
31it [00:04,  7.96it/s][A
32it [00:04,  7.78it/s][A
33it [00:04,  5.43it/s][A
35it [00:05,  6.93it/s][A
36it [00:05,  7.45it/s][A
37it [00:05,  7.93it/s][A
38it [00:05,  8.0

EarlyStopping counter: 3 out of 10


Epoch 6: loss 0.117795 1.0000: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 659/659 [03:16<00:00,  3.35it/s]
Epoch 7: loss 0.122291 0.9400:  37%|████████████████████████████████████████████████████▍                                                                                        | 245/659 [01:03<01:57,  3.51it/s]
0it [00:00, ?it/s][A
1it [00:00,  2.83it/s][A
2it [00:00,  3.11it/s][A
3it [00:00,  4.46it/s][A
5it [00:00,  6.71it/s][A
6it [00:01,  6.95it/s][A
8it [00:01,  7.73it/s][A
9it [00:01,  5.32it/s][A
11it [00:01,  6.78it/s][A
13it [00:02,  7.91it/s][A
15it [00:02,  8.81it/s][A
16it [00:02,  8.56it/s][A
17it [00:02,  5.60it/s][A
19it [00:02,  6.91it/s][A
21it [00:03,  7.96it/s][A
22it [00:03,  8.29it/s][A
23it [00:03,  8.35it/s][A
24it [00:03,  8.07it/s][A
25it [00:03,  5.47it/s][A
27it [00:03,  7.00it/s][A
28it [00:04,  7.43it/s][A
29it [00:04,  7.65it/

EarlyStopping counter: 4 out of 10


Epoch 7: loss 0.116626 1.0000:  83%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                        | 545/659 [02:34<00:27,  4.16it/s]
0it [00:00, ?it/s][A
1it [00:00,  2.42it/s][A
3it [00:00,  5.67it/s][A
4it [00:00,  6.45it/s][A
6it [00:00,  8.01it/s][A
7it [00:01,  7.52it/s][A
8it [00:01,  6.95it/s][A
9it [00:01,  5.58it/s][A
10it [00:01,  5.59it/s][A
11it [00:01,  6.22it/s][A
13it [00:01,  7.49it/s][A
15it [00:02,  8.51it/s][A
16it [00:02,  8.43it/s][A
17it [00:02,  6.07it/s][A
18it [00:02,  5.78it/s][A
20it [00:03,  7.14it/s][A
21it [00:03,  7.64it/s][A
22it [00:03,  7.92it/s][A
24it [00:03,  8.43it/s][A
25it [00:03,  7.79it/s][A
26it [00:03,  5.73it/s][A
27it [00:04,  6.27it/s][A
29it [00:04,  7.58it/s][A
30it [00:04,  8.01it/s][A
31it [00:04,  8.23it/s][A
32it [00:04,  8.19it/s][A
33it [00:04,  7.92it/s][A
34it [00:05,  5.38it/s][A
36it [00:05,  6.89it/s][A
38it [00:05,  7.9

EarlyStopping counter: 5 out of 10


Epoch 7: loss 0.116787 1.0000: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 659/659 [03:14<00:00,  3.38it/s]
Epoch 8: loss 0.115287 1.0000:  28%|███████████████████████████████████████▊                                                                                                     | 186/659 [00:49<02:16,  3.46it/s]
0it [00:00, ?it/s][A
1it [00:00,  2.24it/s][A
2it [00:00,  3.81it/s][A
3it [00:00,  5.10it/s][A
4it [00:00,  6.32it/s][A
5it [00:00,  6.99it/s][A
6it [00:01,  7.76it/s][A
7it [00:01,  7.70it/s][A
8it [00:01,  5.52it/s][A
9it [00:01,  5.50it/s][A
10it [00:01,  5.64it/s][A
11it [00:01,  6.48it/s][A
13it [00:02,  7.58it/s][A
14it [00:02,  8.05it/s][A
15it [00:02,  8.42it/s][A
16it [00:02,  7.67it/s][A
17it [00:02,  5.39it/s][A
18it [00:02,  5.56it/s][A
19it [00:03,  6.33it/s][A
20it [00:03,  7.02it/s][A
21it [00:03,  7.48it/s][A
22it [00:03,  8.00it/s]

EarlyStopping counter: 6 out of 10


Epoch 8: loss 0.116649 0.9400:  74%|███████████████████████████████████████████████████████████████████████████████████████████████████████▉                                     | 486/659 [02:22<00:42,  4.03it/s]
0it [00:00, ?it/s][A
1it [00:00,  1.86it/s][A
2it [00:00,  3.37it/s][A
3it [00:00,  4.65it/s][A
4it [00:00,  5.74it/s][A
5it [00:01,  5.77it/s][A
6it [00:01,  6.40it/s][A
7it [00:01,  6.84it/s][A
8it [00:01,  6.74it/s][A
9it [00:01,  3.91it/s][A
10it [00:02,  4.64it/s][A
11it [00:02,  5.11it/s][A
12it [00:02,  5.86it/s][A
13it [00:02,  6.55it/s][A
14it [00:02,  6.90it/s][A
15it [00:02,  7.20it/s][A
16it [00:02,  6.52it/s][A
17it [00:03,  3.90it/s][A
18it [00:03,  4.62it/s][A
19it [00:03,  5.42it/s][A
20it [00:03,  6.18it/s][A
21it [00:03,  6.60it/s][A
22it [00:03,  7.19it/s][A
23it [00:04,  7.44it/s][A
24it [00:04,  6.68it/s][A
25it [00:04,  4.30it/s][A
26it [00:04,  5.11it/s][A
27it [00:04,  5.88it/s][A
28it [00:05,  6.60it/s][A
29it [00:05,  7.11i

EarlyStopping counter: 7 out of 10


Epoch 8: loss 0.114548 1.0000: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 659/659 [03:28<00:00,  3.15it/s]
Epoch 9: loss 0.123741 0.9400:  19%|███████████████████████████▏                                                                                                                 | 127/659 [00:40<02:20,  3.78it/s]
0it [00:00, ?it/s][A
1it [00:00,  4.04it/s][A
2it [00:00,  4.44it/s][A
3it [00:00,  4.97it/s][A
4it [00:00,  3.82it/s][A
5it [00:01,  3.88it/s][A
6it [00:01,  4.69it/s][A
7it [00:01,  5.23it/s][A
8it [00:01,  5.78it/s][A
9it [00:01,  6.09it/s][A
10it [00:02,  5.22it/s][A
11it [00:02,  5.73it/s][A
12it [00:02,  5.49it/s][A
13it [00:02,  3.84it/s][A
14it [00:02,  4.49it/s][A
15it [00:03,  5.11it/s][A
16it [00:03,  5.69it/s][A
17it [00:03,  6.14it/s][A
18it [00:03,  5.43it/s][A
19it [00:03,  5.98it/s][A
20it [00:03,  5.71it/s][A
21it [00:04,  4.21it/s]

EarlyStopping counter: 8 out of 10


Epoch 9: loss 0.116890 0.9400:  65%|███████████████████████████████████████████████████████████████████████████████████████████▎                                                 | 427/659 [02:29<01:11,  3.26it/s]
0it [00:00, ?it/s][A
1it [00:00,  4.03it/s][A
2it [00:00,  4.52it/s][A
3it [00:00,  4.76it/s][A
4it [00:00,  4.33it/s][A
5it [00:01,  5.02it/s][A
6it [00:01,  5.19it/s][A
7it [00:01,  5.89it/s][A
8it [00:01,  4.67it/s][A
9it [00:01,  5.37it/s][A
10it [00:01,  5.38it/s][A
11it [00:02,  5.93it/s][A
12it [00:02,  5.73it/s][A
13it [00:02,  5.44it/s][A
14it [00:02,  5.71it/s][A
15it [00:02,  6.10it/s][A
16it [00:03,  3.13it/s][A
17it [00:03,  3.83it/s][A
18it [00:03,  4.11it/s][A
19it [00:03,  4.70it/s][A
20it [00:04,  4.85it/s][A
21it [00:04,  4.51it/s][A
22it [00:04,  4.76it/s][A
23it [00:04,  5.32it/s][A
24it [00:04,  4.66it/s][A
25it [00:05,  5.19it/s][A
26it [00:05,  5.27it/s][A
27it [00:05,  5.59it/s][A
28it [00:05,  5.57it/s][A
29it [00:05,  5.11i

EarlyStopping counter: 9 out of 10


Epoch 9: loss 0.112950 1.0000: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 659/659 [03:59<00:00,  2.75it/s]
Epoch 10: loss 0.127389 0.9800:  10%|██████████████▌                                                                                                                              | 68/659 [00:22<03:41,  2.67it/s]
0it [00:00, ?it/s][A
1it [00:00,  1.40it/s][A
2it [00:00,  2.66it/s][A
3it [00:01,  3.67it/s][A
4it [00:01,  4.47it/s][A
5it [00:01,  4.89it/s][A
6it [00:01,  5.56it/s][A
7it [00:01,  6.09it/s][A
8it [00:02,  3.73it/s][A
9it [00:02,  4.30it/s][A
10it [00:02,  4.83it/s][A
11it [00:02,  5.37it/s][A
12it [00:02,  5.95it/s][A
13it [00:02,  6.35it/s][A
14it [00:02,  6.56it/s][A
15it [00:03,  6.79it/s][A
16it [00:03,  3.67it/s][A
17it [00:03,  4.30it/s][A
18it [00:03,  4.88it/s][A
19it [00:04,  5.44it/s][A
20it [00:04,  5.81it/s][A
21it [00:04,  6.22it/s]

EarlyStopping counter: 10 out of 10





In [35]:
model.load_state_dict(torch.load(trainer.es_path+'checkpoint.pt'))

<All keys matched successfully>

In [37]:
run.log({'Valid ROCAUC': compute_roc(model, validloader, evaluator)})
run.log({'Test ROCAUC': compute_roc(model, testloader, evaluator)})

83it [00:14,  5.69it/s]
83it [00:15,  5.49it/s]


#store config and model
with open(t.es_path+'config.yaml', 'w') as f:
    yaml.dump(config.to_dict(), f, default_flow_style=False)
if name is not None and mode != "offline":
    trained_model_artifact = wandb.Artifact(name, type="model", description="trained selfattn model")
    trained_model_artifact.add_dir(t.es_path)
    run.log_artifact(trained_model_artifact)

In [38]:
exit()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
batch-loss,0.03179
loss,0.12739
batch-acc,0.98
acc,0.96609
learning rate,0.0
_runtime,3039.0
_timestamp,1632836756.0
_step,6031.0
valid-score,0.75576
train-loss,0.12739


0,1
batch-loss,█▃▃▂▆▃▄▃▁▄▅▃▁▂▁▅▅▂▂▂▂▄▁▃▂▄▄▄▃▂█▃▅▂▄▂▅▆▁▆
loss,█▆▆▆▅▅▄▄▄▄▄▄▄▁▃▄▄▄▄▄▄▄▅▄▄▄▂▄▄▄▄▄▃▃▄▃▄▄▃▅
batch-acc,▁▆▆▆▃▆▅▆█▁▃▆█▆▆▅▃▆▆▆█▆█▆▆▅▅▅▅▆▁▆▅▅▆▆▅▃▆▁
acc,▂▂▂▂▂▂▃▃▃▄▃▃▃█▅▄▃▄▅▄▅▅▁▅▅▄▄▄▄▄▄▄▅▅▄▇▄▄▅▃
learning rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇███
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
valid-score,▁▂▃▇▆▆▆▆▆█▇▇▆▄▆▅▄▁▇▂
train-loss,█▅▄▃▃▂▂▁▁▄
