In [None]:
%load_ext autoreload
%autoreload 2

### Imports
Import model_selector which takes model **name/ID** as input and annitialize the model according to the default configuration yaml file in **LightningModules/TrackML_ACAT/MODEL_NAME/Configs** and an optional input dictionary which contain the parameters one would like to overwrite. 

Here's a list of model name/ID:
- 1: Edge Classifier Based on Interaction Network (EC-IN)
- 2: Node Embeddings Network Based on Interaction Network (Embedding-IN)
- 3: Node Embeddings Network Based on Hierarchical GNN with GMM clustering (Embedding-HGNN-GMM)
- 4: Node Embeddings Network Based on Hierarchical GNN with HDBSCAN (Ebedding-HGNN-HDBSCAN)
- 5: Bipartite Edge Classifier Based on Hierarchical GNN with GMM clustering (BC-HGNN-GMM)
- 6: Bipartite Edge Classifier Based on Hierarchical GNN with HDBSCAN (BC-HGNN-HDBSCAN)

GMM models are more preferable to HDBSCAN models

In [None]:
# System imports
import os
import sys
import yaml

# External imports
import matplotlib.pyplot as plt
import scipy as sp
from sklearn.metrics import auc
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning import Trainer
import wandb
import math
sys.path.append('../..')

from LightningModules.TrackML_ACAT.training_utils import model_selector

device = "cuda" if torch.cuda.is_available() else "cpu"
from pytorch_lightning.callbacks import ModelCheckpoint

from LightningModules.TrackML_ACAT.tracking_utils import eval_metrics

In [None]:
def kaiming_init(model):
    for name, param in model.named_parameters():
        try:
            if name.endswith(".bias"):
                param.data.fill_(0)
            elif name.endswith("0.weight"):  # The first layer does not have ReLU applied on its input
                param.data.normal_(0, 1 / math.sqrt(param.shape[1]))
            else:
                param.data.normal_(0, math.sqrt(2) / math.sqrt(param.shape[1]))
        except IndexError as E:
            continue

In [None]:
def load_from_pretrained(model, path = None, ckpt = None):
    
    if ckpt is None:
        ckpt = torch.load(path)
    else:
        pass
    state_dict = ckpt["state_dict"]
    model.load_state_dict(state_dict, strict=False)
    del state_dict
    
    return model

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    mode="min",
    save_top_k=2,
    save_last=True)

### Training A New Model

In [None]:
print("input model ID/name")
model_name = input()
model = model_selector(model_name)
kaiming_init(model)

In [None]:
logger = WandbLogger(project="TrackML_1GeV")
accumulator = GradientAccumulationScheduler(scheduling={0: 1, 4: 2, 8: 4})
trainer = Trainer(gpus=1, max_epochs=model.hparams["max_epochs"], gradient_clip_val=0.5, logger=logger, num_sanity_val_steps=2, callbacks=[checkpoint_callback], log_every_n_steps = 50, default_root_dir="/global/cfs/cdirs/m3443/usr/ryanliu/TrackML/")
trainer.fit(model)

### Resume an interupted training

In [None]:
print("input the wandb run ID to resume the run")
training_id = input()
model_path = "/global/cfs/cdirs/m3443/usr/ryanliu/TrackML/TrackML_1GeV/{}/checkpoints/last.ckpt".format(training_id)
ckpt = torch.load(model_path)
model = model_selector(ckpt["hyper_parameters"]["model"], ckpt["hyper_parameters"])
    
logger = WandbLogger(project="TrackML_1GeV", id = training_id)
accumulator = GradientAccumulationScheduler(scheduling={0: 1, 4: 2, 8: 4})
trainer = Trainer(gpus=1, max_epochs=ckpt["hyper_parameters"]["max_epochs"], gradient_clip_val=0.5, logger=logger, num_sanity_val_steps=2, callbacks=[checkpoint_callback], log_every_n_steps = 50, default_root_dir="/global/cfs/cdirs/m3443/usr/ryanliu/TrackML/")
trainer.fit(model, ckpt_path="/global/cfs/cdirs/m3443/usr/ryanliu/TrackML/TrackML_1GeV/{}/checkpoints/last.ckpt".format(training_id))

### Test
Running test on test dataset

In [None]:
print("input the majority cut (0.5 for loose matching, 0.9 for strict matching, 1.0 for perfect matching")
inference_config = {
    "majority_cut": float(input()),
    "score_cut": 0.7
}
print("input the wandb run ID to load model's state dict")
model_path = "/global/cfs/cdirs/m3443/usr/ryanliu/TrackML/TrackML_1GeV/{}/checkpoints/".format(input())
model_paths = os.listdir(model_path)
model_paths.remove("last.ckpt")
ckpt_name = model_paths[0]
for i in model_paths:
    if int(i.strip("epoch=").split("-")[0]) > int(ckpt_name.strip("epoch=").split("-")[0]):
        ckpt_name = i
model_path = os.path.join(model_path, ckpt_name)

ckpt = torch.load(model_path)
sweep_configs = {**(ckpt["hyper_parameters"]), **inference_config}

model = model_selector(ckpt["hyper_parameters"]["model"], sweep_configs)
    
model = load_from_pretrained(model, ckpt = ckpt)
model.setup("test")
trainer = Trainer(gpus=1)
test_results = trainer.test(model, model.test_dataloader())[0]

In [None]:
print(*[test_results[i] for i in test_results], sep = '\n')

In [None]:
print(*[i for i in test_results], sep = '\n')