In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
run_name = input()

 regime


In [2]:
# System imports
import os
import sys
import yaml

# External imports
import matplotlib.pyplot as plt
import scipy as sp
from sklearn.metrics import auc
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning import Trainer
import wandb
import math
sys.path.append('../..')

from LightningModules.GNNNodeEmbedding.Models.gnn_embedding import InteractionGNN

device = "cuda" if torch.cuda.is_available() else "cpu"
from pytorch_lightning.callbacks import ModelCheckpoint

In [3]:
def kaiming_init(model):
    for name, param in model.named_parameters():
        if name.endswith(".bias"):
            param.data.fill_(0)
        elif name.startswith("layers.0"):  # The first layer does not have ReLU applied on its input
            param.data.normal_(0, 1 / math.sqrt(param.shape[1]))
        else:
            param.data.normal_(0, math.sqrt(2) / math.sqrt(param.shape[1]))

## Sweep

In [4]:
with open("gnn_node_embedding_sweep.yaml") as f:
        sweep_hparams = yaml.load(f, Loader=yaml.FullLoader)
with open("gnn_node_embedding_default.yaml") as f:
        default_hparams = yaml.load(f, Loader=yaml.FullLoader)

In [6]:
sweep_configuration = {
    "name": run_name,
    "project": "ITk_barrell_gnn_embedding",
    "metric": {"name": "pur", "goal": "maximize"},
    "method": "grid",
    "parameters": sweep_hparams
}

In [7]:
def training():
    wandb.init()
    model = InteractionGNN({**default_hparams, **wandb.config})

    # kaiming_init(model)
    
    checkpoint_callback = ModelCheckpoint(
        monitor='pur',
        mode="max",
        save_top_k=2,
        save_last=True)

    logger = WandbLogger()
    trainer = Trainer(gpus=1, max_epochs=default_hparams["max_epochs"], log_every_n_steps = 50, logger=logger, callbacks=[checkpoint_callback], default_root_dir="/global/cfs/cdirs/m3443/usr/ryanliu/ITk_gnn_embedding/")
    trainer.fit(model)

In [None]:
sweep_id = wandb.sweep(sweep_configuration, project = "ITk_barrel_gnn_embedding")

# run the sweep
wandb.agent(sweep_id, function=training)

## Construct PyLightning model

In [4]:
with open("gnn_node_embedding_default.yaml") as f:
    hparams = yaml.load(f, Loader=yaml.FullLoader)

In [5]:
model = InteractionGNN(hparams)

## Metric Learning

In [6]:
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    monitor='pur',
    mode="max",
    save_top_k=2,
    save_last=True)

In [None]:
# kaiming_init(model)
logger = WandbLogger(project="ITk_gnn_embedding")
trainer = Trainer(gpus=1, max_epochs=hparams["max_epochs"], logger=logger, num_sanity_val_steps=2, callbacks=[checkpoint_callback], log_every_n_steps = 50, default_root_dir="/global/cfs/cdirs/m3443/usr/ryanliu/ITk_gnn_embedding/")
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(

  | Name         | Type             | Params
--------------------------------------------------
0 | cos          | CosineSimilarity | 0     
1 | node_encoder | Sequential       | 136 K 
2 | edge_encoder | Sequential       | 658 K 
3 | gnn_blocks   | ModuleList       | 37.9 M
4 | output_layer | Sequential       | 66.6 K
--------------------------------------------------
38.7 M    Trainable params
0         Non-trainable params
38.7 M    Total params
154.883   Total estimated model params size (MB)


Validation sanity check:  50%|█████     | 1/2 [00:00<00:00,  1.19it/s]



                                                                      



Epoch 0:   0%|          | 0/1010 [00:00<?, ?it/s] 



Epoch 0:  99%|█████████▉| 1000/1010 [04:15<00:02,  3.92it/s, loss=0.854, v_num=h42n]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/10 [00:00<?, ?it/s][A
Epoch 0:  99%|█████████▉| 1002/1010 [04:16<00:02,  3.91it/s, loss=0.854, v_num=h42n]




Validating:  20%|██        | 2/10 [00:01<00:03,  2.05it/s][A
Epoch 0:  99%|█████████▉| 1004/1010 [04:16<00:01,  3.91it/s, loss=0.854, v_num=h42n]




Validating:  40%|████      | 4/10 [00:01<00:02,  2.81it/s][A




Epoch 0: 100%|█████████▉| 1006/1010 [04:17<00:01,  3.91it/s, loss=0.854, v_num=h42n]
Validating:  60%|██████    | 6/10 [00:02<00:01,  3.47it/s][A




Epoch 0: 100%|█████████▉| 1008/1010 [04:17<00:00,  3.91it/s, loss=0.854, v_num=h42n]




Validating:  80%|████████  | 8/10 [00:02<00:00,  3.56it/s][A




Epoch 0: 100%|██████████| 1010/1010 [04:18<00:00,  3.91it/s, loss=0.854, v_num=h42n]




Epoch 0: 100%|██████████| 1010/1010 [04:18<00:00,  3.90it/s, loss=0.854, v_num=h42n]
Epoch 1:   0%|          | 0/1010 [00:00<?, ?it/s, loss=0.854, v_num=h42n]           



Epoch 1:  27%|██▋       | 270/1010 [01:09<03:10,  3.88it/s, loss=0.852, v_num=h42n]

## Initialize from trained model

In [None]:
import wandb
wandb.finish()

with open("dual_embedding_default.yaml") as f:
    hparams = yaml.load(f, Loader=yaml.FullLoader)
    
hparams["use_dual_encoder"] = True

model = VanillaDualEmbedding(hparams)

checkpoint = torch.load("/global/cfs/cdirs/m3443/usr/ryanliu/ITk_embedding/ITk_dual_embedding/3ijb4qnw/checkpoints/last.ckpt")
state_dict = checkpoint["state_dict"]
names = [i for i in state_dict]
for i in names:
    state = state_dict[i]
    i = i.replace("input_layer1", "input_layer2")
    i = i.replace("layers1", "layers2")
    i = i.replace("output_layer1", "output_layer2")
    state_dict[i] = state

model.load_state_dict(state_dict)
del state_dict

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    monitor='pur',
    mode="max",
    save_top_k=2,
    save_last=True)

In [None]:
logger = WandbLogger(project="ITk_dual_embedding")
trainer = Trainer(gpus=1, max_epochs=hparams["max_epochs"], logger=logger, num_sanity_val_steps=2, callbacks=[checkpoint_callback], log_every_n_steps = 50, default_root_dir="/global/cfs/cdirs/m3443/usr/ryanliu/ITk_embedding/")
trainer.fit(model)