In [None]:
import numpy as np
import os

from pathlib import Path

# change working directory to make src visible
os.chdir(Path.cwd().parent)

from src.data_generation import SPDEventGenerator
from src.visualization import draw_event

In [None]:
event_gen = SPDEventGenerator(detector_eff=0.98)

for _ in range(2):
    event = event_gen.generate_spd_event()
    print(event)

In [None]:
draw_event(
    hits=event.hits,
    fakes=event.fakes,
    vertex=event.vertex.numpy,
    labels=event.track_ids,
)

In [None]:
event.track_params

## Example of reproducing tracks by their parameters

In [None]:
track_params = event.track_params
vertex = event.vertex
magnetic_field = event_gen.magnetic_field
z_coord_range = event_gen.z_coord_range
radii = np.linspace(
    event_gen.r_coord_range[0], 
    event_gen.r_coord_range[1],
    event_gen.n_stations
) # mm

hits = []
labels = []

for track in track_params:
    for r in radii:
        hit, _ = SPDEventGenerator.generate_hit_by_params(
            track_params=track_params[track],
            vertex=vertex,
            Rc=r,
            #magnetic_field=magnetic_field
        )

        if (hit.x, hit.y, hit.z) == (0, 0, 0):
            continue

        if not z_coord_range[0] <= hit.z <= z_coord_range[1]:
            continue
        
        hits.append(hit.numpy)
        labels.append(track)
    
hits = np.vstack(hits, dtype=np.float32)
labels = np.array(labels, dtype=np.int32)

In [None]:
draw_event(
    hits=hits,
    fakes=None,
    vertex=vertex.numpy,
    labels=labels,
)

## Visualize generated tracks


In [None]:
from src.model import TRT
from src.training import TrainModel


In [None]:
from src.normalization import ConstraintsNormalizer, TrackParamsNormalizer
from src.postprocess import TracksFromParamsGenerator
from src.loss import MatchingLoss, HungarianMatcher


In [None]:
hits_generator = TracksFromParamsGenerator(
    hits_normalizer=ConstraintsNormalizer(),
    params_normalizer=TrackParamsNormalizer(),
    n_stations=35,
)

loss = MatchingLoss(
    matcher=HungarianMatcher(class_cost= 0.5, params_cost=2.),
    hits_generator=hits_generator,
    num_classes=1,
    eos_coef=0.2,
    losses=["labels", "params", "hits"])


In [None]:
from torch.optim import AdamW
base_model = TRT(
    num_candidates=20, num_out_params=6, dropout=0.1,n_points=512
)
model = TrainModel(
    model=base_model,
    criterion=loss,
    metrics=[],
    optimizer=AdamW(lr=0.001, params=base_model.parameters())
)

In [None]:
PATH = r"D:\projects\trt\results\hydra\2024-06-26\00-10-24\TRT\version_0\epoch=45-step=57500.ckpt"
model = TrainModel.load_from_checkpoint(
    PATH,
    model=base_model,
    criterion=loss,
    metrics=[],
    optimizer=AdamW(lr=0.001, params=base_model.parameters()))


In [None]:
import torch

maxlen = len(hits)
batch_size = 1
n_features = hits.shape[-1]
mask = np.ones(len(hits), dtype=bool)
batch_inputs = np.zeros((batch_size, maxlen, n_features), dtype=np.float32)
batch_mask = np.zeros((batch_size, maxlen), dtype=bool)
# params have the fixed size - MAX_TRACKS x N_PARAMS
batch_inputs[0, :len(hits)] = hits
batch_mask[0, :len(hits)] = mask

inputs = torch.from_numpy(batch_inputs)
mask = torch.from_numpy(batch_mask)

In [None]:
preds = model(inputs={"inputs": inputs, "mask": mask})

In [None]:
source_params = preds["params"]
source_charges = torch.argmax(preds["logits"], dim=-1).to(torch.float) * 2 - 1
source_charges = source_charges.unsqueeze(-1)
source_params = torch.concat((source_params, source_charges), dim=-1)
source_tracks, _ = hits_generator.generate_tracks(
            source_params.detach().cpu().numpy().squeeze()
        )

In [None]:
source_tracks

In [None]:
pred_hits = []
pred_labels = []

for i, track in enumerate(source_tracks):
    pred_labels.append([i for i in range(len(track))])

pred_hits = np.vstack(hits, dtype=np.float32)
pred_labels = np.array(labels, dtype=np.int32)

In [None]:
pred_labels.shape

In [None]:
draw_event(
    hits=hits,
    fakes=None,
    vertex=vertex.numpy,
    predicted_hits=source_tracks,
    predicted_tracks=np.array([i for i in range(len(source_tracks))]),
    labels=labels,
)