In [20]:
import numpy as np
import os

from pathlib import Path

# change working directory to make src visible
os.chdir(Path.cwd().parent)

from src.data_generation import SPDEventGenerator
from src.visualization import draw_event

In [21]:
event_gen = SPDEventGenerator(detector_eff=0.98)

for _ in range(2):
    event = event_gen.generate_spd_event()
    print(event)

Event:
Shape of real hits: (335, 3)
Shape of momentums: (335, 3)
Shape of fake hits: (3473, 3)
Fraction of fakes: 0.91
Fraction of missing hits: 0.03
Number of unique tracks: 10
Vertex: Vertex(x=-2.92, y=-8.57, z=218.22)
Track parameters:
	Track ID: 0, TrackParams(pt=835.05, phi=1.39, theta=2.57, charge=-1)
	Track ID: 1, TrackParams(pt=468.73, phi=5.67, theta=0.46, charge=-1)
	Track ID: 2, TrackParams(pt=467.29, phi=5.92, theta=2.80, charge=-1)
	Track ID: 3, TrackParams(pt=137.80, phi=3.49, theta=1.15, charge=-1)
	Track ID: 4, TrackParams(pt=392.14, phi=1.73, theta=1.68, charge=-1)
	Track ID: 5, TrackParams(pt=155.34, phi=2.09, theta=2.23, charge=-1)
	Track ID: 6, TrackParams(pt=106.49, phi=0.65, theta=0.85, charge=-1)
	Track ID: 7, TrackParams(pt=114.19, phi=1.74, theta=1.46, charge=1)
	Track ID: 8, TrackParams(pt=191.58, phi=0.21, theta=0.92, charge=-1)
	Track ID: 9, TrackParams(pt=601.52, phi=4.50, theta=0.60, charge=-1)


