In [1]:
import numpy as np
import os

from pathlib import Path

# change working directory to make src visible
os.chdir(Path.cwd().parent)

from src.dataset import SPDEventsDataset
from src.normalization import ConstraintsNormalizer, TrackParamsNormalizer

In [2]:
np.random.seed(42)

dataset = SPDEventsDataset(
    hits_normalizer=ConstraintsNormalizer(),
    track_params_normalizer=TrackParamsNormalizer(),
    shuffle=True,
)

In [3]:
sample = dataset[0]

In [4]:
sample

{'hits': array([[-0.28403901,  0.15882117, -0.37502865],
        [-0.80633537, -0.22986594, -0.19392248],
        [-0.56487029,  0.12186746,  0.54244963],
        ...,
        [-0.35769091, -0.25200208, -0.65893922],
        [ 0.1086406 , -0.87180812, -0.56872047],
        [-0.49563133, -0.045159  ,  0.81616737]]),
 'hit_labels': array([ 8, -1, -1, ..., -1, -1, -1]),
 'params': array([[0.5029184 , 0.49918765, 0.5291694 , 0.38292688, 0.9717121 ,
         0.25415015, 0.        ],
        [0.5029184 , 0.49918765, 0.5291694 , 0.5245114 , 0.4103968 ,
         0.08475863, 0.        ],
        [0.5029184 , 0.49918765, 0.5291694 , 0.9497999 , 0.7257195 ,
         0.4271637 , 1.        ],
        [0.5029184 , 0.49918765, 0.5291694 , 0.47347176, 0.17320187,
         0.5422352 , 0.        ],
        [0.5029184 , 0.49918765, 0.5291694 , 0.5986585 , 0.15601864,
         0.74151593, 0.        ],
        [0.5029184 , 0.49918765, 0.5291694 , 0.6775644 , 0.01658783,
         0.49230057, 0.        ],
  

In [5]:
sample["hits"].shape

(3288, 3)

In [6]:
sample["hit_labels"].shape

(3288,)

In [7]:
sample["param_labels"]

array([1, 8, 6, 4, 0, 2, 3, 5, 9, 7], dtype=int32)

In [8]:
np.round(sample["params"], 3)

array([[0.503, 0.499, 0.529, 0.383, 0.972, 0.254, 0.   ],
       [0.503, 0.499, 0.529, 0.525, 0.41 , 0.085, 0.   ],
       [0.503, 0.499, 0.529, 0.95 , 0.726, 0.427, 1.   ],
       [0.503, 0.499, 0.529, 0.473, 0.173, 0.542, 0.   ],
       [0.503, 0.499, 0.529, 0.599, 0.156, 0.742, 0.   ],
       [0.503, 0.499, 0.529, 0.678, 0.017, 0.492, 0.   ],
       [0.503, 0.499, 0.529, 0.844, 0.93 , 0.829, 1.   ],
       [0.503, 0.499, 0.529, 0.946, 0.781, 0.781, 0.   ],
       [0.503, 0.499, 0.529, 0.908, 0.316, 0.443, 1.   ],
       [0.503, 0.499, 0.529, 0.973, 0.602, 0.686, 0.   ]], dtype=float32)

In [10]:
vectorized_denorm = np.vectorize(dataset.track_params_normalizer.denormalize)
orig_params = np.apply_along_axis(
    dataset.track_params_normalizer.denormalize, axis=1, arr=sample["params"][:5]
)
np.round(orig_params, 2)

array([[ 4.9700e+00, -1.3800e+00,  1.3920e+02,  4.4463e+02,  6.1100e+00,
         8.0000e-01, -1.0000e+00],
       [ 4.9700e+00, -1.3800e+00,  1.3920e+02,  5.7206e+02,  2.5800e+00,
         2.7000e-01, -1.0000e+00],
       [ 4.9700e+00, -1.3800e+00,  1.3920e+02,  9.5482e+02,  4.5600e+00,
         1.3400e+00,  1.0000e+00],
       [ 4.9700e+00, -1.3800e+00,  1.3920e+02,  5.2612e+02,  1.0900e+00,
         1.7000e+00, -1.0000e+00],
       [ 4.9700e+00, -1.3800e+00,  1.3920e+02,  6.3879e+02,  9.8000e-01,
         2.3300e+00, -1.0000e+00]], dtype=float32)

In [11]:
dataset.track_params_normalizer = None
sample = dataset[0]
np.round(sample["params"], 2)

array([[ 4.9700e+00, -1.3800e+00,  1.3920e+02,  4.4463e+02,  6.1100e+00,
         8.0000e-01, -1.0000e+00],
       [ 4.9700e+00, -1.3800e+00,  1.3920e+02,  5.7206e+02,  2.5800e+00,
         2.7000e-01, -1.0000e+00],
       [ 4.9700e+00, -1.3800e+00,  1.3920e+02,  9.5482e+02,  4.5600e+00,
         1.3400e+00,  1.0000e+00],
       [ 4.9700e+00, -1.3800e+00,  1.3920e+02,  5.2612e+02,  1.0900e+00,
         1.7000e+00, -1.0000e+00],
       [ 4.9700e+00, -1.3800e+00,  1.3920e+02,  6.3879e+02,  9.8000e-01,
         2.3300e+00, -1.0000e+00],
       [ 4.9700e+00, -1.3800e+00,  1.3920e+02,  7.0981e+02,  1.0000e-01,
         1.5500e+00, -1.0000e+00],
       [ 4.9700e+00, -1.3800e+00,  1.3920e+02,  8.5979e+02,  5.8400e+00,
         2.6000e+00,  1.0000e+00],
       [ 4.9700e+00, -1.3800e+00,  1.3920e+02,  9.5158e+02,  4.9100e+00,
         2.4500e+00, -1.0000e+00],
       [ 4.9700e+00, -1.3800e+00,  1.3920e+02,  9.1708e+02,  1.9900e+00,
         1.3900e+00,  1.0000e+00],
       [ 4.9700e+00, -1.3800

## Test dataloader

In [12]:
from torch.utils.data import DataLoader
from src.dataset import collate_fn

#### Without normalization for parameters

In [13]:
train_loader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn)
batch = next(iter(train_loader))
[f"{k}: {v.shape}" for k, v in batch.items()]

['inputs: torch.Size([4, 3807, 3])',
 'mask: torch.Size([4, 3807])',
 'targets: torch.Size([4, 10, 7])',
 'orig_params: torch.Size([4, 10, 7])']

In [14]:
batch["targets"]

tensor([[[ 4.9671e+00, -1.3826e+00,  1.3920e+02,  4.4463e+02,  6.1054e+00,
           7.9844e-01, -1.0000e+00],
         [ 4.9671e+00, -1.3826e+00,  1.3920e+02,  5.7206e+02,  2.5786e+00,
           2.6628e-01, -1.0000e+00],
         [ 4.9671e+00, -1.3826e+00,  1.3920e+02,  9.5482e+02,  4.5598e+00,
           1.3420e+00,  1.0000e+00],
         [ 4.9671e+00, -1.3826e+00,  1.3920e+02,  5.2612e+02,  1.0883e+00,
           1.7035e+00, -1.0000e+00],
         [ 4.9671e+00, -1.3826e+00,  1.3920e+02,  6.3879e+02,  9.8029e-01,
           2.3295e+00, -1.0000e+00],
         [ 4.9671e+00, -1.3826e+00,  1.3920e+02,  7.0981e+02,  1.0422e-01,
           1.5466e+00, -1.0000e+00],
         [ 4.9671e+00, -1.3826e+00,  1.3920e+02,  8.5979e+02,  5.8435e+00,
           2.6044e+00,  1.0000e+00],
         [ 4.9671e+00, -1.3826e+00,  1.3920e+02,  9.5158e+02,  4.9090e+00,
           2.4545e+00, -1.0000e+00],
         [ 4.9671e+00, -1.3826e+00,  1.3920e+02,  9.1708e+02,  1.9867e+00,
           1.3933e+00,  1.000

In [15]:
event_idx = 3
np.unique(dataset[event_idx]["hit_labels"]).size - 1

10

In [16]:
np.unique(dataset[3]["hit_labels"][512:], return_counts=True)

(array([-1,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9]),
 array([1925,   28,   27,   27,   31,   24,   23,   29,   29,   32,   27]))

#### With normalized parameters

In [17]:
def print_batch_statistics(batch):
    for k, v in batch.items():
        if k != "mask":
            print(f"{k}: shape: {list(v.shape)}, "
                f"dtype: {v.dtype}, (min, mean, max): "
                f"({v.min().item() :.4f}, {v.mean().item() :.4f}, {v.max().item() :.4f})")
        else:
            print(f"{k}: shape: {list(v.shape)}, dtype: {v.dtype}, "
                f"real hits ratio: {v.sum() / v.numel() :.4f}")

In [18]:
dataset = SPDEventsDataset(
    hits_normalizer=ConstraintsNormalizer(),
    track_params_normalizer=TrackParamsNormalizer(),
    shuffle=True,
)

train_loader = DataLoader(dataset, batch_size=16, collate_fn=collate_fn)
batch = next(iter(train_loader))
print_batch_statistics(batch)

inputs: shape: [16, 3819, 3], dtype: torch.float32, (min, mean, max): (-0.9999, -0.0004, 1.0000)
mask: shape: [16, 3819], dtype: torch.bool, real hits ratio: 0.7733
targets: shape: [16, 10, 7], dtype: torch.float32, (min, mean, max): (0.0000, 0.5111, 1.0000)
orig_params: shape: [16, 10, 7], dtype: torch.float32, (min, mean, max): (-266.1517, 85.3959, 998.2501)


#### With normalization and truncation enabled

In [19]:
dataset = SPDEventsDataset(
    hits_normalizer=ConstraintsNormalizer(),
    track_params_normalizer=TrackParamsNormalizer(),
    shuffle=True,
    max_event_tracks=10,
    truncation_length=512,
)

train_loader = DataLoader(dataset, batch_size=32, collate_fn=collate_fn)
batch = next(iter(train_loader))
print_batch_statistics(batch)

inputs: shape: [32, 512, 3], dtype: torch.float32, (min, mean, max): (-0.9999, 0.0002, 0.9999)
mask: shape: [32, 512], dtype: torch.bool, real hits ratio: 1.0000
targets: shape: [32, 10, 7], dtype: torch.float32, (min, mean, max): (0.0000, 0.5037, 1.0000)
orig_params: shape: [32, 10, 7], dtype: torch.float32, (min, mean, max): (-280.1523, 84.8576, 997.9048)


With smaller number of tracks

In [20]:
dataset = SPDEventsDataset(
    hits_normalizer=ConstraintsNormalizer(),
    track_params_normalizer=TrackParamsNormalizer(),
    shuffle=True,
    max_event_tracks=5,
    truncation_length=512,
)

train_loader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn)
batch_gen = iter(train_loader)
batch = next(batch_gen)
print_batch_statistics(batch)

inputs: shape: [4, 512, 3], dtype: torch.float32, (min, mean, max): (-0.9970, 0.0228, 0.9986)
mask: shape: [4, 512], dtype: torch.bool, real hits ratio: 1.0000
targets: shape: [4, 5, 7], dtype: torch.float32, (min, mean, max): (0.0000, 0.4657, 1.0000)
orig_params: shape: [4, 5, 7], dtype: torch.float32, (min, mean, max): (-191.0368, 70.4013, 997.3988)


Next batch

In [21]:
batch = next(batch_gen)
print_batch_statistics(batch)

inputs: shape: [4, 512, 3], dtype: torch.float32, (min, mean, max): (-0.9999, -0.0232, 0.9961)
mask: shape: [4, 512], dtype: torch.bool, real hits ratio: 1.0000
targets: shape: [4, 5, 7], dtype: torch.float32, (min, mean, max): (0.0000, 0.5007, 1.0000)
orig_params: shape: [4, 5, 7], dtype: torch.float32, (min, mean, max): (-165.8356, 60.5103, 923.9129)
