In [1]:
import os
from pathlib import Path

# change working directory to make src visible
os.chdir(Path.cwd().parent)

## Usage of Tracks Dataset

In [2]:
from src.tracknet.data.dataset import TrackMLTracksDataset

dataset = TrackMLTracksDataset(
    data_dirs=Path("data/trackml/train_100_events"),
    blacklist_dir=Path("data/trackml/blacklist_training"),
    validation_split=0.1,
)

In [3]:
from tqdm import tqdm

for i, track in enumerate(tqdm(dataset)):
    if i == 10000:
        break
print(track)
track

10000it [00:03, 2664.66it/s]

Track 1296 with 12 hits, pT=0.23





Track(
	event_id=event000001011,
	track_id=1296,
	particle_id=112591983549087744,
	hits=12,
	px=0.10, py=0.21, pz=-1.70,
	pT=0.23, charge=1,
	detector_info=[
		vol8-l2-m59
		vol7-l14-m71
		vol7-l12-m71
		vol7-l10-m72
		vol7-l8-m70
		vol7-l8-m72
		vol7-l6-m70
		vol12-l8-m106
		vol12-l6-m106
		vol12-l4-m106
		vol12-l2-m103
		vol12-l2-m105
	]
)

In [4]:
track.hits_xyz

array([[   15.08  ,    27.6324,  -230.094 ],
       [   44.1775,    68.8854,  -602.    ],
       [   52.9455,    79.3657,  -702.    ],
       [   64.0962,    91.6667,  -822.    ],
       [   77.2527,   105.202 ,  -958.    ],
       [   77.6356,   105.593 ,  -962.    ],
       [   91.4131,   118.562 , -1098.    ],
       [  169.966 ,   175.698 , -1795.5   ],
       [  211.342 ,   198.291 , -2145.5   ],
       [  258.921 ,   223.052 , -2545.5   ],
       [  303.762 ,   248.084 , -2948.5   ],
       [  304.467 ,   248.484 , -2954.5   ]], dtype=float32)

In [5]:
track.hits_cylindrical

array([[ 3.14794521e+01,  1.07123256e+00, -2.30093994e+02],
       [ 8.18342819e+01,  1.00054812e+00, -6.02000000e+02],
       [ 9.54051361e+01,  9.82488215e-01, -7.02000000e+02],
       [ 1.11853065e+02,  9.60586667e-01, -8.22000000e+02],
       [ 1.30519882e+02,  9.37401414e-01, -9.58000000e+02],
       [ 1.31061691e+02,  9.36812401e-01, -9.62000000e+02],
       [ 1.49710724e+02,  9.13980603e-01, -1.09800000e+03],
       [ 2.44454956e+02,  8.01979184e-01, -1.79550000e+03],
       [ 2.89801239e+02,  7.53548741e-01, -2.14550000e+03],
       [ 3.41748840e+02,  7.11113930e-01, -2.54550000e+03],
       [ 3.92195129e+02,  6.84844375e-01, -2.94850000e+03],
       [ 3.92994232e+02,  6.84497893e-01, -2.95450000e+03]], dtype=float32)

## Dataset with transforms and filters

In [6]:
from src.tracknet.data.dataset import TrackMLTracksDataset
from src.tracknet.data.transformations import MinMaxNormalizeXYZ, DropRepeatedLayerHits
from src.tracknet.data.filters import MinHitsFilter, PtFilter, FirstLayerFilter


dataset = TrackMLTracksDataset(
    data_dirs=Path("data/trackml/train_100_events"),
    blacklist_dir=Path("data/trackml/blacklist_training"),
    validation_split=0.1,
    transforms=[
        DropRepeatedLayerHits(),
        MinMaxNormalizeXYZ(
            min_xyz=(-1000.0, -1000.0, -3000.0),
            max_xyz=(1000.0, 1000.0, 3000.0)
        )
    ],
    filters=[
        MinHitsFilter(min_hits=3),
        PtFilter(min_pt=1.0),
        FirstLayerFilter(
            {(8, 2), (7, 14), (9, 2)}
        )
    ]
)

In [13]:
from tqdm import tqdm

for i, track in enumerate(tqdm(dataset)):
    if i == 8010:
        break
print(track)
track

8010it [00:16, 479.16it/s]

Track 4679 with 4 hits, pT=1.10





Track(
	event_id=event000001017,
	track_id=4679,
	particle_id=364814384683286528,
	hits=4,
	px=0.93, py=0.58, pz=-2.44,
	pT=1.10, charge=-1,
	detector_info=[
		vol8-l2-m90
		vol8-l4-m147
		vol8-l6-m187
		vol8-l8-m125
	]
)

In [14]:
track.hits_xyz

array([[0.5134294 , 0.5085623 , 0.48804566],
       [0.5299982 , 0.51957726, 0.473306  ],
       [0.54780436, 0.53205657, 0.45720285],
       [0.57052004, 0.54894966, 0.43624935]], dtype=float32)

Cylindrical coordinates were calculated using the original hits because transformations on top of normalized coordinates are mathematically incorrect.

In [15]:
track.hits_cylindrical

array([[  31.85358   ,    0.56759626,  -71.726     ],
       [  71.64246   ,    0.5782107 , -160.164     ],
       [ 115.11536   ,    0.5907058 , -256.783     ],
       [ 171.68736   ,    0.6067717 , -382.504     ]], dtype=float32)

