In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
run_name = input()

 architecture


In [2]:
# System imports
import os
import sys
import yaml

# External imports
import matplotlib.pyplot as plt
import scipy as sp
from sklearn.metrics import auc
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning import Trainer
import wandb
import math
sys.path.append('../..')

from LightningModules.GNNEdgeClassification.Models.gnn import InteractionGNN

device = "cuda" if torch.cuda.is_available() else "cpu"
from pytorch_lightning.callbacks import ModelCheckpoint

In [3]:
def kaiming_init(model):
    for name, param in model.named_parameters():
        if name.endswith(".bias"):
            param.data.fill_(0)
        elif name.startswith("layers.0"):  # The first layer does not have ReLU applied on its input
            param.data.normal_(0, 1 / math.sqrt(param.shape[1]))
        else:
            param.data.normal_(0, math.sqrt(2) / math.sqrt(param.shape[1]))

In [4]:
def load_from_pretrained(model, path):
    
    checkpoint = torch.load(path)
    state_dict = checkpoint["state_dict"]
    names = [i for i in state_dict]
    for i in names:
        if "output_layer" in i:
            del state_dict[i]
    model.load_state_dict(state_dict, strict=False)
    del state_dict
    
    return model

## Sweep

In [5]:
with open("gnn_edge_classification_sweep.yaml") as f:
        sweep_hparams = yaml.load(f, Loader=yaml.FullLoader)
with open("gnn_edge_classification_default.yaml") as f:
        default_hparams = yaml.load(f, Loader=yaml.FullLoader)

In [6]:
sweep_configuration = {
    "name": run_name,
    "project": "ITk_barrel_gnn",
    "metric": {"name": "pur", "goal": "maximize"},
    "method": "grid",
    "parameters": sweep_hparams
}

In [7]:
def training():
    wandb.init()
    model = InteractionGNN({**default_hparams, **wandb.config})

    # kaiming_init(model)
    model = load_from_pretrained(model, "/global/cfs/cdirs/m3443/usr/ryanliu/ITk_barrel_gnn/ITk_barrel_gnn/3pp297g4/checkpoints/epoch=32-step=32999.ckpt")
    
    checkpoint_callback = ModelCheckpoint(
        monitor='pur',
        mode="max",
        save_top_k=2,
        save_last=True)

    logger = WandbLogger()
    trainer = Trainer(gpus=1, max_epochs=default_hparams["max_epochs"], log_every_n_steps = 50, logger=logger, callbacks=[checkpoint_callback], default_root_dir="/global/cfs/cdirs/m3443/usr/ryanliu/ITk_barrel_gnn/")
    trainer.fit(model)

In [None]:
sweep_id = wandb.sweep(sweep_configuration, project = "ITk_barrel_gnn")

# run the sweep
wandb.agent(sweep_id, function=training)

## Construct PyLightning model

In [5]:
with open("gnn_edge_classification_default.yaml") as f:
    hparams = yaml.load(f, Loader=yaml.FullLoader)

In [6]:
model = InteractionGNN(hparams)
model = load_from_pretrained(model, "/global/cfs/cdirs/m3443/usr/ryanliu/ITk_barrel_gnn/ITk_barrel_gnn/3pp297g4/checkpoints/epoch=32-step=32999.ckpt")

## Metric Learning

In [7]:
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    monitor='sig_pur',
    mode="max",
    save_top_k=2,
    save_last=True)

In [None]:
# kaiming_init(model)
logger = WandbLogger(project="ITk_barrel_gnn")
trainer = Trainer(gpus=1, max_epochs=hparams["max_epochs"], logger=logger, num_sanity_val_steps=2, callbacks=[checkpoint_callback], log_every_n_steps = 50, default_root_dir="/global/cfs/cdirs/m3443/usr/ryanliu/ITk_barrel_gnn/")
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mexatrkx[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.14 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade



  | Name         | Type       | Params
--------------------------------------------
0 | node_encoder | Sequential | 34.3 K
1 | edge_encoder | Sequential | 34.7 K
2 | gnn_blocks   | ModuleList | 1.2 M 
3 | output_layer | Sequential | 82.8 K
--------------------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.400     Total estimated model params size (MB)


Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]



Validation sanity check:  50%|█████     | 1/2 [00:02<00:02,  2.35s/it]



                                                                      



