## Embedding Inference

- _run inference using Lightning Callbacks (load trained model from `checkpoint` with `weights` and `hparams`, run `test_step()` via `pl.Trainer`)_ 
- _examine input graphs generated through graph embeddings_

In [None]:
import glob, os, sys, yaml

In [None]:
import numpy as np
import scipy as sp
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import pprint
import seaborn as sns
import trackml.dataset

In [None]:
import torch
import pytorch_lightning as pl
from torch_geometric.data import Data
import itertools

In [None]:
# append parent dir
sys.path.append('src')

In [None]:
# get cuda device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# set environment
os.environ['EXATRKX_DATA'] = os.path.abspath(os.curdir)

In [None]:
# local imports
from LightningModules.edge_construction import LayerlessEmbedding
from src import SttCSVDataReader, SttTorchDataReader
from src import Visualize_CSVEvent, Visualize_TorchEvent
from src import Build_Event, Build_Event_Viz, Visualize_Edges

### _1. Model Configurapytorch_lightningan load model cofiguration used during training as well as what was the model architecture._

- _**Load Model Configuration**_

In [None]:
# load hparams from config file
config_file = os.path.join(os.curdir, 'LightningModules/edge_construction/configs/train_quickstart_embedding.yaml')
with open(config_file) as f:
    try:
        config = yaml.load(f, Loader=yaml.FullLoader) # equiv: yaml.full_load(f)
    except yaml.YAMLError as e:
        print(e)

In [None]:
# see hparams used in this stage
pp = pprint.PrettyPrinter(indent=2)
pp.pprint(config)

- _**Load Model Architecture**_

In [None]:
# Layerless Embedding
e_model = LayerlessEmbedding(config)

In [None]:
# Model Summary
print(e_model)

In [None]:
# use torchsummary if available
# from torchsummary import summary
# summary(e_model, input_size=(3, ), batch_size=1, device=device)

### _2. Model Checkpoints_

_PyTorch Lightning stores checkpoints with model's entire internal states, see [Saving and Loading Checkpoints (Basic)](https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html) for details. We can load a checkpoint and see what information is stored. There are two ways to laod a models checkpoint: **(a)** simply load last checkpoint, **(b)** recreate the ModelCheckpoint callback instance. In both cases, one needs path to a checkpoint e.g. `last.ckpt`._

In [None]:
# checkpoint path (last saved checkpoint)
ckpt_path = "run/lightning_models/lightning_checkpoints/EmbeddingStudy/version_2/checkpoints/last.ckpt"

_**(a).** simply load last checkpoint (use this one)_

In [None]:
# load checkpoint
checkpoint = torch.load(ckpt_path, map_location=device)

In [None]:
# checkpoint keys
checkpoint.keys()

In [None]:
# examine checkpoint state
# checkpoint["epoch"]
# checkpoint["global_step"]
# checkpoint["pytorch-lightning_version"]
# checkpoint["optimizer_states"]
# checkpoint["lr_schedulers"]
# checkpoint["callbacks"]
# checkpoint["hyper_parameters"]

_**(b).** recreate the ModelCheckpoint callback instance (ignore this one)_

In [None]:
# ModelCheckpoint Callback
# checkpoint = torch.load(ckpt_path, map_location=device)
# checkpoint_callback = ModelCheckpoint(monitor="val_loss", mode="min")
# checkpoint_callback.on_load_checkpoint(checkpoint['callbacks'])          # failed: need more positional arguments
# best_model_path = checkpoint_callback.best_model_path
# print(f"Best model path: {best_model_path}")

### _3. Load Model from a Checkpoint_

In [None]:
hparams = checkpoint["hyper_parameters"]                           # get hparams from checkpoint
e_model = LayerlessEmbedding(config)                               # instantiate a model with a config/hparams
e_model = e_model.load_from_checkpoint(checkpoint_path=ckpt_path)  # load model state from checkpoint
e_model.eval()                                                     # disable randomness, dropout, etc.

In [None]:
# check model (new) hparams from checkpoint
# e_model.hparams

In [None]:
# use SttTorchDataReader
event_id=14996
inputdir="./run/feature_store"
torch_reader = SttTorchDataReader(inputdir)
feature_data = torch_reader(evtid=event_id)

In [None]:
feature_data

In [None]:
# predict with the model
spatial = e_model(feature_data.x)

In [None]:
spatial_np = spatial.detach().numpy()

In [None]:
spatial.shape

