# Example of Metric Learning in Embedded Space

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# System imports
import os
import sys
import yaml

# External imports
import matplotlib.pyplot as plt
import scipy as sp
from sklearn.decomposition import PCA
from sklearn.metrics import auc
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning import Trainer
import frnn

sys.path.append('../..')

from LightningModules.Embedding.Models.multi_layerless_embedding import MultiLayerlessEmbedding
import copy
from LightningModules.Embedding.utils import multi_build_edges, graph_intersection

device = "cuda" if torch.cuda.is_available() else "cpu"

## Pytorch Lightning Model

In this example notebook, we will use an approach to ML called Pytorch Lightning. Pytorch is a library like Tensorflow, which is very popular in ML engineering. It's main appeal is foolproof tracking of gradients for backpropagation, and very easy manipulation of tensors on and off GPUs. 

Pytorch Lightning is an extension of Pytorch that makes some decisions about the best-practices for training. Instead of you writing the training loop yourself, and moving things on and off a GPU, it handles much of this for you. You write all the data loading logic, the loss functions, etc. into a `LightningModule` and then hand this module to a `Trainer`. Together, the module and trainer are the two objects that allow training and inference. 

So we start by importing a class that we have written ourselves, in this case a LightningModule that is in charge of loading TrackML (Codalab) data, and training and validating an embedding/metric learning model. 

### Construct PyLightning model

An ML model typically has many knobs to turn, as well as locations of data, some training preferences, and so on. For convenience, let's put all of these parameters into a YAML file and load it.

In [3]:
with open("Multi_embedding.yaml") as f:
        hparams = yaml.load(f, Loader=yaml.FullLoader)

We plug these parameters into a constructor of the `LayerlessEmbedding` Lightning Module. This doesn't **do** anything yet - merely creates the object.

In [4]:
model = MultiLayerlessEmbedding(hparams)

## Metric Learning

### Train embedding

Finally! Let's train! We instantiate a `Trainer` class that knows things like which hardware to work with, how long to train for, and a **bunch** of default options that we ignore here. Check out the Trainer class docs in Pytorch Lightning. Suffice it to say that it clears away much repetitive boilerplate in training code.

In [5]:
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    monitor='eff@1.0',
    mode="max",
    save_top_k=2,
    save_last=True)

