In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
run_name = input()

 regime


In [2]:
# System imports
import os
import sys
import yaml

# External imports
import matplotlib.pyplot as plt
import scipy as sp
from sklearn.metrics import auc
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning import Trainer
import wandb
import math
sys.path.append('../..')

from LightningModules.GNNEmbedding.Models.gnn_embedding import InteractionGNN

device = "cuda" if torch.cuda.is_available() else "cpu"
from pytorch_lightning.callbacks import ModelCheckpoint

In [3]:
def kaiming_init(model):
    for name, param in model.named_parameters():
        if name.endswith(".bias"):
            param.data.fill_(0)
        elif name.startswith("layers.0"):  # The first layer does not have ReLU applied on its input
            param.data.normal_(0, 1 / math.sqrt(param.shape[1]))
        else:
            param.data.normal_(0, math.sqrt(2) / math.sqrt(param.shape[1]))

## Sweep

In [5]:
with open("gnn_embedding_sweep.yaml") as f:
        sweep_hparams = yaml.load(f, Loader=yaml.FullLoader)
with open("gnn_embedding_default.yaml") as f:
        default_hparams = yaml.load(f, Loader=yaml.FullLoader)

In [6]:
sweep_configuration = {
    "name": run_name,
    "project": "ITk_barrell_gnn_embedding",
    "metric": {"name": "pur", "goal": "maximize"},
    "method": "grid",
    "parameters": sweep_hparams
}

In [7]:
def training():
    wandb.init()
    model = InteractionGNN({**default_hparams, **wandb.config})

    # kaiming_init(model)
    
    checkpoint_callback = ModelCheckpoint(
        monitor='pur',
        mode="max",
        save_top_k=2,
        save_last=True)

    logger = WandbLogger()
    trainer = Trainer(gpus=1, max_epochs=default_hparams["max_epochs"], log_every_n_steps = 50, logger=logger, callbacks=[checkpoint_callback], default_root_dir="/global/cfs/cdirs/m3443/usr/ryanliu/ITk_gnn_embedding/")
    trainer.fit(model)

In [None]:
sweep_id = wandb.sweep(sweep_configuration, project = "ITk_barrel_gnn_embedding")

# run the sweep
wandb.agent(sweep_id, function=training)

## Construct PyLightning model

In [4]:
with open("gnn_embedding_default.yaml") as f:
    hparams = yaml.load(f, Loader=yaml.FullLoader)

In [5]:
model = InteractionGNN(hparams)

## Metric Learning

In [6]:
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    monitor='pur',
    mode="max",
    save_top_k=2,
    save_last=True)