Event:
Shape of real hits: (333, 3)
Shape of momentums: (333

In [22]:
draw_event(
    hits=event.hits,
    fakes=event.fakes,
    vertex=event.vertex.numpy,
    labels=event.track_ids,
)

In [23]:
event.track_params

{0: TrackParams(phi=4.3840293908071875, theta=1.6608810290949785, pt=158.41986388517083, charge=1),
 1: TrackParams(phi=0.4656730140207456, theta=2.678422292295492, pt=622.7430519140903, charge=-1),
 2: TrackParams(phi=1.0144020684772002, theta=2.024246997054324, pt=278.68781736292453, charge=-1),
 3: TrackParams(phi=3.645706176133503, theta=2.685267531791077, pt=492.7981727363326, charge=-1),
 4: TrackParams(phi=3.7337804732263846, theta=2.05083084958334, pt=234.6420777323817, charge=-1),
 5: TrackParams(phi=0.38193735908399806, theta=2.7047822960754138, pt=724.5157521673534, charge=-1),
 6: TrackParams(phi=5.540544602390994, theta=2.9077500188098178, pt=728.318442753845, charge=-1),
 7: TrackParams(phi=5.900918272788932, theta=2.4366475803726844, pt=261.7056796067375, charge=-1),
 8: TrackParams(phi=5.741597685240069, theta=2.6766107922504982, pt=830.3769474467848, charge=-1),
 9: TrackParams(phi=1.4625180671394413, theta=1.2582099525785695, pt=485.8350599377981, charge=1)}

## Example of reproducing tracks by their parameters

In [24]:
track_params = event.track_params
vertex = event.vertex
magnetic_field = event_gen.magnetic_field
z_coord_range = event_gen.z_coord_range
radii = np.linspace(
    event_gen.r_coord_range[0], event_gen.r_coord_range[1], event_gen.n_stations
)  # mm

hits = []
labels = []

for track in track_params:
    for r in radii:
        hit, _ = SPDEventGenerator.generate_hit_by_params(
            track_params=track_params[track],
            vertex=vertex,
            Rc=r,
            # magnetic_field=magnetic_field
        )

        if (hit.x, hit.y, hit.z) == (0, 0, 0):
            continue

        if not z_coord_range[0] <= hit.z <= z_coord_range[1]:
            continue

        hits.append(hit.numpy)
        labels.append(track)

hits = np.vstack(hits, dtype=np.float32)
labels = np.array(labels, dtype=np.int32)

In [25]:
draw_event(
    hits=hits,
    fakes=None,
    vertex=vertex.numpy,
    labels=labels,
)

## Visualize generated tracks


In [26]:
from src.model import TRT
from src.training import TrainModel

In [27]:
from src.normalization import ConstraintsNormalizer, TrackParamsNormalizer
from src.postprocess import TracksFromParamsGenerator
from src.loss import MatchingLoss, HungarianMatcher

In [28]:
hits_generator = TracksFromParamsGenerator(
    hits_normalizer=ConstraintsNormalizer(),
    params_normalizer=TrackParamsNormalizer(),
    n_stations=35,
)

loss = MatchingLoss(
    matcher=HungarianMatcher(class_cost=0.5, params_cost=2.0),
    hits_generator=hits_generator,
    num_classes=1,
    eos_coef=0.2,
    losses=["labels", "params", "hits"],
)

In [29]:
from torch.optim import AdamW

base_model = TRT(num_candidates=20, num_out_params=7, dropout=0.1, n_points=512)
model = TrainModel(
    model=base_model,
    criterion=loss,
    metrics=[],
    optimizer=AdamW(lr=0.001, params=base_model.parameters()),
)

In [30]:
PATH = r"D:\projects\trt\results\hydra\2024-08-22\14-39-07\TRT\version_0\epoch=44-step=45.ckpt"
model = TrainModel.load_from_checkpoint(
    PATH,
    model=base_model,
    criterion=loss,
    metrics=[],
    optimizer=AdamW(lr=0.001, params=base_model.parameters()),
)

RuntimeError: Error(s) in loading state_dict for TrainModel:
	Missing key(s) in state_dict: "model.class_head.weight", "model.class_head.bias". 
	Unexpected key(s) in state_dict: "model.class_head.0.weight", "model.class_head.0.bias", "model.class_head.3.weight". 

In [31]:
import torch

maxlen = len(hits)
batch_size = 1
n_features = hits.shape[-1]
mask = np.ones(len(hits), dtype=bool)
batch_inputs = np.zeros((batch_size, maxlen, n_features), dtype=np.float32)
batch_mask = np.zeros((batch_size, maxlen), dtype=bool)
# params have the fixed size - MAX_TRACKS x N_PARAMS
batch_inputs[0, : len(hits)] = hits
batch_mask[0, : len(hits)] = mask

inputs = torch.from_numpy(batch_inputs)
mask = torch.from_numpy(batch_mask)
model.eval()

TrainModel(
  (model): TRT(
    (emb_encoder): Sequential(
      (0): Conv1d(3, 128, kernel_size=(1,), stride=(1,), bias=False)
      (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
      (3): Conv1d(128, 128, kernel_size=(1,), stride=(1,), bias=False)
      (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): LeakyReLU(negative_slope=0.2)
    )
    (encoder): PointTransformerEncoder(
      (conv1): Conv1d(128, 128, kernel_size=(1,), stride=(1,), bias=False)
      (conv2): Conv1d(128, 128, kernel_size=(1,), stride=(1,), bias=False)
      (conv3): Conv1d(512, 128, kernel_size=(1,), stride=(1,), bias=False)
      (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (

In [32]:
preds = model(inputs={"inputs": inputs, "mask": mask})

In [33]:
torch.argmax(preds["logits"], dim=-1).to(torch.float)

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [34]:
source_params = preds["params"][..., :-1]
#source_charges = torch.argmax(preds["logits"], dim=-1).to#(torch.float)
source_charges = preds["params"][..., -1]
is_objects = source_charges < 2
source_charges = source_charges * 2 - 1
source_charges = source_charges.unsqueeze(-1)
source_params = torch.concat((source_params, source_charges), dim=-1)[is_objects]
source_tracks, _ = hits_generator.generate_tracks(
    source_params.detach().cpu().numpy().squeeze()
)

In [45]:
source_params

tensor([[ 0.4845,  0.4719,  0.4690,  0.4024,  0.4261,  0.4067, -0.2218],
        [ 0.4845,  0.4719,  0.4690,  0.4024,  0.4261,  0.4067, -0.2218],
        [ 0.4845,  0.4719,  0.4690,  0.4024,  0.4261,  0.4067, -0.2218],
        [ 0.4845,  0.4719,  0.4690,  0.4024,  0.4261,  0.4067, -0.2218],
        [ 0.4845,  0.4719,  0.4690,  0.4024,  0.4261,  0.4067, -0.2218],
        [ 0.4845,  0.4719,  0.4690,  0.4024,  0.4261,  0.4067, -0.2218],
        [ 0.4845,  0.4719,  0.4690,  0.4024,  0.4261,  0.4067, -0.2218],
        [ 0.4845,  0.4719,  0.4690,  0.4024,  0.4261,  0.4067, -0.2218],
        [ 0.4845,  0.4719,  0.4690,  0.4024,  0.4261,  0.4067, -0.2218],
        [ 0.4845,  0.4719,  0.4690,  0.4024,  0.4261,  0.4067, -0.2218],
        [ 0.4845,  0.4719,  0.4690,  0.4024,  0.4261,  0.4067, -0.2218],
        [ 0.4845,  0.4719,  0.4690,  0.4024,  0.4261,  0.4067, -0.2218],
        [ 0.4845,  0.4719,  0.4690,  0.4024,  0.4261,  0.4067, -0.2218],
        [ 0.4845,  0.4719,  0.4690,  0.4024,  0.426

In [36]:
pred_hits = []
pred_labels = []

for i, track in enumerate(source_tracks):
    pred_labels.append([i for i in range(len(track))])

pred_hits = np.vstack(hits, dtype=np.float32)
pred_labels = np.array(labels, dtype=np.int32)

In [37]:
pred_labels.shape

(337,)

In [39]:
pred_hits

array([[-108.34672 , -223.61343 , -245.61658 ],
       [-119.73146 , -236.74675 , -247.18661 ],
       [-131.4774  , -249.61874 , -248.76071 ],
       ...,
       [ 252.77939 ,  760.2722  ,   37.27323 ],
       [ 261.05957 ,  775.5589  ,   42.891808],
       [ 269.47336 ,  790.7888  ,   48.515022]], dtype=float32)

In [44]:
draw_event(
    hits=hits,
    fakes=None,
    vertex=vertex.numpy,
    predicted_hits=source_tracks,
    predicted_tracks=np.array([i for i in range(len(source_tracks))]),
    labels=labels,
)