### _4. Run Inference_

In [None]:
from LightningModules.edge_construction import EmbeddingBuilder, EmbeddingTelemetry

In [None]:
# run inference callbacks (EmbeddingTelemetry has issues)
trainer = pl.Trainer(callbacks=[EmbeddingBuilder()])

In [None]:
trainer.test(model=e_model, verbose=True)

### _5. Embedded Data_

In [None]:
inputdir="./run/edge_construction/test"
torch_reader = SttTorchDataReader(inputdir)
embedded_data = torch_reader(evtid=event_id)

In [None]:
feature_data

In [None]:
embedded_data

In [None]:
embedded_data.x[:10]

In [None]:
embedded_data.pt[:10]

In [None]:
embedded_data.event_file

In [None]:
!ls ./ctstrkx/data_all/event0000014996-hits.csv

In [None]:
event_prefix = "./data_all/event0000014996"

In [None]:
# see the corresponding raw data
hits, tubes, particles, truth = trackml.dataset.load_event(event_prefix)

In [None]:
particles.head()

In [None]:
truth.head()

### _6. Visualize Embedding_

In [None]:
# lets get unique pids with freq (~ hits).
sel_pids, sel_pids_fr = np.unique(feature_data.pid, return_counts=True)

In [None]:
# now we have using pids, so plotting will be fast
sel_pids

In [None]:
# plot hits in the embedding space
embedding_dims = [(0, 1), (2, 3), (4, 5), (6, 7)]
for id1, id2 in embedding_dims:
    fig = plt.figure(figsize=(6,6))
    for pid in sel_pids:
        # idx = hits.particle_id == pid
        idx = feature_data.pid == pid
        plt.scatter(spatial_np[idx, id1], spatial_np[idx, id2], label='track %d'%pid)
        
    plt.grid(True)
    plt.legend(fontsize=10, loc='best')
    plt.tight_layout()
    plt.savefig("embedding_{}_{}.png".format(id1, id2))
    del fig

### TSNE from 8 dim to 2 dim

- We projected 8D embedding space on 2D space (dim reduction)
- We can use [TSNE](https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html) package

In [None]:
from sklearn.manifold import TSNE

In [None]:
spatial_np.shape

In [None]:
spatial_tsne = TSNE(n_components=2).fit_transform(spatial_np)

In [None]:
spatial_tsne.shape

In [None]:
# Using Object Oriented API
plt.close('all')
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 6))

# loop over tracks (pids=1,2,...,10)
for pid in sel_pids:
    idx = feature_data.pid == pid
    ax.scatter(spatial_tsne[idx, 0], spatial_tsne[idx, 1], label='track %d'%pid)

ax.legend(fontsize=10, loc='best')
ax.grid(True)
fig.tight_layout()
plt.savefig("embeding-tsne.png")

## _Prepare Data for Plotting Edges_

- Need `x, y, z` for hits
- Corresponding Edges

In [None]:
from LightningModules.edge_construction.utils.embedding_utils import build_edges

In [None]:
# build edges
e_spatial = build_edges(spatial, spatial, indices=None, r_max=0.1, k_max=100)

In [None]:
# e_spatial.shape

In [None]:
# e_spatial

In [None]:
# get first three pairs/edges. use all rows but cols=0,1,2
# e_spatial[:, 0], e_spatial[:, 1], e_spatial[:, 2]

In [None]:
# convert e_spatial tensor to numpy version e_spatial_np
e_spatial_np = e_spatial.detach().numpy()

In [None]:
# 1st row
# e_spatial_np[0]

In [None]:
# 2nd row
# e_spatial_np[1]

In [None]:
# edge/pair = 1st row first element, 2nd row 1st element
# e_spatial_np[0, 0], e_spatial_np[1, 0]

In [None]:
# hit=1 (0) of one pair is also paired next with hit=13 (12) forming another pair.
# e_spatial_np[0, 1], e_spatial_np[1, 1]

In [None]:
# access hi1=1 (0) in pair 0 and pair 1.
# e_spatial_np[0, 0], e_spatial_np[0, 1]

In [None]:
# Why ???
hits.iloc[[e_spatial[0, 0], e_spatial[0, 1]]].x.values

### _Ploting Edges_

In [None]:
# remeber r, phi, z in feature_data is in radians (we scaled it with np.pi)
def cylinderical_to_cartesion(r, phi, z):
    """Cylinderical to Catesian Coordinates. Offset scaling [r*100, phi*np.pi, z*100]"""
    theta = phi * np.pi
    x = r * np.cos(theta)*100
    y = r * np.sin(theta)*100
    z = z * 100
    return x, y, z