In [6]:
logger = WandbLogger(project="ITk_0.5GeV", group="Common_track_multi_embedding")
trainer = Trainer(gpus=1, max_epochs=hparams["max_epochs"], logger=logger, num_sanity_val_steps=2, callbacks=[checkpoint_callback], default_root_dir="/global/cfs/cdirs/m3443/usr/ryanliu/ITk_embedding/")
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mexatrkx[0m (use `wandb login --relogin` to force relogin)



  | Name            | Type       | Params
-----------------------------------------------
0 | layers          | ModuleList | 4.2 M 
1 | n_spaces_layers | ModuleList | 8.2 K 
2 | act             | GELU       | 0     
-----------------------------------------------
4.2 M     Trainable params
0         Non-trainable params
4.2 M     Total params
16.876    Total estimated model params size (MB)


Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


tensor([1., 1., 1.], device='cuda:0')
Validation sanity check:  50%|█████     | 1/2 [00:06<00:06,  6.85s/it]

  eff = torch.tensor(cluster_true_positive / cluster_true)
  pur = torch.tensor(cluster_true_positive / cluster_positive)


tensor([1., 1., 1.], device='cuda:0')
                                                                      

  rank_zero_warn(


Epoch 0:   0%|          | 0/1020 [00:00<?, ?it/s] tensor([1., 1., 1.], device='cuda:0')
Epoch 0:   0%|          | 1/1020 [00:00<13:38,  1.24it/s, loss=0.929, v_num=ts5c]tensor([1., 1., 1.], device='cuda:0')
Epoch 0:   0%|          | 2/1020 [00:01<08:54,  1.90it/s, loss=0.934, v_num=ts5c]tensor([1., 1., 1.], device='cuda:0')
Epoch 0:   0%|          | 3/1020 [00:01<07:33,  2.24it/s, loss=0.934, v_num=ts5c]tensor([1., 1., 1.], device='cuda:0')
Epoch 0:   0%|          | 4/1020 [00:01<06:41,  2.53it/s, loss=0.935, v_num=ts5c]tensor([1., 1., 1.], device='cuda:0')
Epoch 0:   0%|          | 5/1020 [00:01<06:13,  2.72it/s, loss=0.936, v_num=ts5c]tensor([1., 1., 1.], device='cuda:0')
Epoch 0:   1%|          | 6/1020 [00:02<05:49,  2.90it/s, loss=0.936, v_num=ts5c]tensor([1., 1., 1.], device='cuda:0')
Epoch 0:   1%|          | 7/1020 [00:02<05:41,  2.97it/s, loss=0.936, v_num=ts5c]tensor([1., 1., 1.], device='cuda:0')
Epoch 0:   1%|          | 8/1020 [00:02<05:35,  3.01it/s, loss=0.936, v_num=ts5

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


### Test embedding

A Pytorch Lightning Trainer has two main methods: `fit` and `test`. They represent the two main steps of any ML engingeering or research: Train a model, then make sure it can infer accurately on test (i.e. **hidden**) data.

Note that `ckpt_path` just tells the trainer to use the latest "checkpoint" of the model, rather than *best* checkpoint. A checkpoint is a saved version of the model, a snapshot at a particular stage of the training process. Pytorch Lightning automatically saves a checkpoint of your model in case something crashes and we need to resume.

In [7]:
test_results = trainer.test(ckpt_path=None)

  rank_zero_warn(
Restoring states from the checkpoint path at /global/cfs/cdirs/m3443/usr/ryanliu/ITk_embedding/ITk_0.5GeV/22r0ts5c/checkpoints/epoch=0-step=999.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at /global/cfs/cdirs/m3443/usr/ryanliu/ITk_embedding/ITk_0.5GeV/22r0ts5c/checkpoints/epoch=0-step=999.ckpt
  rank_zero_warn("One of given dataloaders is None and it will be skipped.")


## Performance

In [8]:
from LightningModules.Embedding.utils import get_metrics

Traceback (most recent call last):
  File "/global/homes/r/ryanliu/.conda/envs/gnn/lib/python3.8/multiprocessing/queues.py", line 245, in _feed
    send_bytes(obj)
  File "/global/homes/r/ryanliu/.conda/envs/gnn/lib/python3.8/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/global/homes/r/ryanliu/.conda/envs/gnn/lib/python3.8/multiprocessing/connection.py", line 411, in _send_bytes
    self._send(header + buf)
  File "/global/homes/r/ryanliu/.conda/envs/gnn/lib/python3.8/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


Let's see how well the model learned to embed the hits into a space that separates them into discrete clusters. As in the data visualisation above, we are going manual here. So one thing we need to do is tell the model that we are **evaluating**, not **training**, with `model.eval()`. We also make sure to wrap everything with `with torch.no_grad()` to ignore any gradients. This speeds things up and gives the GPU's memory a rest, since we're not interesting in any kind of training from here on in.

In [9]:
model.eval();

### Test metrics

In [10]:
all_efficiencies, all_purities = [], []
all_radius = np.arange(0.5, 1.2, 0.1)

with torch.no_grad():
    for r in all_radius:

        model.hparams.r_test = r
        test_results = trainer.test(ckpt_path=None)[0]

        mean_efficiency, mean_purity = test_results["eff"], test_results["pur"]

        all_efficiencies.append(mean_efficiency)
        all_purities.append(mean_purity)

Restoring states from the checkpoint path at /global/cfs/cdirs/m3443/usr/ryanliu/ITk_embedding/ITk_0.5GeV/22r0ts5c/checkpoints/epoch=0-step=999.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at /global/cfs/cdirs/m3443/usr/ryanliu/ITk_embedding/ITk_0.5GeV/22r0ts5c/checkpoints/epoch=0-step=999.ckpt


IndexError: list index out of range

We should always visualise two important metrics: the efficiency (the number of true positives divided by the total number of possible true edges) and the purity (the number of true positives divided by the number of predicted edges). Is it clear to you why the graphs below behave as they do, as we widen the sphere around each hit to generate neighboring edges?

In [None]:
plt.figure(figsize=(12,8))
plt.plot(all_radius, all_efficiencies);
plt.title("Embedding efficiency", fontsize=24), plt.xlabel("Radius of neighborhood", fontsize=18), plt.ylabel("Efficiency", fontsize=18);

In [None]:
plt.figure(figsize=(12,8))
plt.plot(all_radius, all_purities);
plt.title("Embedding purity", fontsize=24), plt.xlabel("Radius of neighborhood", fontsize=18), plt.ylabel("Purity", fontsize=18);

### Visualise embedding / latent space

Another useful thing to visualise is the actual space being embedded into. Since it's 8 dimensional, we can reduce it to 2 dimensions with "Principal Component Analysis". 

In [None]:
from sklearn.decomposition import PCA

In [None]:
example_data = trainset[0]
particle_ids = example_data.pid
cyl_coords = example_data.x
cell_features = example_data.cell_data
all_features = torch.cat([cyl_coords, cell_features], axis=-1).to("cuda")

In [None]:
latent_features = model(all_features)

In [None]:
pca = PCA(n_components=2)

In [None]:
reduced_dimensions = pca.fit_transform(latent_features.detach().cpu())

In [None]:
reduced_dimensions

In [None]:
plt.figure(figsize=(12,8))
plt.scatter(reduced_dimensions[:, 0], reduced_dimensions[:, 1], s=1)

The above plot is what **all** the data looks like in the latent space. Let's pick a (long) particle track and see if the track is projected to be close together:

In [None]:
particles, counts = np.unique(example_data.pid, return_counts=True)

In [None]:
example_particle = particles[counts > 10][0]

In [None]:
plt.figure(figsize=(12,8))
plt.scatter(reduced_dimensions[:, 0], reduced_dimensions[:, 1], s=1)
plt.scatter(reduced_dimensions[particle_ids == example_particle, 0], reduced_dimensions[particle_ids == example_particle, 1])

There are at least 10 hits in the plot above. Hopefully some of them cluster together into blobs, and that the blobs are close to other blobs of the same color. We get ~99% efficiency with this model, so even if some blobs appear far away in the PCA-reduced 2D space, they are most likely much closer in the 8-dimensional space. You can check this yourself!

### Distributions

In [None]:
def calc_eta(r, z):
    theta = np.arctan2(r, z)
    return -1. * np.log(np.tan(theta / 2.))

In [None]:
def get_performance(model, batch, r_max, k_max):
    with torch.no_grad():
        input_data = torch.cat([batch.cell_data[:, :model.hparams["cell_channels"]], batch.x], axis=-1)
        input_data[input_data != input_data] = 0
        spatial = model(input_data)
        e_bidir = torch.cat(
                [batch.modulewise_true_edges, batch.modulewise_true_edges.flip(0)], axis=-1
            )
        e_spatial = build_edges(spatial, spatial, indices=None, r_max=r_max, k_max=k_max)
        e_spatial, y_cluster = graph_intersection(e_spatial, e_bidir)
    
    return y_cluster, e_spatial, e_bidir

In [None]:
r_max = 1.6
k_max = 1100

In [None]:
%%time
batch = 0
model.eval()

y_cluster, e_spatial, e_bidir = get_performance(model, model.testset[batch].to(device), r_max=r_max, k_max=k_max)

In [None]:
print(f"Pur: {y_cluster.sum() / y_cluster.shape[0]}, Eff: {y_cluster.sum() / e_bidir.shape[1]}")

In [None]:
print(f"Pur: {y_cluster.sum() / y_cluster.shape[0]}, Eff: {y_cluster.sum() / e_bidir.shape[1]}")

In [None]:
torch.cuda.max_memory_allocated() / 1024**3

In [None]:
torch.cuda.reset_peak_memory_stats()

In [None]:
eta_cuts = np.arange(-4, 4.5, 0.5)
batch_range = np.arange(0,200)

eta_eff_list = []
eta_pur_list = []

for batch_idx in batch_range[:5]:
    
    batch = model.testset[batch_idx].to(device)
    
    y_cluster, e_spatial, e_bidir = get_performance(model, batch, r_max=r_max, k_max=k_max)
    
    eta_hits = calc_eta(batch.x[:, 0].cpu(), batch.x[:, 2].cpu())
    av_eta_preds = (eta_hits[e_spatial[0]] + eta_hits[e_spatial[1]])/2
    av_eta_true = (eta_hits[e_bidir[0]] + eta_hits[e_bidir[1]])/2
    
    for eta1, eta2 in zip(eta_cuts[:-1], eta_cuts[1:]):
        edge_eta_pred = (av_eta_preds >= eta1) & (av_eta_preds <= eta2)
        edge_eta_true = (av_eta_true >= eta1) & (av_eta_true <= eta2)
        true_positives = y_cluster[edge_eta_pred]

        eta_eff_list.append(true_positives.sum().item() / edge_eta_true.sum().item())
        eta_pur_list.append(true_positives.sum().item() / true_positives.shape[0])

In [None]:
eta_eff_reshape = np.array(eta_eff_list).reshape(eta_cuts.shape[0]-1, len(batch_range))
eta_pur_reshape = np.array(eta_pur_list).reshape(eta_cuts.shape[0]-1, len(batch_range))

Train set

In [None]:
plt.scatter(eta_eff_reshape.mean(0), eta_pur_reshape.mean(0), s=2)

Test set

In [None]:
plt.scatter(eta_eff_reshape.mean(0), eta_pur_reshape.mean(0), s=2)

In [None]:
eta_center = (eta_cuts[:-1] + eta_cuts[1:])/2
plt.errorbar(eta_center, eta_eff_reshape.mean(1), eta_eff_reshape.std(1), fmt="o")

In [None]:
eta_center = (eta_cuts[:-1] + eta_cuts[1:])/2
plt.errorbar(eta_center, eta_pur_reshape.mean(1), eta_pur_reshape.std(1), fmt="o")

In [None]:
pt_cuts = np.arange(900, 5000, 500)

pt_eff_list = []
pt_pur_list = []
    
for batch_idx in batch_range:
    
    batch = model.testset[batch_idx].to(device)
    
    y_cluster, e_spatial, e_bidir = get_performance(model, batch, r_max=1.2, k_max=500)
    
    av_pt_preds = (batch.pt[e_spatial[0]] + batch.pt[e_spatial[1]])/2
    av_pt_true = (batch.pt[e_bidir[0]] + batch.pt[e_bidir[1]])/2
    
    for pt1, pt2 in zip(pt_cuts[:-1], pt_cuts[1:]):
        edge_pt_pred = (av_pt_preds >= pt1) & (av_pt_preds <= pt2)
        edge_pt_true = (av_pt_true >= pt1) & (av_pt_true <= pt2)
        true_positives = y_cluster[edge_pt_pred]

        pt_eff_list.append(true_positives.sum().item() / max(1, edge_pt_true.sum().item()))
        pt_pur_list.append(true_positives.sum().item() / max(1, true_positives.shape[0]))
    

In [None]:
pt_eff_reshape = np.array(pt_eff_list).reshape(pt_cuts.shape[0]-1, len(batch_range))
pt_pur_reshape = np.array(pt_pur_list).reshape(pt_cuts.shape[0]-1, len(batch_range))

In [None]:
pt_center = (pt_cuts[:-1] + pt_cuts[1:])/2
plt.errorbar(pt_center, pt_eff_reshape.mean(1), pt_eff_reshape.std(1), fmt="o")

In [None]:
pt_center = (pt_cuts[:-1] + pt_cuts[1:])/2
plt.errorbar(pt_center, pt_pur_reshape.mean(1), pt_pur_reshape.std(1), fmt="o")

## Memory of Edge Builder

In [None]:
torch.cuda.max_memory_allocated() / 1024**3

In [None]:
torch.cuda.reset_peak_memory_stats()

In [None]:
r_max = 1.6
k_max = 1000

In [None]:
batch = model.testset[0].to(device)
model.eval()
with torch.no_grad():
    input_data = torch.cat([batch.cell_data[:, :model.hparams["cell_channels"]], batch.x], axis=-1)
    input_data[input_data != input_data] = 0
    spatial = model(input_data)
    e_bidir = torch.cat(
            [batch.modulewise_true_edges, batch.modulewise_true_edges.flip(0)], axis=-1
        )

In [None]:
dists, idxs, nn, grid = frnn.frnn_grid_points(points1=spatial.unsqueeze(0), points2=spatial.unsqueeze(0), lengths1=None, lengths2=None, K=k_max, r=r_max, grid=None, return_nn=False, return_sorted=True)

In [None]:
idxs = idxs.squeeze().int()
ind = torch.Tensor.repeat(torch.arange(idxs.shape[0], device=device), (idxs.shape[1], 1), 1).T.int()

In [None]:
positive_idxs = idxs >= 0
ind = ind[positive_idxs]
idxs = idxs[positive_idxs]

In [None]:
edge_list = torch.stack([ind, idxs]).int()

In [None]:
# Remove self-loops
edge_list = edge_list[:, edge_list[0] != edge_list[1]]

In [None]:
del ind
del idxs

In [None]:
edge_list = edge_list.long()

# Build Edge Dataset

## Load best model

In [None]:
checkpoint_path = "/global/cscratch1/sd/danieltm/ExaTrkX/itk_lightning_checkpoints/ITk_1GeV/pdwlz89x/checkpoints/last.ckpt"
checkpoint = torch.load(checkpoint_path)

In [None]:
model = LayerlessEmbedding.load_from_checkpoint(checkpoint_path).to(device)

In [None]:
model.hparams["train_split"] = [10, 10, 10]

In [None]:
%%time
model.setup(stage="fit")

## Define Building Class

In [None]:
class EmbeddingInferenceBuilder:
    def __init__(self, model, output_dir, overwrite=False):
        self.output_dir = output_dir
        self.model = model
        self.overwrite = overwrite

        # Prep the directory to produce inference data to
        self.datatypes = ["train", "val", "test"]
        os.makedirs(self.output_dir, exist_ok=True)
        [
            os.makedirs(os.path.join(self.output_dir, datatype), exist_ok=True)
            for datatype in self.datatypes
        ]


    def build(self):
        print("Training finished, running inference to build graphs...")

        # By default, the set of examples propagated through the pipeline will be train+val+test set
        datasets = {
            "train": self.model.trainset,
            "val": self.model.valset,
            "test": self.model.testset,
        }
        total_length = sum([len(dataset) for dataset in datasets.values()])
        batch_incr = 0
        self.model.eval()
        with torch.no_grad():
            for set_idx, (datatype, dataset) in enumerate(datasets.items()):
                for batch_idx, batch in enumerate(dataset):
                    percent = (batch_incr / total_length) * 100
                    sys.stdout.flush()
                    sys.stdout.write(f"{percent:.01f}% inference complete \r")
                    if (
                        not os.path.exists(
                            os.path.join(
                                self.output_dir, datatype, batch.event_file[-4:]
                            )
                        )
                    ) or self.overwrite:
                        batch_to_save = copy.deepcopy(batch)
                        batch_to_save = batch_to_save.to(
                            self.model.device
                        )  # Is this step necessary??
                        self.construct_downstream(batch_to_save, datatype)

                    batch_incr += 1

    def construct_downstream(self, batch, datatype):

        if "ci" in self.model.hparams["regime"]:
            input_data = torch.cat([batch.cell_data[:, :self.model.hparams["cell_channels"]], batch.x], axis=-1)
            input_data[input_data != input_data] = 0
            spatial = self.model(input_data)
        else:
            input_data = batch.x
            input_data[input_data != input_data] = 0
            spatial = self.model(input_data)

        # Make truth bidirectional
        e_bidir = torch.cat(
            [
                batch.modulewise_true_edges,
                torch.stack(
                    [batch.modulewise_true_edges[1], batch.modulewise_true_edges[0]],
                    axis=1,
                ).T,
            ],
            axis=-1,
        )

        # Build the radius graph with radius < r_test
        e_spatial = build_edges(
            spatial, spatial, indices=None, r_max=self.model.hparams["r_test"], k_max =1100
        ).long()  # This step should remove reliance on r_val, and instead compute an r_build based on the EXACT r required to reach target eff/pur

        # Arbitrary ordering to remove half of the duplicate edges
        R_dist = torch.sqrt(batch.x[:, 0] ** 2 + batch.x[:, 2] ** 2)
        e_spatial = e_spatial[:, (R_dist[e_spatial[0]] <= R_dist[e_spatial[1]])]

        e_spatial, y_cluster = graph_intersection(e_spatial, e_bidir)

        # Re-introduce random direction, to avoid training bias
        random_flip = torch.randint(2, (e_spatial.shape[1],)).bool()
        e_spatial[0, random_flip], e_spatial[1, random_flip] = (
            e_spatial[1, random_flip],
            e_spatial[0, random_flip],
        )

        batch.edge_index = e_spatial
        batch.y = y_cluster

        self.save_downstream(batch, datatype)

    def save_downstream(self, batch, datatype):

        with open(
            os.path.join(self.output_dir, datatype, batch.event_file[-4:]), "wb"
        ) as pickle_file:
            torch.save(batch, pickle_file)


In [None]:
output_dir = "/project/projectdirs/m3443/data/ITk-upgrade/processed/embedding_processed/0_GeV_unweighted_subset"
model.hparams["r_test"] = 1.6
edge_builder = EmbeddingInferenceBuilder(model, output_dir, overwrite=False)

In [None]:
edge_builder.build()

In [None]:
len(model.trainset)