In [1]:
%load_ext autoreload
%autoreload 2

### Imports
Import model_selector which takes model **name/ID** as input and annitialize the model according to the default configuration yaml file in **Modules/MODEL_NAME/Configs** and an optional input dictionary which contain the parameters one would like to overwrite. 

Here's a list of model name/ID:
- 1: Edge Classifier Based on Interaction Network (EC-IN)
- 2: Node Embeddings Network Based on Interaction Network (Embedding-IN)
- 3: Node Embeddings Network Based on Hierarchical GNN with GMM clustering (Embedding-HGNN-GMM)
- 4: Bipartite Edge Classifier Based on Hierarchical GNN with GMM clustering (BC-HGNN-GMM)

In [2]:
# System imports
import os
import sys
import yaml

# External imports
import matplotlib.pyplot as plt
import scipy as sp
from sklearn.metrics import auc
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning import Trainer
import wandb
import math
sys.path.append('..')

from Modules.training_utils import model_selector

device = "cuda" if torch.cuda.is_available() else "cpu"
from pytorch_lightning.callbacks import ModelCheckpoint

from Modules.tracking_utils import eval_metrics

In [3]:
def kaiming_init(model):
    for name, param in model.named_parameters():
        try:
            if name.endswith(".bias"):
                param.data.fill_(0)
            elif name.endswith("0.weight"):  # The first layer does not have ReLU applied on its input
                param.data.normal_(0, 1 / math.sqrt(param.shape[1]))
            else:
                param.data.normal_(0, math.sqrt(2) / math.sqrt(param.shape[1]))
        except IndexError as E:
            continue

In [4]:
def load_from_pretrained(model, path = None, ckpt = None):
    
    if ckpt is None:
        ckpt = torch.load(path)
    else:
        pass
    state_dict = ckpt["state_dict"]
    model.load_state_dict(state_dict, strict=False)
    del state_dict
    
    return model

In [5]:
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    mode="min",
    save_top_k=2,
    save_last=True)

In [6]:
ROOT_PATH = "/global/cfs/cdirs/m3443/usr/ryanliu/TrackML/TrackML_1GeV/"

### Training A New Model

In [7]:
model_name = input("input model ID/name")
model = model_selector(model_name)
kaiming_init(model)

input model ID/name 4


In [None]:
logger = WandbLogger(project="TrackML_1GeV")
trainer = Trainer(gpus=1, max_epochs=model.hparams["max_epochs"], gradient_clip_val=0.5, logger=logger, num_sanity_val_steps=2, callbacks=[checkpoint_callback], log_every_n_steps = 50, default_root_dir=ROOT_PATH)
trainer.fit(model)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mliuryan30[0m ([33mexatrkx[0m). Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                   | Type                 | Params
----------------------------------------------------------------
0 | ignn_block             | InteractionGNNBlock  | 2.0 M 
1 | hgnn_block             | HierarchicalGNNBlock | 5.5 M 
2 | bipartite_output_layer | Sequential           | 132 K 
----------------------------------------------------------------
7.7 M     Trainable params
0         Non-trainable params
7.7 M     Total params
30.726    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



### Resume an interupted training

In [None]:
training_id = input("input the wandb run ID to resume the run")
model_path = "{}{}/checkpoints/last.ckpt".format(ROOT_PATH, training_id)
ckpt = torch.load(model_path)
model = model_selector(ckpt["hyper_parameters"]["model"], ckpt["hyper_parameters"])
    
logger = WandbLogger(project="TrackML_1GeV", id = training_id)
accumulator = GradientAccumulationScheduler(scheduling={0: 1, 4: 2, 8: 4})
trainer = Trainer(gpus=1, max_epochs=ckpt["hyper_parameters"]["max_epochs"], gradient_clip_val=0.5, logger=logger, num_sanity_val_steps=2, callbacks=[checkpoint_callback], log_every_n_steps = 50, default_root_dir=ROOT_PATH)
trainer.fit(model, ckpt_path="{}{}/checkpoints/last.ckpt".format(ROOT_PATH, training_id))

### Test
Running test on test dataset

In [None]:
inference_config = {
    "majority_cut": float(input("majority cut (0.5 for loose matching, 0.9 for strict matching, 1.0 for perfect matching")),
    "score_cut": 0.7
}
model_path = "{}{}/checkpoints/".format(ROOT_PATH, input("input the wandb run ID to load model's state dict"))
model_paths = os.listdir(model_path)
model_paths.remove("last.ckpt")
ckpt_name = model_paths[0]
for i in model_paths:
    if int(i.strip("epoch=").split("-")[0]) > int(ckpt_name.strip("epoch=").split("-")[0]):
        ckpt_name = i
model_path = os.path.join(model_path, ckpt_name)

ckpt = torch.load(model_path)
sweep_configs = {**(ckpt["hyper_parameters"]), **inference_config}

model = model_selector(ckpt["hyper_parameters"]["model"], sweep_configs)
    
model = load_from_pretrained(model, ckpt = ckpt)
model.setup("test")
trainer = Trainer(gpus=1)
test_results = trainer.test(model, model.test_dataloader())[0]