# Example of Filtering Embedded Pairs

In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
run_name = input()

In [1]:
# System imports
import os
import sys
import yaml

# External imports
import matplotlib.pyplot as plt
import scipy as sp
from sklearn.decomposition import PCA
from sklearn.metrics import auc
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning import Trainer
import wandb

sys.path.append('../..')
device = "cuda" if torch.cuda.is_available() else "cpu"

from LightningModules.Filter.Models.pyramid_filter import PyramidFilter
from LightningModules.Filter.utils import graph_intersection
from pytorch_lightning.callbacks import ModelCheckpoint

## Pytorch Lightning Model

### Construct PyLightning model

An ML model typically has many knobs to turn, as well as locations of data, some training preferences, and so on. For convenience, let's put all of these parameters into a YAML file and load it.

In [None]:
with open("filter-sweep.yaml") as f:
        sweep_hparams = yaml.load(f, Loader=yaml.FullLoader)
with open("filter.yaml") as f:
        default_hparams = yaml.load(f, Loader=yaml.FullLoader)

In [None]:
sweep_configuration = {
    "name": run_name,
    "project": "ITk_barrel_full_filter",
    "metric": {"name": "pur", "goal": "maximize"},
    "method": "grid",
    "parameters": sweep_hparams
}

In [None]:
def training():
    wandb.init()
    model = PyramidFilter({**default_hparams, **wandb.config})

    checkpoint_callback = ModelCheckpoint(
        monitor='pur',
        mode="max",
        save_top_k=2,
        save_last=True)

    logger = WandbLogger()
    trainer = Trainer(gpus=1, max_steps=default_hparams["max_steps"], val_check_interval = 1000, logger=logger, callbacks=[checkpoint_callback], default_root_dir="/global/cfs/cdirs/m3443/usr/ryanliu/ITk_filter/")
    trainer.fit(model)

In [None]:
sweep_id = wandb.sweep(sweep_configuration, project = "ITk_barrel_full_filter")

# run the sweep
wandb.agent(sweep_id, function=training)

## Metric Learning

### Train filter

Finally! Let's train! We instantiate a `Trainer` class that knows things like which hardware to work with, how long to train for, and a **bunch** of default options that we ignore here. Check out the Trainer class docs in Pytorch Lightning. Suffice it to say that it clears away much repetitive boilerplate in training code.

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    monitor='eff',
    mode="max",
    save_top_k=2,
    save_last=True)

model = PyramidFilter({**default_hparams})

In [None]:
%%time
# logger = WandbLogger(project="ITk_0GeV_Filter")
trainer = Trainer(gpus=1, max_epochs=default_hparams["max_epochs"], num_sanity_val_steps=2, logger=None, callbacks=[checkpoint_callback])
trainer.fit(model)

## Build Edges

In [31]:
checkpoint_path = "/global/cfs/cdirs/m3443/usr/ryanliu/ITk_filter/ITk_barrel_full_filter/clzmphv8/checkpoints/last.ckpt"
checkpoint = torch.load(checkpoint_path)

In [32]:
model = PyramidFilter.load_from_checkpoint(checkpoint_path).to(device)

In [33]:
%%time
model.setup(stage="fit")

CPU times: user 4.89 ms, sys: 4.75 ms, total: 9.64 ms
Wall time: 55.6 ms