## Usage in a DataLoader

In [16]:
from torch.utils.data import DataLoader
from src.tracknet.data.dataset import TrackMLTracksDataset
from src.tracknet.data.transformations import MinMaxNormalizeXYZ, DropRepeatedLayerHits
from src.tracknet.data.filters import MinHitsFilter, PtFilter, FirstLayerFilter
from src.tracknet.data.collate import collate_fn

BATCH_SIZE = 4

val_dataset = TrackMLTracksDataset(
    data_dirs=Path("data/trackml/train_100_events"),
    blacklist_dir=Path("data/trackml/blacklist_training"),
    validation_split=0.1,
    split="validation",
    transforms=[
        DropRepeatedLayerHits(),
        MinMaxNormalizeXYZ(
            min_xyz=(-1000.0, -1000.0, -3000.0),
            max_xyz=(1000.0, 1000.0, 3000.0)
        )
    ],
    filters=[
        MinHitsFilter(min_hits=3),
        PtFilter(min_pt=1.0),
        FirstLayerFilter(
            {(8, 2), (7, 14), (9, 2)}
        )
    ]
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn,
    num_workers=4,
    persistent_workers=True,
)

In [17]:
for step, batch in enumerate(val_loader):
    if step == 10:
        break

for i in range(BATCH_SIZE):
    print(batch["inputs"][i])
    print(batch["input_lengths"][i])
    print(batch["targets"][i])
    print()

tensor([[0.4874, 0.4901, 0.4956],
        [0.4713, 0.4778, 0.4899],
        [0.4537, 0.4650, 0.4840],
        [0.4309, 0.4492, 0.4763],
        [0.3918, 0.4240, 0.4635],
        [0.3516, 0.4007, 0.4506],
        [0.2857, 0.3674, 0.4302],
        [0.2140, 0.3375, 0.4089],
        [0.1376, 0.3122, 0.3861]])
9
tensor([[0.4713, 0.4778, 0.4899],
        [0.4537, 0.4650, 0.4840],
        [0.4309, 0.4492, 0.4763],
        [0.3918, 0.4240, 0.4635],
        [0.3516, 0.4007, 0.4506],
        [0.2857, 0.3674, 0.4302],
        [0.2140, 0.3375, 0.4089],
        [0.1376, 0.3122, 0.3861],
        [0.0358, 0.2880, 0.3582]])

tensor([[0.4856, 0.4929, 0.4996],
        [0.4673, 0.4842, 0.4990],
        [0.4477, 0.4751, 0.4985],
        [0.4219, 0.4636, 0.4977],
        [0.3802, 0.4458, 0.4965],
        [0.3338, 0.4274, 0.4953],
        [0.2672, 0.4034, 0.4934],
        [0.1946, 0.3802, 0.4913],
        [0.1152, 0.3585, 0.4891]])
9
tensor([[0.4673, 0.4842, 0.4990],
        [0.4477, 0.4751, 0.4985],
      

In [18]:
from tqdm import tqdm

for i, batch in enumerate(tqdm(val_loader)):
    _ = batch["inputs"]

10780it [00:23, 449.29it/s]


In [19]:
for step, batch in enumerate(val_loader):
    if 2 in set(batch["input_lengths"]):  # track with 3 hits
        break

batch["input_lengths"]

[9, 9, 8, 2]

In [20]:
batch["inputs"][2:], batch["input_lengths"][2:]

(tensor([[[0.4842, 0.5042, 0.4983],
          [0.4652, 0.5091, 0.4963],
          [0.4440, 0.5141, 0.4940],
          [0.4167, 0.5202, 0.4911],
          [0.3737, 0.5286, 0.4865],
          [0.3254, 0.5363, 0.4815],
          [0.2539, 0.5447, 0.4741],
          [0.1728, 0.5497, 0.4658],
          [0.0000, 0.0000, 0.0000]],
 
         [[0.4845, 0.5052, 0.4800],
          [0.4658, 0.5110, 0.4558],
          [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000]]]),
 [8, 2])

In [21]:
batch["targets"][2:]

tensor([[[0.4652, 0.5091, 0.4963],
         [0.4440, 0.5141, 0.4940],
         [0.4167, 0.5202, 0.4911],
         [0.3737, 0.5286, 0.4865],
         [0.3254, 0.5363, 0.4815],
         [0.2539, 0.5447, 0.4741],
         [0.1728, 0.5497, 0.4658],
         [0.0918, 0.5502, 0.4576],
         [0.0000, 0.0000, 0.0000]],

        [[0.4658, 0.5110, 0.4558],
         [0.4445, 0.5170, 0.4286],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000]]])

In [22]:
batch["target_mask"][2:]

tensor([[ True,  True,  True,  True,  True,  True,  True,  True, False,  True,
          True,  True,  True,  True,  True,  True, False],
        [ True,  True, False, False, False, False, False, False, False,  True,
         False, False, False, False, False, False, False]])

### Verify min hit filter

In [23]:
for step, batch in enumerate(val_loader):
    if 1 in set(batch["input_lengths"]):  # track with 2 hits
        raise ValueError("Track with less than 3 hits found")