In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# System imports
import os
import sys
import yaml

# External imports
import matplotlib.pyplot as plt
import scipy as sp
from sklearn.metrics import auc
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning import Trainer
import wandb
import math
sys.path.append('../..')

from LightningModules.HGNN.Models.pyramid_models import HierarchicalGNN

device = "cuda" if torch.cuda.is_available() else "cpu"
from pytorch_lightning.callbacks import ModelCheckpoint

In [None]:
def kaiming_init(model):
    for name, param in model.named_parameters():
        try:
            if name.endswith(".bias"):
                param.data.fill_(0)
            elif name.endswith("0.weight"):  # The first layer does not have ReLU applied on its input
                param.data.normal_(0, 1 / math.sqrt(param.shape[1]))
            else:
                param.data.normal_(0, math.sqrt(2) / math.sqrt(param.shape[1]))
        except IndexError as E:
            continue

In [None]:
def load_from_pretrained(model, path):
    
    checkpoint = torch.load(path)
    state_dict = checkpoint["state_dict"]
    model.load_state_dict(state_dict, strict=False)
    del state_dict
    
    return model

## Construct PyLightning model

In [None]:
with open("object_condensation_default.yaml") as f:
    hparams = yaml.load(f, Loader=yaml.FullLoader)

In [None]:
if hparams["use_toy"]:
    hparams["regime"] = []
    hparams["spatial_channels"] = 2

model = HierarchicalGNN(hparams)
model_path = "/global/cfs/cdirs/m3443/usr/ryanliu/ITk_object_condensation/ITk_object_condensation/1org0kz3/checkpoints/last.ckpt"

if hparams["use_pretrain"]:
    model = load_from_pretrained(model, model_path)
else:
    kaiming_init(model)

## Metric Learning

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    monitor='track_eff',
    mode="max",
    save_top_k=2,
    save_last=True)

In [None]:
logger = WandbLogger(project="ITk_object_condensation")
accumulator = GradientAccumulationScheduler(scheduling={0: 1, 4: 2, 8: 4})
trainer = Trainer(gpus=1, max_epochs=hparams["max_epochs"], gradient_clip_val=0.5, logger=logger, num_sanity_val_steps=2, callbacks=[checkpoint_callback], log_every_n_steps = 50, default_root_dir="/global/cfs/cdirs/m3443/usr/ryanliu/ITk_object_condensation/")
trainer.fit(model)

In [None]:
training_id = input()
model_path = "/global/cfs/cdirs/m3443/usr/ryanliu/ITk_object_condensation/ITk_object_condensation/{}/checkpoints/last.ckpt".format(training_id)
model = HierarchicalGNN.load_from_checkpoint(model_path)
    
logger = WandbLogger(project="ITk_object_condensation", id = training_id)
accumulator = GradientAccumulationScheduler(scheduling={0: 1, 4: 2, 8: 4})
trainer = Trainer(gpus=1, max_epochs=hparams["max_epochs"], gradient_clip_val=0.5, logger=logger, num_sanity_val_steps=2, callbacks=[checkpoint_callback], log_every_n_steps = 50, default_root_dir="/global/cfs/cdirs/m3443/usr/ryanliu/ITk_object_condensation/")
trainer.fit(model, ckpt_path="/global/cfs/cdirs/m3443/usr/ryanliu/ITk_object_condensation/ITk_object_condensation/{}/checkpoints/last.ckpt".format(training_id))

## Sweep

In [None]:
run_name = input()

In [None]:
with open("object_condensation_sweep.yaml") as f:
        sweep_hparams = yaml.load(f, Loader=yaml.FullLoader)
with open("object_condensation_default.yaml") as f:
        default_hparams = yaml.load(f, Loader=yaml.FullLoader)   

In [None]:
sweep_configuration = {
    "name": run_name,
    "project": "ITk_object_condensation",
    "metric": {"name": "track_eff", "goal": "maximize"},
    "method": "grid",
    "parameters": sweep_hparams
}

In [None]:
def training():
    wandb.init()
    hparams = {**default_hparams, **wandb.config}
    if hparams["use_toy"]:
        hparams["regime"] = []
        hparams["spatial_channels"] = 2
        
    model = DualHierarchicalGNN(hparams)
    
    if hparams["use_pretrain"]:
        model = load_from_pretrained(model, model_path)
    
    checkpoint_callback = ModelCheckpoint(
        monitor='track_eff',
        mode="max",
        save_top_k=2,
        save_last=True)
    
    kaiming_init(model)

    logger = WandbLogger()
    trainer = Trainer(gpus=1, max_epochs=default_hparams["max_epochs"], log_every_n_steps = 50, logger=logger, callbacks=[checkpoint_callback], default_root_dir="/global/cfs/cdirs/m3443/usr/ryanliu/ITk_barrel_embedding/")
    trainer.fit(model)

In [None]:
sweep_id = wandb.sweep(sweep_configuration, project = "ITk_object_condensation")

# run the sweep
wandb.agent(sweep_id, function=training)

## Test

In [None]:
with open("object_condensation_default.yaml") as f:
    hparams = yaml.load(f, Loader=yaml.FullLoader)
if hparams["use_toy"]:
    hparams["regime"] = []
    hparams["spatial_channels"] = 2

model = HierarchicalGNN(hparams)
model_path = "/global/cfs/cdirs/m3443/usr/ryanliu/ITk_object_condensation/ITk_object_condensation/1org0kz3/checkpoints/last.ckpt"
model = load_from_pretrained(model, model_path)
model.setup("test")
trainer = Trainer(gpus=1)
test_results = trainer.test(model, model.test_dataloader())

In [None]:
print(test_results)