# Example of Metric Learning in Embedded Space

In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
run_name = input()

In [2]:
# System imports
import os
import sys
import yaml

# External imports
import matplotlib.pyplot as plt
import scipy as sp
from sklearn.decomposition import PCA
from sklearn.metrics import auc
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from torch import nn
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning import Trainer
import frnn

sys.path.append('../..')

from LightningModules.Embedding.Models.sphere_layerless_embedding import SphereLayerlessEmbedding
import copy
from LightningModules.Embedding.utils import multi_build_edges, graph_intersection

device = "cuda" if torch.cuda.is_available() else "cpu"

import wandb
from pytorch_lightning.callbacks import ModelCheckpoint

## Pytorch Lightning Model

In this example notebook, we will use an approach to ML called Pytorch Lightning. Pytorch is a library like Tensorflow, which is very popular in ML engineering. It's main appeal is foolproof tracking of gradients for backpropagation, and very easy manipulation of tensors on and off GPUs. 

Pytorch Lightning is an extension of Pytorch that makes some decisions about the best-practices for training. Instead of you writing the training loop yourself, and moving things on and off a GPU, it handles much of this for you. You write all the data loading logic, the loss functions, etc. into a `LightningModule` and then hand this module to a `Trainer`. Together, the module and trainer are the two objects that allow training and inference. 

So we start by importing a class that we have written ourselves, in this case a LightningModule that is in charge of loading TrackML (Codalab) data, and training and validating an embedding/metric learning model. 

### Construct PyLightning model

An ML model typically has many knobs to turn, as well as locations of data, some training preferences, and so on. For convenience, let's put all of these parameters into a YAML file and load it.

In [None]:
with open("Sphere_embedding_sweep.yaml") as f:
        sweep_hparams = yaml.load(f, Loader=yaml.FullLoader)
with open("Sphere_embedding_defaults.yaml") as f:
        default_hparams = yaml.load(f, Loader=yaml.FullLoader)


In [None]:
sweep_configuration = {
    "name": run_name,
    "project": "ITk_full",
    "metric": {"name": "eff", "goal": "maximize"},
    "method": "grid",
    "parameters": sweep_hparams
}

In [None]:
def test_training():
    
    model = SphereLayerlessEmbedding({**default_hparams})

    checkpoint_callback = ModelCheckpoint(
        monitor='eff',
        mode="max",
        save_top_k=2,
        save_last=True)

    logger = WandbLogger()
    trainer = Trainer(gpus=1, max_steps=default_hparams["max_steps"], logger=logger, val_check_interval = 0.5, num_sanity_val_steps=2, callbacks=[checkpoint_callback], default_root_dir="/global/cfs/cdirs/m3443/usr/ryanliu/ITk_embedding/")
    trainer.fit(model)

In [None]:
# test_training()

## Metric Learning

### Train embedding

Finally! Let's train! We instantiate a `Trainer` class that knows things like which hardware to work with, how long to train for, and a **bunch** of default options that we ignore here. Check out the Trainer class docs in Pytorch Lightning. Suffice it to say that it clears away much repetitive boilerplate in training code.

In [None]:
def training():
    wandb.init()
    model = SphereLayerlessEmbedding({**default_hparams, **wandb.config})

    checkpoint_callback = ModelCheckpoint(
        monitor='eff',
        mode="max",
        save_top_k=2,
        save_last=True)

    logger = WandbLogger()
    trainer = Trainer(gpus=1, max_steps=default_hparams["max_steps"], val_check_interval = 1000, logger=logger, num_sanity_val_steps=2, callbacks=[checkpoint_callback], default_root_dir="/global/cfs/cdirs/m3443/usr/ryanliu/ITk_embedding/")
    trainer.fit(model)

In [None]:
sweep_id = wandb.sweep(sweep_configuration, project = "ITk_full")

# run the sweep
wandb.agent(sweep_id, function=training)

# Build Edge Dataset

## Load best model

In [3]:
checkpoint_path = "/global/cfs/cdirs/m3443/usr/ryanliu/ITk_embedding/ITk_full/nbse7ida/checkpoints/epoch=1-step=7999.ckpt"
checkpoint = torch.load(checkpoint_path)

In [4]:
model = SphereLayerlessEmbedding.load_from_checkpoint(checkpoint_path).to(device)

In [5]:
%%time
model.setup(stage="fit")

CPU times: user 9.69 ms, sys: 4.78 ms, total: 14.5 ms
Wall time: 13.6 ms


## Define Building Class