In [34]:
import copy
class FilterInferenceBuilder:
    def __init__(self, model, output_dir, overwrite=False):
        self.output_dir = output_dir
        self.model = model
        self.overwrite = overwrite

        # Prep the directory to produce inference data to
        self.datatypes = ["test", "val", "train"]
        os.makedirs(self.output_dir, exist_ok=True)
        [
            os.makedirs(os.path.join(self.output_dir, datatype), exist_ok=True)
            for datatype in self.datatypes
        ]


    def build(self):
        print("Training finished, running inference to build graphs...")

        # By default, the set of examples propagated through the pipeline will be train+val+test set
        datasets = {
            # "test": self.model.testset,
            "val": self.model.valset,
            # "train": self.model.trainset,
        }
        total_length = sum([len(dataset) for dataset in datasets.values()])
        batch_incr = 0
        eff = 0
        pur = 0
        self.model.eval()
        with torch.no_grad():
            for set_idx, (datatype, dataset) in enumerate(datasets.items()):
                for batch_idx, batch in enumerate(dataset):
                    
                    batch = torch.load(batch, map_location=torch.device("cpu"))
                    percent = (batch_incr / total_length) * 100
                    sys.stdout.flush()
                    sys.stdout.write(f"{percent:.01f}% inference complete, eff: {eff:.02f}%, pur: {pur:.03f}%\r")
                    if (
                        not os.path.exists(
                            os.path.join(
                                self.output_dir, datatype, batch.event_file[-5:]
                            )
                        )
                    ) or self.overwrite:
                        batch_to_save = copy.deepcopy(batch)
                        batch_to_save = batch_to_save.to(
                            self.model.device
                        )
                        eff, pur = self.construct_downstream(batch_to_save, datatype)

                    batch_incr += 1

    def construct_downstream(self, batch, datatype):

        score_list = []
        chunks = torch.chunk(batch.idxs, self.model.hparams["n_chunks"], dim = 1)
        input_data = self.model.get_input_data(batch)
        e_bidir = torch.cat(
            [batch.modulewise_true_edges, batch.modulewise_true_edges.flip(0)], axis=-1
        )
        all_edges = torch.empty([2, 0], dtype=torch.int64).cpu()
        all_y = torch.empty([0], dtype=torch.int64).cpu()
        all_scores = torch.empty([0], dtype=torch.float).cpu()
        
        for chunk in chunks:

            scores = torch.zeros(chunk.shape).to(self.model.device)
            ind = torch.Tensor.repeat(torch.arange(chunk.shape[0], device=self.model.device), (chunk.shape[1], 1)).T.int()
            
            positive_idxs = chunk >= 0
            edges = torch.stack([ind[positive_idxs], chunk[positive_idxs]]).long()
            
            output = self.model(
                    input_data,
                    edges
                ).squeeze()
            scores[positive_idxs] = torch.sigmoid(output)
            score_list.append(scores.detach().cpu())
            
            # compute val loss
            truth_mask = (batch.pid[edges[0]] == batch.pid[edges[1]]) & (batch.pid[edges] != 0).all(0)
            edges_easy_fake = edges[:,truth_mask.logical_not()].clone().detach()
            edges_ambiguous = edges[:,truth_mask].clone().detach()
            if edges_ambiguous.numel() != 0:
                edges_ambiguous, y_ambiguous = graph_intersection(edges_ambiguous, e_bidir)
                edges = torch.cat([edges_easy_fake, edges_ambiguous.to(self.model.device)], dim = 1)
                y = torch.cat([torch.zeros(edges_easy_fake.shape[1]), y_ambiguous], dim = 0)
            else: 
                edges = edges_easy_fake
                y = torch.zeros(edges_easy_fake.shape[1])
            
            output = self.model(
                    input_data,
                    edges
                ).squeeze()
            
            all_scores = torch.cat([all_scores, torch.sigmoid(output).cpu()], dim = 0)
            all_edges = torch.cat([all_edges, edges.cpu()], dim = 1)
            all_y = torch.cat([all_y, y.cpu()], dim = 0)
            
        score_list = torch.cat(score_list, dim = 1)
        
        
        # Find Cut
        pt_mask = (batch.pt[e_bidir] >= self.model.hparams["signal_pt_cut"]).all(0)

        eff_cut_score = 0.296
        
        cut_list = (all_scores >= eff_cut_score)
        
        # For efficeincy and purity, evaluate on modulewise truth.
        modulewise_true = pt_mask.sum()
        prediction_pt_mask = (batch.pt[all_edges] >= self.model.hparams["signal_pt_cut"]).all(0)
        modulewise_true_positive = (all_y.bool() & cut_list)[prediction_pt_mask].sum()
        modulewise_true_positive_without_cut = (all_y.bool() & cut_list).sum()
        modulewise_positive = cut_list.sum()
        
        eff = (modulewise_true_positive / modulewise_true).clone().detach()
        pur = (modulewise_true_positive_without_cut / modulewise_positive).clone().detach()
        
        batch.idxs_scores = score_list.clone().detach()

        self.save_downstream(batch, datatype)
        
        return eff, pur

    def save_downstream(self, batch, datatype):

        with open(
            os.path.join(self.output_dir, datatype, batch.event_file[-5:]), "wb"
        ) as pickle_file:
            torch.save(batch, pickle_file)


In [35]:
output_dir = "/global/cfs/cdirs/m3443/usr/ryanliu/ITk_filter/filter_processed"

In [36]:
edge_builder = FilterInferenceBuilder(model, output_dir, overwrite=True)
edge_builder.build()

Training finished, running inference to build graphs...
80.0% inference complete, eff: 0.95%, pur: 0.022%

In [8]:
class FilteringSelecting:
    def __init__(self, input_dirs, score_cut):
        
        self.input_dirs = input_dirs
        self.score_cut = score_cut

    def select(self):
        all_events = []
        print("Selecting data...")
        for input_dir in self.input_dirs:
            events = os.listdir(input_dir)
            all_events.extend([os.path.join(input_dir, event) for event in events])

        all_events = sorted(all_events)
        
        total_length = len(all_events)
        batch_incr = 0
        
        for event in all_events:
                percent = (batch_incr / total_length) * 100
                sys.stdout.flush()
                sys.stdout.write(f"{percent:.01f}% select complete\r")
                
                try:
                    batch = torch.load(event, map_location=torch.device("cuda"))
                except:
                    batch_incr += 1
                    continue
                    
                batch.idxs[batch.idxs_scores < self.score_cut] = -1
                
                delattr(batch, "idxs_scores")
                
                with open(
                    event, "wb"
                ) as pickle_file:
                    torch.save(batch, pickle_file)

                batch_incr += 1

In [9]:
selector = FilteringSelecting(["/global/cfs/cdirs/m3443/usr/ryanliu/ITk_filter/filter_processed/test",
                               "/global/cfs/cdirs/m3443/usr/ryanliu/ITk_filter/filter_processed/train"
                              ], 0.296)