Epoch 0:  99%|█████████▉| 1000/1010 [33:48<00:20,  2.03s/it, loss=0.0577, v_num=gcp9]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/10 [00:00<?, ?it/s][A




Epoch 0:  99%|█████████▉| 1002/1010 [33:50<00:16,  2.03s/it, loss=0.0577, v_num=gcp9]




Validating:  20%|██        | 2/10 [00:02<00:09,  1.25s/it][A




Epoch 0:  99%|█████████▉| 1004/1010 [33:51<00:12,  2.02s/it, loss=0.0577, v_num=gcp9]




Validating:  40%|████      | 4/10 [00:04<00:06,  1.02s/it][A




Epoch 0: 100%|█████████▉| 1006/1010 [33:53<00:08,  2.02s/it, loss=0.0577, v_num=gcp9]




Validating:  60%|██████    | 6/10 [00:05<00:03,  1.21it/s][A




Epoch 0: 100%|█████████▉| 1008/1010 [33:55<00:04,  2.02s/it, loss=0.0577, v_num=gcp9]




Validating:  80%|████████  | 8/10 [00:07<00:01,  1.18it/s][A




Epoch 0: 100%|██████████| 1010/1010 [33:56<00:00,  2.02s/it, loss=0.0577, v_num=gcp9]




Epoch 0: 100%|██████████| 1010/1010 [33:58<00:00,  2.02s/it, loss=0.0577, v_num=gcp9]
Epoch 1:  99%|█████████▉| 1000/1010 [33:48<00:20,  2.03s/it, loss=0.0545, v_num=gcp9]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/10 [00:00<?, ?it/s][A




Epoch 1:  99%|█████████▉| 1002/1010 [33:50<00:16,  2.03s/it, loss=0.0545, v_num=gcp9]




Validating:  20%|██        | 2/10 [00:02<00:09,  1.20s/it][A




Epoch 1:  99%|█████████▉| 1004/1010 [33:51<00:12,  2.02s/it, loss=0.0545, v_num=gcp9]




Validating:  40%|████      | 4/10 [00:04<00:06,  1.01s/it][A




Epoch 1: 100%|█████████▉| 1006/1010 [33:53<00:08,  2.02s/it, loss=0.0545, v_num=gcp9]




Validating:  60%|██████    | 6/10 [00:05<00:03,  1.21it/s][A




Epoch 1: 100%|█████████▉| 1008/1010 [33:55<00:04,  2.02s/it, loss=0.0545, v_num=gcp9]




Validating:  80%|████████  | 8/10 [00:07<00:01,  1.18it/s][A




Epoch 1: 100%|██████████| 1010/1010 [33:57<00:00,  2.02s/it, loss=0.0545, v_num=gcp9]




Epoch 1: 100%|██████████| 1010/1010 [33:58<00:00,  2.02s/it, loss=0.0545, v_num=gcp9]
Epoch 2:  99%|█████████▉| 1000/1010 [33:47<00:20,  2.03s/it, loss=0.0524, v_num=gcp9]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/10 [00:00<?, ?it/s][A




Epoch 2:  99%|█████████▉| 1002/1010 [33:49<00:16,  2.03s/it, loss=0.0524, v_num=gcp9]




Validating:  20%|██        | 2/10 [00:02<00:09,  1.22s/it][A




Epoch 2:  99%|█████████▉| 1004/1010 [33:50<00:12,  2.02s/it, loss=0.0524, v_num=gcp9]




Validating:  40%|████      | 4/10 [00:04<00:06,  1.02s/it][A




Epoch 2: 100%|█████████▉| 1006/1010 [33:52<00:08,  2.02s/it, loss=0.0524, v_num=gcp9]




Validating:  60%|██████    | 6/10 [00:05<00:03,  1.21it/s][A




Epoch 2: 100%|█████████▉| 1008/1010 [33:54<00:04,  2.02s/it, loss=0.0524, v_num=gcp9]




Validating:  80%|████████  | 8/10 [00:07<00:01,  1.18it/s][A




Epoch 2: 100%|██████████| 1010/1010 [33:55<00:00,  2.02s/it, loss=0.0524, v_num=gcp9]




Epoch 2: 100%|██████████| 1010/1010 [33:57<00:00,  2.02s/it, loss=0.0524, v_num=gcp9]
Epoch 3:  99%|█████████▉| 1000/1010 [33:46<00:20,  2.03s/it, loss=0.052, v_num=gcp9] 
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/10 [00:00<?, ?it/s][A




Epoch 3:  99%|█████████▉| 1002/1010 [33:48<00:16,  2.02s/it, loss=0.052, v_num=gcp9]




Validating:  20%|██        | 2/10 [00:02<00:09,  1.23s/it][A




Epoch 3:  99%|█████████▉| 1004/1010 [33:49<00:12,  2.02s/it, loss=0.052, v_num=gcp9]




Validating:  40%|████      | 4/10 [00:04<00:06,  1.03s/it][A




Epoch 3: 100%|█████████▉| 1006/1010 [33:51<00:08,  2.02s/it, loss=0.052, v_num=gcp9]




Validating:  60%|██████    | 6/10 [00:06<00:03,  1.08it/s][A




Epoch 3: 100%|█████████▉| 1008/1010 [33:53<00:04,  2.02s/it, loss=0.052, v_num=gcp9]




Validating:  80%|████████  | 8/10 [00:08<00:01,  1.12it/s][A




Epoch 3: 100%|██████████| 1010/1010 [33:54<00:00,  2.01s/it, loss=0.052, v_num=gcp9]




Epoch 3: 100%|██████████| 1010/1010 [33:56<00:00,  2.02s/it, loss=0.052, v_num=gcp9]
Epoch 4:  53%|█████▎    | 537/1010 [17:57<15:49,  2.01s/it, loss=0.0479, v_num=gcp9]

## Initialize from trained model

In [None]:
import wandb
wandb.finish()

with open("dual_embedding_default.yaml") as f:
    hparams = yaml.load(f, Loader=yaml.FullLoader)
    
hparams["use_dual_encoder"] = True

model = VanillaDualEmbedding(hparams)

checkpoint = torch.load("/global/cfs/cdirs/m3443/usr/ryanliu/ITk_embedding/ITk_dual_embedding/3ijb4qnw/checkpoints/last.ckpt")
state_dict = checkpoint["state_dict"]
names = [i for i in state_dict]
for i in names:
    state = state_dict[i]
    i = i.replace("input_layer1", "input_layer2")
    i = i.replace("layers1", "layers2")
    i = i.replace("output_layer1", "output_layer2")
    state_dict[i] = state

model.load_state_dict(state_dict)
del state_dict

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    monitor='pur',
    mode="max",
    save_top_k=2,
    save_last=True)

In [None]:
logger = WandbLogger(project="ITk_dual_embedding")
trainer = Trainer(gpus=1, max_epochs=hparams["max_epochs"], logger=logger, num_sanity_val_steps=2, callbacks=[checkpoint_callback], log_every_n_steps = 50, default_root_dir="/global/cfs/cdirs/m3443/usr/ryanliu/ITk_embedding/")
trainer.fit(model)