In [6]:
class EmbeddingInferenceBuilder:
    def __init__(self, model, output_dir, overwrite=False):
        self.output_dir = output_dir
        self.model = model
        self.overwrite = overwrite

        # Prep the directory to produce inference data to
        self.datatypes = ["train", "val", "test"]
        os.makedirs(self.output_dir, exist_ok=True)
        [
            os.makedirs(os.path.join(self.output_dir, datatype), exist_ok=True)
            for datatype in self.datatypes
        ]


    def build(self):
        print("Training finished, running inference to build graphs...")

        # By default, the set of examples propagated through the pipeline will be train+val+test set
        datasets = {
            "test": self.model.testset,
            "val": self.model.valset,
            "train": self.model.trainset,
        }
        total_length = sum([len(dataset) for dataset in datasets.values()])
        batch_incr = 0
        eff = 0
        pur = 0
        self.model.eval()
        with torch.no_grad():
            for set_idx, (datatype, dataset) in enumerate(datasets.items()):
                for batch_idx, batch in enumerate(dataset):
                    percent = (batch_incr / total_length) * 100
                    sys.stdout.flush()
                    sys.stdout.write(f"{percent:.01f}% inference complete, eff: {eff:.01f}%, pur: {pur:.02f}%\r")
                    if (
                        not os.path.exists(
                            os.path.join(
                                self.output_dir, datatype, batch.event_file[-5:]
                            )
                        )
                    ) or self.overwrite:
                        batch_to_save = copy.deepcopy(batch)
                        batch_to_save = batch_to_save.to(
                            self.model.device
                        )
                        eff, pur = self.construct_downstream(batch_to_save, datatype)

                    batch_incr += 1

    def construct_downstream(self, batch, datatype):

        if "ci" in self.model.hparams["regime"]:
            input_data = torch.cat([batch.cell_data[:, :self.model.hparams["cell_channels"]], batch.x], axis=-1)
            spatials = self.model(input_data)
        else:
            input_data = batch.x
            input_data[input_data != input_data] = 0
            spatials = self.model(input_data)
        
        if self.model.hparams["normalize"]:
            spatials = nn.functional.normalize(spatials, p=2.0, dim=2, eps=1e-12)

        # Make truth bidirectional
        e_bidir = torch.cat(
            [batch["modulewise_true_edges"],
             batch["modulewise_true_edges"].flip(0)], axis=-1
        )

        # Build the radius graph with radius < r_test
        e_spatial = multi_build_edges(
            spatials, spatials, indices=None, r_max=self.model.hparams["r_test"], k_max =1000
        ).long()  # This step should remove reliance on r_val, and instead compute an r_build based on the EXACT r required to reach target eff/pur

        # Arbitrary ordering to remove half of the duplicate edges
        R_dist = torch.sqrt(batch.x[:, 0] ** 2 + batch.x[:, 2] ** 2)
        e_spatial = e_spatial[:, (R_dist[e_spatial[0]] <= R_dist[e_spatial[1]])]

        e_spatial_easy_fake = e_spatial[:, batch.pid[e_spatial[0]] != batch.pid[e_spatial[1]]]
        y_cluster_easy_fake = torch.zeros(e_spatial_easy_fake.shape[1])
        
        e_spatial_ambiguous = e_spatial[:, batch.pid[e_spatial[0]] == batch.pid[e_spatial[1]]]
        e_spatial_ambiguous, y_cluster_ambiguous = graph_intersection(e_spatial_ambiguous, e_bidir)
        
        e_spatial = torch.cat([e_spatial_easy_fake.cpu(), e_spatial_ambiguous], dim=-1)
        y_cluster = torch.cat([y_cluster_easy_fake, y_cluster_ambiguous])
        
        eff = (2*y_cluster.sum()/e_bidir.shape[1]).item()*100
        pur = (y_cluster.sum()/y_cluster.shape[0]).item()*100
        
        # e_spatial, y_cluster = graph_intersection(e_spatial, e_bidir)

        # Re-introduce random direction, to avoid training bias
        random_flip = torch.randint(2, (e_spatial.shape[1],)).bool()
        e_spatial[0, random_flip], e_spatial[1, random_flip] = (
            e_spatial[1, random_flip],
            e_spatial[0, random_flip],
        )
        e_spatial = e_spatial[:, torch.randperm(e_spatial.shape[1])]
        

        batch.edge_index = e_spatial
        batch.y = y_cluster

        self.save_downstream(batch, datatype)
        
        return eff, pur

    def save_downstream(self, batch, datatype):

        with open(
            os.path.join(self.output_dir, datatype, batch.event_file[-5:]), "wb"
        ) as pickle_file:
            torch.save(batch, pickle_file)


In [7]:
output_dir = "/global/cfs/cdirs/m3443/usr/ryanliu/ITk_embedding/ITk_processed/ITk_full"
model.hparams["r_test"] = 1.0
edge_builder = EmbeddingInferenceBuilder(model, output_dir, overwrite=True)

In [None]:
edge_builder.build()

Training finished, running inference to build graphs...
0.9% inference complete, eff: 94.8%, pur: 0.06%

In [None]:
len(model.trainset)