In [None]:
# kaiming_init(model)
logger = WandbLogger(project="ITk_gnn_embedding")
trainer = Trainer(gpus=1, max_epochs=hparams["max_epochs"], logger=logger, num_sanity_val_steps=2, callbacks=[checkpoint_callback], log_every_n_steps = 50, default_root_dir="/global/cfs/cdirs/m3443/usr/ryanliu/ITk_gnn_embedding/")
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(

  | Name         | Type             | Params
--------------------------------------------------
0 | cos          | CosineSimilarity | 0     
1 | node_encoder | Sequential       | 266 K 
2 | edge_encoder | Sequential       | 2.6 M 
3 | gnn_blocks   | ModuleList       | 117 M 
4 | output_layer | Sequential       | 535 K 
--------------------------------------------------
121 M     Trainable params
0         Non-trainable params
121 M     Total params
484.172   Total estimated model params size (MB)


Validation sanity check:  50%|█████     | 1/2 [00:00<00:00,  1.41it/s]



                                                                      



Epoch 0:   0%|          | 0/1010 [00:00<?, ?it/s] 



Epoch 0:  99%|█████████▉| 1000/1010 [05:15<00:03,  3.17it/s, loss=0.212, v_num=czm9]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/10 [00:00<?, ?it/s][A
Epoch 0:  99%|█████████▉| 1002/1010 [05:16<00:02,  3.17it/s, loss=0.212, v_num=czm9]




Validating:  20%|██        | 2/10 [00:01<00:03,  2.15it/s][A




Epoch 0:  99%|█████████▉| 1004/1010 [05:16<00:01,  3.17it/s, loss=0.212, v_num=czm9]




Validating:  40%|████      | 4/10 [00:01<00:02,  2.68it/s][A




Epoch 0: 100%|█████████▉| 1006/1010 [05:17<00:01,  3.17it/s, loss=0.212, v_num=czm9]




Validating:  60%|██████    | 6/10 [00:02<00:01,  3.12it/s][A




Epoch 0: 100%|█████████▉| 1008/1010 [05:18<00:00,  3.17it/s, loss=0.212, v_num=czm9]




Validating:  80%|████████  | 8/10 [00:02<00:00,  3.16it/s][A




Epoch 0: 100%|██████████| 1010/1010 [05:18<00:00,  3.17it/s, loss=0.212, v_num=czm9]




Epoch 0: 100%|██████████| 1010/1010 [05:19<00:00,  3.16it/s, loss=0.212, v_num=czm9]
Epoch 1:   0%|          | 0/1010 [00:00<?, ?it/s, loss=0.212, v_num=czm9]           



Epoch 1:  99%|█████████▉| 1000/1010 [05:14<00:03,  3.18it/s, loss=0.209, v_num=czm9]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/10 [00:00<?, ?it/s][A
Epoch 1:  99%|█████████▉| 1002/1010 [05:15<00:02,  3.18it/s, loss=0.209, v_num=czm9]




Validating:  20%|██        | 2/10 [00:01<00:03,  2.10it/s][A




Epoch 1:  99%|█████████▉| 1004/1010 [05:15<00:01,  3.18it/s, loss=0.209, v_num=czm9]




Validating:  40%|████      | 4/10 [00:01<00:02,  2.66it/s][A




Epoch 1: 100%|█████████▉| 1006/1010 [05:16<00:01,  3.18it/s, loss=0.209, v_num=czm9]




Validating:  60%|██████    | 6/10 [00:02<00:01,  3.06it/s][A




Epoch 1: 100%|█████████▉| 1008/1010 [05:16<00:00,  3.18it/s, loss=0.209, v_num=czm9]




Validating:  80%|████████  | 8/10 [00:02<00:00,  3.10it/s][A




Epoch 1: 100%|██████████| 1010/1010 [05:17<00:00,  3.18it/s, loss=0.209, v_num=czm9]




Epoch 1: 100%|██████████| 1010/1010 [05:18<00:00,  3.18it/s, loss=0.209, v_num=czm9]
Epoch 2:   0%|          | 0/1010 [00:00<?, ?it/s, loss=0.209, v_num=czm9]           



Epoch 2:  99%|█████████▉| 1000/1010 [05:16<00:03,  3.15it/s, loss=0.208, v_num=czm9]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/10 [00:00<?, ?it/s][A
Epoch 2:  99%|█████████▉| 1002/1010 [05:17<00:02,  3.15it/s, loss=0.208, v_num=czm9]




Validating:  20%|██        | 2/10 [00:01<00:03,  2.12it/s][A




Epoch 2:  99%|█████████▉| 1004/1010 [05:18<00:01,  3.15it/s, loss=0.208, v_num=czm9]




Validating:  40%|████      | 4/10 [00:01<00:02,  2.65it/s][A




Epoch 2: 100%|█████████▉| 1006/1010 [05:19<00:01,  3.15it/s, loss=0.208, v_num=czm9]




Validating:  60%|██████    | 6/10 [00:02<00:01,  3.17it/s][A




Epoch 2: 100%|█████████▉| 1008/1010 [05:19<00:00,  3.15it/s, loss=0.208, v_num=czm9]




Validating:  80%|████████  | 8/10 [00:02<00:00,  3.13it/s][A




Epoch 2: 100%|██████████| 1010/1010 [05:20<00:00,  3.15it/s, loss=0.208, v_num=czm9]




Epoch 2: 100%|██████████| 1010/1010 [05:20<00:00,  3.15it/s, loss=0.208, v_num=czm9]
Epoch 3:   0%|          | 0/1010 [00:00<?, ?it/s, loss=0.208, v_num=czm9]           



Epoch 3:  99%|█████████▉| 1000/1010 [05:15<00:03,  3.16it/s, loss=0.207, v_num=czm9]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/10 [00:00<?, ?it/s][A
Epoch 3:  99%|█████████▉| 1002/1010 [05:16<00:02,  3.16it/s, loss=0.207, v_num=czm9]




Validating:  20%|██        | 2/10 [00:00<00:03,  2.19it/s][A




Epoch 3:  99%|█████████▉| 1004/1010 [05:17<00:01,  3.16it/s, loss=0.207, v_num=czm9]




Validating:  40%|████      | 4/10 [00:01<00:02,  2.64it/s][A




Epoch 3: 100%|█████████▉| 1006/1010 [05:17<00:01,  3.16it/s, loss=0.207, v_num=czm9]




Validating:  60%|██████    | 6/10 [00:02<00:01,  3.15it/s][A




Epoch 3: 100%|█████████▉| 1008/1010 [05:18<00:00,  3.16it/s, loss=0.207, v_num=czm9]




Validating:  80%|████████  | 8/10 [00:02<00:00,  3.12it/s][A




Epoch 3: 100%|██████████| 1010/1010 [05:19<00:00,  3.16it/s, loss=0.207, v_num=czm9]




Epoch 3: 100%|██████████| 1010/1010 [05:19<00:00,  3.16it/s, loss=0.207, v_num=czm9]
Epoch 4:   0%|          | 0/1010 [00:00<?, ?it/s, loss=0.207, v_num=czm9]           



Epoch 4:  99%|█████████▉| 1000/1010 [05:15<00:03,  3.17it/s, loss=0.208, v_num=czm9]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/10 [00:00<?, ?it/s][A
Epoch 4:  99%|█████████▉| 1002/1010 [05:16<00:02,  3.16it/s, loss=0.208, v_num=czm9]




Validating:  20%|██        | 2/10 [00:01<00:04,  1.93it/s][A




Epoch 4:  99%|█████████▉| 1004/1010 [05:17<00:01,  3.16it/s, loss=0.208, v_num=czm9]




Validating:  40%|████      | 4/10 [00:01<00:02,  2.37it/s][A




Epoch 4: 100%|█████████▉| 1006/1010 [05:18<00:01,  3.16it/s, loss=0.208, v_num=czm9]




Validating:  60%|██████    | 6/10 [00:02<00:01,  2.97it/s][A




Epoch 4: 100%|█████████▉| 1008/1010 [05:18<00:00,  3.16it/s, loss=0.208, v_num=czm9]




Validating:  80%|████████  | 8/10 [00:03<00:00,  3.02it/s][A




Epoch 4: 100%|██████████| 1010/1010 [05:19<00:00,  3.16it/s, loss=0.208, v_num=czm9]




Epoch 4: 100%|██████████| 1010/1010 [05:19<00:00,  3.16it/s, loss=0.208, v_num=czm9]
Epoch 5:   0%|          | 0/1010 [00:00<?, ?it/s, loss=0.208, v_num=czm9]           



Epoch 5:  99%|█████████▉| 1000/1010 [05:16<00:03,  3.16it/s, loss=0.208, v_num=czm9]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/10 [00:00<?, ?it/s][A
Epoch 5:  99%|█████████▉| 1002/1010 [05:16<00:02,  3.16it/s, loss=0.208, v_num=czm9]




Validating:  20%|██        | 2/10 [00:00<00:03,  2.20it/s][A




Epoch 5:  99%|█████████▉| 1004/1010 [05:17<00:01,  3.16it/s, loss=0.208, v_num=czm9]




Validating:  40%|████      | 4/10 [00:01<00:02,  2.66it/s][A




Epoch 5: 100%|█████████▉| 1006/1010 [05:18<00:01,  3.16it/s, loss=0.208, v_num=czm9]




Validating:  60%|██████    | 6/10 [00:02<00:01,  3.19it/s][A




Epoch 5: 100%|█████████▉| 1008/1010 [05:18<00:00,  3.16it/s, loss=0.208, v_num=czm9]




Validating:  80%|████████  | 8/10 [00:02<00:00,  3.16it/s][A




Epoch 5: 100%|██████████| 1010/1010 [05:19<00:00,  3.16it/s, loss=0.208, v_num=czm9]




Epoch 5: 100%|██████████| 1010/1010 [05:19<00:00,  3.16it/s, loss=0.208, v_num=czm9]
Epoch 6:   0%|          | 0/1010 [00:00<?, ?it/s, loss=0.208, v_num=czm9]           



Epoch 6:  99%|█████████▉| 1000/1010 [05:16<00:03,  3.16it/s, loss=0.206, v_num=czm9]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/10 [00:00<?, ?it/s][A
Epoch 6:  99%|█████████▉| 1002/1010 [05:17<00:02,  3.16it/s, loss=0.206, v_num=czm9]




Validating:  20%|██        | 2/10 [00:00<00:03,  2.27it/s][A




Epoch 6:  99%|█████████▉| 1004/1010 [05:17<00:01,  3.16it/s, loss=0.206, v_num=czm9]




Validating:  40%|████      | 4/10 [00:01<00:02,  2.75it/s][A




Epoch 6: 100%|█████████▉| 1006/1010 [05:18<00:01,  3.16it/s, loss=0.206, v_num=czm9]




Validating:  60%|██████    | 6/10 [00:02<00:01,  3.25it/s][A




Epoch 6: 100%|█████████▉| 1008/1010 [05:18<00:00,  3.16it/s, loss=0.206, v_num=czm9]




Validating:  80%|████████  | 8/10 [00:02<00:00,  3.14it/s][A




Epoch 6: 100%|██████████| 1010/1010 [05:19<00:00,  3.16it/s, loss=0.206, v_num=czm9]




Epoch 6: 100%|██████████| 1010/1010 [05:19<00:00,  3.16it/s, loss=0.206, v_num=czm9]
Epoch 7:   0%|          | 0/1010 [00:00<?, ?it/s, loss=0.206, v_num=czm9]           



Epoch 7:  99%|█████████▉| 1000/1010 [05:16<00:03,  3.16it/s, loss=0.209, v_num=czm9]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/10 [00:00<?, ?it/s][A
Epoch 7:  99%|█████████▉| 1002/1010 [05:17<00:02,  3.16it/s, loss=0.209, v_num=czm9]




Validating:  20%|██        | 2/10 [00:01<00:03,  2.07it/s][A




Epoch 7:  99%|█████████▉| 1004/1010 [05:17<00:01,  3.16it/s, loss=0.209, v_num=czm9]




Validating:  40%|████      | 4/10 [00:01<00:02,  2.62it/s][A




Epoch 7: 100%|█████████▉| 1006/1010 [05:18<00:01,  3.16it/s, loss=0.209, v_num=czm9]




Validating:  60%|██████    | 6/10 [00:02<00:01,  3.16it/s][A




Epoch 7: 100%|█████████▉| 1008/1010 [05:18<00:00,  3.16it/s, loss=0.209, v_num=czm9]




Validating:  80%|████████  | 8/10 [00:02<00:00,  3.13it/s][A




Epoch 7: 100%|██████████| 1010/1010 [05:19<00:00,  3.16it/s, loss=0.209, v_num=czm9]




Epoch 7: 100%|██████████| 1010/1010 [05:20<00:00,  3.16it/s, loss=0.209, v_num=czm9]
Epoch 8:   0%|          | 0/1010 [00:00<?, ?it/s, loss=0.209, v_num=czm9]           



Epoch 8:  99%|█████████▉| 1000/1010 [05:16<00:03,  3.16it/s, loss=0.205, v_num=czm9]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/10 [00:00<?, ?it/s][A
Epoch 8:  99%|█████████▉| 1002/1010 [05:17<00:02,  3.16it/s, loss=0.205, v_num=czm9]




Validating:  20%|██        | 2/10 [00:00<00:03,  2.25it/s][A




Epoch 8:  99%|█████████▉| 1004/1010 [05:17<00:01,  3.16it/s, loss=0.205, v_num=czm9]




Validating:  40%|████      | 4/10 [00:01<00:02,  2.72it/s][A




Epoch 8: 100%|█████████▉| 1006/1010 [05:18<00:01,  3.16it/s, loss=0.205, v_num=czm9]




Validating:  60%|██████    | 6/10 [00:02<00:01,  3.16it/s][A




Epoch 8: 100%|█████████▉| 1008/1010 [05:18<00:00,  3.16it/s, loss=0.205, v_num=czm9]




Validating:  80%|████████  | 8/10 [00:02<00:00,  3.12it/s][A




Epoch 8: 100%|██████████| 1010/1010 [05:19<00:00,  3.16it/s, loss=0.205, v_num=czm9]




Epoch 8: 100%|██████████| 1010/1010 [05:20<00:00,  3.16it/s, loss=0.205, v_num=czm9]
Epoch 9:   0%|          | 0/1010 [00:00<?, ?it/s, loss=0.205, v_num=czm9]           



Epoch 9:  99%|█████████▉| 1000/1010 [05:15<00:03,  3.17it/s, loss=0.206, v_num=czm9]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/10 [00:00<?, ?it/s][A
Epoch 9:  99%|█████████▉| 1002/1010 [05:16<00:02,  3.16it/s, loss=0.206, v_num=czm9]




Validating:  20%|██        | 2/10 [00:01<00:04,  1.90it/s][A




Epoch 9:  99%|█████████▉| 1004/1010 [05:17<00:01,  3.16it/s, loss=0.206, v_num=czm9]




Validating:  40%|████      | 4/10 [00:01<00:02,  2.52it/s][A




Epoch 9: 100%|█████████▉| 1006/1010 [05:18<00:01,  3.16it/s, loss=0.206, v_num=czm9]




Validating:  60%|██████    | 6/10 [00:02<00:01,  3.09it/s][A




Epoch 9: 100%|█████████▉| 1008/1010 [05:18<00:00,  3.16it/s, loss=0.206, v_num=czm9]




Validating:  80%|████████  | 8/10 [00:03<00:00,  3.12it/s][A




Epoch 9: 100%|██████████| 1010/1010 [05:19<00:00,  3.16it/s, loss=0.206, v_num=czm9]




Epoch 9: 100%|██████████| 1010/1010 [05:19<00:00,  3.16it/s, loss=0.206, v_num=czm9]
Epoch 10:   0%|          | 0/1010 [00:00<?, ?it/s, loss=0.206, v_num=czm9]          



Epoch 10:  99%|█████████▉| 1000/1010 [05:16<00:03,  3.16it/s, loss=0.204, v_num=czm9]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/10 [00:00<?, ?it/s][A
Epoch 10:  99%|█████████▉| 1002/1010 [05:17<00:02,  3.16it/s, loss=0.204, v_num=czm9]




Validating:  20%|██        | 2/10 [00:01<00:04,  1.98it/s][A




Epoch 10:  99%|█████████▉| 1004/1010 [05:17<00:01,  3.16it/s, loss=0.204, v_num=czm9]




Validating:  40%|████      | 4/10 [00:01<00:02,  2.56it/s][A




Epoch 10: 100%|█████████▉| 1006/1010 [05:18<00:01,  3.16it/s, loss=0.204, v_num=czm9]




Validating:  60%|██████    | 6/10 [00:02<00:01,  3.05it/s][A




Epoch 10: 100%|█████████▉| 1008/1010 [05:19<00:00,  3.16it/s, loss=0.204, v_num=czm9]




Validating:  80%|████████  | 8/10 [00:02<00:00,  3.11it/s][A




Epoch 10: 100%|██████████| 1010/1010 [05:19<00:00,  3.16it/s, loss=0.204, v_num=czm9]




Epoch 10: 100%|██████████| 1010/1010 [05:20<00:00,  3.15it/s, loss=0.204, v_num=czm9]
Epoch 11:   0%|          | 0/1010 [00:00<?, ?it/s, loss=0.204, v_num=czm9]           



Epoch 11:  99%|█████████▉| 1000/1010 [05:16<00:03,  3.15it/s, loss=0.204, v_num=czm9]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/10 [00:00<?, ?it/s][A
Epoch 11:  99%|█████████▉| 1002/1010 [05:17<00:02,  3.15it/s, loss=0.204, v_num=czm9]




Validating:  20%|██        | 2/10 [00:01<00:03,  2.08it/s][A




Epoch 11:  99%|█████████▉| 1004/1010 [05:18<00:01,  3.15it/s, loss=0.204, v_num=czm9]




Validating:  40%|████      | 4/10 [00:01<00:02,  2.61it/s][A




Epoch 11: 100%|█████████▉| 1006/1010 [05:19<00:01,  3.15it/s, loss=0.204, v_num=czm9]




Validating:  60%|██████    | 6/10 [00:02<00:01,  3.13it/s][A




Epoch 11: 100%|█████████▉| 1008/1010 [05:19<00:00,  3.15it/s, loss=0.204, v_num=czm9]




Validating:  80%|████████  | 8/10 [00:02<00:00,  3.10it/s][A




Epoch 11: 100%|██████████| 1010/1010 [05:20<00:00,  3.15it/s, loss=0.204, v_num=czm9]




Epoch 11: 100%|██████████| 1010/1010 [05:20<00:00,  3.15it/s, loss=0.204, v_num=czm9]
Epoch 12:   0%|          | 0/1010 [00:00<?, ?it/s, loss=0.204, v_num=czm9]           



Epoch 12:  99%|█████████▉| 1000/1010 [05:16<00:03,  3.16it/s, loss=0.204, v_num=czm9]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/10 [00:00<?, ?it/s][A
Epoch 12:  99%|█████████▉| 1002/1010 [05:16<00:02,  3.16it/s, loss=0.204, v_num=czm9]




Validating:  20%|██        | 2/10 [00:00<00:03,  2.22it/s][A




Epoch 12:  99%|█████████▉| 1004/1010 [05:17<00:01,  3.16it/s, loss=0.204, v_num=czm9]




Validating:  40%|████      | 4/10 [00:01<00:02,  2.63it/s][A




Epoch 12: 100%|█████████▉| 1006/1010 [05:18<00:01,  3.16it/s, loss=0.204, v_num=czm9]




Validating:  60%|██████    | 6/10 [00:02<00:01,  3.13it/s][A




Epoch 12: 100%|█████████▉| 1008/1010 [05:18<00:00,  3.16it/s, loss=0.204, v_num=czm9]




Validating:  80%|████████  | 8/10 [00:02<00:00,  3.11it/s][A




Epoch 12: 100%|██████████| 1010/1010 [05:19<00:00,  3.16it/s, loss=0.204, v_num=czm9]




Epoch 12: 100%|██████████| 1010/1010 [05:19<00:00,  3.16it/s, loss=0.204, v_num=czm9]
Epoch 13:   0%|          | 0/1010 [00:00<?, ?it/s, loss=0.204, v_num=czm9]           



Epoch 13:  99%|█████████▉| 1000/1010 [05:15<00:03,  3.17it/s, loss=0.204, v_num=czm9]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/10 [00:00<?, ?it/s][A
Epoch 13:  99%|█████████▉| 1002/1010 [05:16<00:02,  3.17it/s, loss=0.204, v_num=czm9]




Validating:  20%|██        | 2/10 [00:01<00:03,  2.15it/s][A




Epoch 13:  99%|█████████▉| 1004/1010 [05:16<00:01,  3.17it/s, loss=0.204, v_num=czm9]




Validating:  40%|████      | 4/10 [00:01<00:02,  2.68it/s][A




Epoch 13: 100%|█████████▉| 1006/1010 [05:17<00:01,  3.17it/s, loss=0.204, v_num=czm9]




Validating:  60%|██████    | 6/10 [00:02<00:01,  3.16it/s][A




Epoch 13: 100%|█████████▉| 1008/1010 [05:18<00:00,  3.17it/s, loss=0.204, v_num=czm9]




Validating:  80%|████████  | 8/10 [00:02<00:00,  3.17it/s][A




Epoch 13: 100%|██████████| 1010/1010 [05:18<00:00,  3.17it/s, loss=0.204, v_num=czm9]




Epoch 13: 100%|██████████| 1010/1010 [05:19<00:00,  3.16it/s, loss=0.204, v_num=czm9]
Epoch 14:   0%|          | 0/1010 [00:00<?, ?it/s, loss=0.204, v_num=czm9]           



Epoch 14:  99%|█████████▉| 1000/1010 [05:16<00:03,  3.16it/s, loss=0.204, v_num=czm9]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/10 [00:00<?, ?it/s][A
Epoch 14:  99%|█████████▉| 1002/1010 [05:17<00:02,  3.15it/s, loss=0.204, v_num=czm9]




Validating:  20%|██        | 2/10 [00:01<00:03,  2.11it/s][A




Epoch 14:  99%|█████████▉| 1004/1010 [05:18<00:01,  3.15it/s, loss=0.204, v_num=czm9]




Validating:  40%|████      | 4/10 [00:01<00:02,  2.60it/s][A




Epoch 14: 100%|█████████▉| 1006/1010 [05:18<00:01,  3.15it/s, loss=0.204, v_num=czm9]




Validating:  60%|██████    | 6/10 [00:02<00:01,  3.14it/s][A




Epoch 14: 100%|█████████▉| 1008/1010 [05:19<00:00,  3.15it/s, loss=0.204, v_num=czm9]




Validating:  80%|████████  | 8/10 [00:02<00:00,  3.14it/s][A




Epoch 14: 100%|██████████| 1010/1010 [05:20<00:00,  3.15it/s, loss=0.204, v_num=czm9]




Epoch 14: 100%|██████████| 1010/1010 [05:20<00:00,  3.15it/s, loss=0.204, v_num=czm9]
Epoch 15:   0%|          | 0/1010 [00:00<?, ?it/s, loss=0.204, v_num=czm9]           



Epoch 15:  99%|█████████▉| 1000/1010 [05:15<00:03,  3.16it/s, loss=0.204, v_num=czm9]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/10 [00:00<?, ?it/s][A
Epoch 15:  99%|█████████▉| 1002/1010 [05:16<00:02,  3.16it/s, loss=0.204, v_num=czm9]




Validating:  20%|██        | 2/10 [00:01<00:03,  2.08it/s][A




Epoch 15:  99%|█████████▉| 1004/1010 [05:17<00:01,  3.16it/s, loss=0.204, v_num=czm9]




Validating:  40%|████      | 4/10 [00:01<00:02,  2.56it/s][A




Epoch 15: 100%|█████████▉| 1006/1010 [05:18<00:01,  3.16it/s, loss=0.204, v_num=czm9]




Validating:  60%|██████    | 6/10 [00:02<00:01,  3.10it/s][A




Epoch 15: 100%|█████████▉| 1008/1010 [05:18<00:00,  3.16it/s, loss=0.204, v_num=czm9]




Validating:  80%|████████  | 8/10 [00:02<00:00,  3.11it/s][A




Epoch 15: 100%|██████████| 1010/1010 [05:19<00:00,  3.16it/s, loss=0.204, v_num=czm9]




Epoch 15: 100%|██████████| 1010/1010 [05:19<00:00,  3.16it/s, loss=0.204, v_num=czm9]
Epoch 16:   0%|          | 0/1010 [00:00<?, ?it/s, loss=0.204, v_num=czm9]           



Epoch 16:  47%|████▋     | 473/1010 [02:27<02:47,  3.20it/s, loss=0.203, v_num=czm9]

## Initialize from trained model

In [None]:
import wandb
wandb.finish()

with open("dual_embedding_default.yaml") as f:
    hparams = yaml.load(f, Loader=yaml.FullLoader)
    
hparams["use_dual_encoder"] = True

model = VanillaDualEmbedding(hparams)

checkpoint = torch.load("/global/cfs/cdirs/m3443/usr/ryanliu/ITk_embedding/ITk_dual_embedding/3ijb4qnw/checkpoints/last.ckpt")
state_dict = checkpoint["state_dict"]
names = [i for i in state_dict]
for i in names:
    state = state_dict[i]
    i = i.replace("input_layer1", "input_layer2")
    i = i.replace("layers1", "layers2")
    i = i.replace("output_layer1", "output_layer2")
    state_dict[i] = state

model.load_state_dict(state_dict)
del state_dict

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    monitor='pur',
    mode="max",
    save_top_k=2,
    save_last=True)

In [None]:
logger = WandbLogger(project="ITk_dual_embedding")
trainer = Trainer(gpus=1, max_epochs=hparams["max_epochs"], logger=logger, num_sanity_val_steps=2, callbacks=[checkpoint_callback], log_every_n_steps = 50, default_root_dir="/global/cfs/cdirs/m3443/usr/ryanliu/ITk_embedding/")
trainer.fit(model)