In [None]:
# lets get transpose of e_spatial
e_spatial_np_t = e_spatial_np.T

In [None]:
e_spatial_np.shape

In [None]:
e_spatial_np_t.shape

In [None]:
# plotting event from processing stage i.e. feature_data
fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111, projection='3d')
for pid in sel_pids:
    idx = feature_data.pid == pid
    x, y, z = cylinderical_to_cartesion(r=feature_data.x[:, 0], phi=feature_data.x[:, 1], z=feature_data.x[:, 2])
    ax.scatter(x[idx], y[idx], z[idx], label='track %d'%pid)

ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
ax.set_xlim(-40, 40)
ax.set_ylim(-40, 40)
ax.set_zlim(-10, 100)
ax.legend(fontsize=10, loc='best')
ax.grid(True)
fig.tight_layout()
# plt.savefig(os.path.join(outdir, "embeding-tsne.png"))

In [None]:
# plot edges 3D: data from processing stage, edges from embedding stage
fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111, projection='3d')
for pid in sel_pids:
    idx = feature_data.pid == pid
    x, y, z = cylinderical_to_cartesion(r=feature_data.x[:, 0], phi=feature_data.x[:, 1], z=feature_data.x[:, 2])
    ax.scatter(x[idx], y[idx], z[idx], label='track %d'%pid)


# add edges (hits=runs from 0 to 256, edges goes all the way to 5)
e_spatial_np_t = e_spatial_np.T
for iedge in range(e_spatial_np.shape[1]):
    x, y, z = cylinderical_to_cartesion(r=feature_data.x[:, 0], phi=feature_data.x[:, 1], z=feature_data.x[:, 2])
    ax.plot(x[e_spatial_np_t[iedge]], y[e_spatial_np_t[iedge]], z[e_spatial_np_t[iedge]], color='k', alpha=0.3, lw=1.)
    

ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
ax.set_xlim(-40, 40)
ax.set_ylim(-40, 40)
ax.set_zlim(-10, 100)
ax.legend(fontsize=10, loc='best')
ax.grid(True)
fig.tight_layout()
# plt.savefig(os.path.join(outdir, "emedding_edges_3d.pdf"))

In [None]:
# plot edges 2D: data from processing stage, edges from embedding stage
fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111)
for pid in sel_pids:
    idx = feature_data.pid == pid
    x, y, z = cylinderical_to_cartesion(r=feature_data.x[:, 0], phi=feature_data.x[:, 1], z=feature_data.x[:, 2])
    ax.scatter(x[idx], y[idx], label='track %d'%pid)

# add edges (hits=runs from 0 to 256, edges goes all the way to 5)
e_spatial_np_t = e_spatial_np.T
# for iedge in range(e_spatial_np.shape[1]):
for iedge in range(100):
    ax.plot(hits.iloc[e_spatial_np_t[iedge]].x.values, hits.iloc[e_spatial_np_t[iedge]].y.values, color='k', alpha=0.3, lw=1.)
    
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_xlim(-40, 40)
ax.set_ylim(-40, 40)
ax.legend(fontsize=10, loc='best')
ax.grid(True)
fig.tight_layout()
# plt.savefig(os.path.join(outdir, "embedding_edges_x_y.pdf"))

In [None]:
# plot edges 2D: data from processing stage, edges from embedding stage
fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111)
for pid in sel_pids:
    idx = feature_data.pid == pid
    ax.scatter(feature_data.x[:, 2][idx], feature_data.x[:, 0][idx], label='track %d'%pid)

# add edges
e_spatial_np_t = e_spatial_np.T
# for iedge in range(e_spatial_np.shape[1]):
for iedge in range(100):
    ax.plot(feature_data.x[:, 2][e_spatial_np_t[iedge]], feature_data.x[:, 0][e_spatial_np_t[iedge]], color='k', alpha=0.3, lw=1.)

ax.set_xlabel('Z')
ax.set_ylabel('R')
#ax.set_xlim(-40, 100)
#ax.set_ylim(0, 50)
ax.legend(fontsize=10, loc='best')
ax.grid(True)
fig.tight_layout()
# plt.savefig(os.path.join(outdir, "embedding_edges_z_r.pdf"))

### Plotting Hists
- following plots need r, phi, z coordinates.