In [1]:
import numpy as np
import os

from pathlib import Path

# change working directory to make src visible
os.chdir(Path.cwd().parent)

from hydra import compose, initialize
from hydra.utils import instantiate
from pytorch_lightning import seed_everything

from src.data_generation import SPDEventGenerator
from src.visualization import draw_event

CHECKPOINT_PATH = "results/hydra/2024-09-11/19-29-11"
seed_everything(13)

Seed set to 13


13

In [2]:
with initialize(version_base=None, config_path=str(".." / Path(CHECKPOINT_PATH) / ".hydra")):
    cfg = compose(config_name="config.yaml")

In [3]:
event_gen = SPDEventGenerator(
    max_event_tracks=cfg.dataset.max_event_tracks,
    detector_eff=cfg.dataset.detector_efficiency
)
track_params_normalizer=instantiate(cfg.dataset.track_params_normalizer)


for _ in range(2):
    event = event_gen.generate_spd_event()
    print(event)

Event:
Shape of real hits: (103, 3)
Shape of momentums: (103, 3)
Shape of fake hits: (309, 3)
Fraction of fakes: 0.75
Fraction of missing hits: 0.02
Number of unique tracks: 3
Vertex: Vertex(x=-7.12, y=7.54, z=194.57)
Track parameters:
	Track ID: 0, TrackParams(pt=969.17, phi=6.11, theta=1.66, charge=-1)
	Track ID: 1, TrackParams(pt=984.00, phi=3.30, theta=0.69, charge=1)
	Track ID: 2, TrackParams(pt=219.16, phi=4.35, theta=1.91, charge=-1)




Event:
Shape of real hits: (103, 3)
Shape of momentums: (103, 3)
Shape of fake hits: (233, 3)
Fraction of fakes: 0.69
Fraction of missing hits: 0.02
Number of unique tracks: 3
Vertex: Vertex(x=1.00, y=5.75, z=-275.20)
Track parameters:
	Track ID: 0, TrackParams(pt=243.62, phi=1.31, theta=0.78, charge=1)
	Track ID: 1, TrackParams(pt=535.48, phi=4.55, theta=2.51, charge=-1)
	Track ID: 2, TrackParams(pt=613.66, phi=1.17, theta=2.34, charge=1)




In [4]:
draw_event(
    hits=event.hits,
    fakes=event.fakes,
    vertex=event.vertex.numpy,
    labels=event.track_ids,
)

In [5]:
event.track_params

{0: TrackParams(phi=1.312273183817232, theta=0.7756352619527093, pt=243.61850011384706, charge=1),
 1: TrackParams(phi=4.548803225574217, theta=2.506917504320647, pt=535.4763634597723, charge=-1),
 2: TrackParams(phi=1.1746950416863728, theta=2.33599121984812, pt=613.6645069556786, charge=1)}

## Example of reproducing tracks by their parameters

In [6]:
hits, labels = [], []
for track, track_params in event.track_params.items():
    track_hits = event_gen.reconstruct_track_hits_from_params(track_params, event.vertex)
    hits.append(track_hits)
    labels.extend([track]*len(track_hits))
    
hits = np.vstack(hits, dtype=np.float32)
labels = np.array(labels, dtype=np.int32)

In [7]:
orig_event_fig = draw_event(
    hits=hits,
    fakes=None,
    vertex=event.vertex.numpy,
    labels=labels,
)
orig_event_fig.show()

## Visualize generated tracks


In [8]:
from src.training import TrainModel
from src.postprocess import EventRecoveryFromPredictions

### Load the model from checkpoint directory

In [9]:
model = TrainModel.load_from_checkpoint(CHECKPOINT_PATH)
model.eval()

TrainModel(
  (model): TRT(
    (emb_encoder): Sequential(
      (0): Conv1d(3, 128, kernel_size=(1,), stride=(1,), bias=False)
      (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
      (3): Conv1d(128, 128, kernel_size=(1,), stride=(1,), bias=False)
      (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): LeakyReLU(negative_slope=0.2)
    )
    (encoder): PointTransformerEncoder(
      (conv1): Conv1d(128, 128, kernel_size=(1,), stride=(1,), bias=False)
      (conv2): Conv1d(128, 128, kernel_size=(1,), stride=(1,), bias=False)
      (conv3): Conv1d(512, 128, kernel_size=(1,), stride=(1,), bias=False)
      (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (

### Convert input hits to the model input

In [10]:
import torch

def event_to_model_input(hits) -> tuple[torch.Tensor, torch.Tensor]:
    maxlen = len(hits)
    batch_size = 1
    n_features = hits.shape[-1]
    mask = np.ones(len(hits), dtype=bool)
    batch_inputs = np.zeros((batch_size, maxlen, n_features), dtype=np.float32)
    batch_mask = np.zeros((batch_size, maxlen), dtype=bool)
    # params have the fixed size - MAX_TRACKS x N_PARAMS
    batch_inputs[0, :len(hits)] = hits
    batch_mask[0, :len(hits)] = mask
    inputs_torch = torch.from_numpy(batch_inputs)
    mask_torch = torch.from_numpy(batch_mask)
    return inputs_torch, mask_torch


inputs, mask = event_to_model_input(event.hits)
inputs.shape, event.hits.shape, mask.shape

(torch.Size([1, 103, 3]), (103, 3), torch.Size([1, 103]))

### Get model predictions

In [11]:
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

with torch.no_grad():
    preds = model(inputs={"inputs": inputs.to(device), "mask": mask.to(device)})

In [12]:
preds

{'logits': tensor([[[ 0.3288, -1.0841],
          [ 1.3489, -0.4145],
          [ 0.0906, -1.4475]]], device='mps:0'),
 'params': tensor([[[0.5032, 0.4967, 0.4956, 0.6243, 0.8008, 0.5042],
          [0.4961, 0.4995, 0.5046, 0.4547, 0.1750, 0.4974],
          [0.4975, 0.4948, 0.4967, 0.4902, 0.8373, 0.4906]]], device='mps:0')}

### Recovert event from predictions

In [13]:
# event reconstruction
event_recoverer = EventRecoveryFromPredictions(event_gen, track_params_normalizer)
# TODO: fix indices
hits, track_ids = event_recoverer(preds["params"][0], preds["logits"][0])

In [14]:
pred_event_fig = draw_event(
    hits=hits,
    fakes=None,
    vertex=None,
    labels=track_ids,
)
pred_event_fig.show()

In [15]:
from src.visualization import display_side_by_side

comparison_fig = display_side_by_side(pred_event_fig, orig_event_fig)
comparison_fig.show()