In [1]:
import os
from pathlib import Path

# change working directory to make src visible
os.chdir(Path.cwd().parent)

## Usage of Tracks Dataset

In [2]:
from src.tracknet.data.dataset import TrackMLTracksDataset

dataset = TrackMLTracksDataset(
    data_dirs=Path("data/trackml/train_100_events"),
    blacklist_dir=Path("data/trackml/blacklist_training"),
    validation_split=0.1,
)

In [3]:
from tqdm import tqdm

for i, track in enumerate(tqdm(dataset)):
    if i == 10000:
        break
print(track)
track

10000it [00:03, 2641.51it/s]

Track 1296 with 12 hits, pT=0.23





Track(
	event_id=event000001011,
	track_id=1296,
	particle_id=112591983549087744,
	hits=12,
	px=0.10, py=0.21, pz=-1.70,
	pT=0.23, charge=1,
	detector_info=[
		vol8-l2-m59
		vol7-l14-m71
		vol7-l12-m71
		vol7-l10-m72
		vol7-l8-m70
		vol7-l8-m72
		vol7-l6-m70
		vol12-l8-m106
		vol12-l6-m106
		vol12-l4-m106
		vol12-l2-m103
		vol12-l2-m105
	]
)

In [4]:
track.hits_xyz

array([[   15.08  ,    27.6324,  -230.094 ],
       [   44.1775,    68.8854,  -602.    ],
       [   52.9455,    79.3657,  -702.    ],
       [   64.0962,    91.6667,  -822.    ],
       [   77.2527,   105.202 ,  -958.    ],
       [   77.6356,   105.593 ,  -962.    ],
       [   91.4131,   118.562 , -1098.    ],
       [  169.966 ,   175.698 , -1795.5   ],
       [  211.342 ,   198.291 , -2145.5   ],
       [  258.921 ,   223.052 , -2545.5   ],
       [  303.762 ,   248.084 , -2948.5   ],
       [  304.467 ,   248.484 , -2954.5   ]], dtype=float32)

In [5]:
track.hits_cylindrical

array([[ 31.479452  ,   1.0712326 ,   3.0056255 ],
       [ 81.83428   ,   1.0005481 ,   3.0064836 ],
       [ 95.405136  ,   0.9824882 ,   3.0065155 ],
       [111.853065  ,   0.96058667,   3.006349  ],
       [130.51988   ,   0.9374014 ,   3.0061843 ],
       [131.06169   ,   0.9368124 ,   3.0061874 ],
       [149.71072   ,   0.9139806 ,   3.0060797 ],
       [244.45496   ,   0.8019792 ,   3.006276  ],
       [289.80124   ,   0.75354874,   3.0073314 ],
       [341.74884   ,   0.7111139 ,   3.0081346 ],
       [392.19513   ,   0.6848444 ,   3.0093539 ],
       [392.99423   ,   0.6844979 ,   3.0093534 ]], dtype=float32)

## Dataset with transforms and filters

In [6]:
from src.tracknet.data.dataset import TrackMLTracksDataset
from src.tracknet.data.transformations import MinMaxNormalizeXYZ, DropRepeatedLayerHits
from src.tracknet.data.filters import MinHitsFilter, PtFilter, FirstLayerFilter


dataset = TrackMLTracksDataset(
    data_dirs=Path("data/trackml/train_100_events"),
    blacklist_dir=Path("data/trackml/blacklist_training"),
    validation_split=0.1,
    transforms=[
        DropRepeatedLayerHits(),
        MinMaxNormalizeXYZ(
            min_xyz=(-1000.0, -1000.0, -3000.0),
            max_xyz=(1000.0, 1000.0, 3000.0)
        )
    ],
    filters=[
        MinHitsFilter(min_hits=3),
        PtFilter(min_pt=1.0),
        FirstLayerFilter(
            {(8, 2), (7, 14), (9, 2)}
        )
    ]
)

In [7]:
from tqdm import tqdm

for i, track in enumerate(tqdm(dataset)):
    if i == 10000:
        break
print(track)
track

10000it [00:20, 481.48it/s]

Track 5100 with 4 hits, pT=1.53





Track(
	event_id=event000001019,
	track_id=5100,
	particle_id=414337006873608192,
	hits=4,
	px=1.41, py=-0.58, pz=3.14,
	pT=1.53, charge=-1,
	detector_info=[
		vol8-l2-m135
		vol8-l4-m302
		vol8-l6-m543
		vol8-l8-m971
	]
)

In [8]:
track.hits_xyz

array([[0.5148276 , 0.49398863, 0.5130898 ],
       [0.5330384 , 0.4869121 , 0.5265055 ],
       [0.553887  , 0.479182  , 0.5417965 ],
       [0.580604  , 0.4697951 , 0.5612945 ]], dtype=float32)

In [9]:
track.hits_cylindrical

array([[ 31.999722  ,  -0.38516566,   0.38690275],
       [ 71.07261   ,  -0.37717617,   0.42027697],
       [115.53697   ,  -0.36866397,   0.4317265 ],
       [172.15504   ,  -0.35853574,   0.43781093]], dtype=float32)

## Usage in a DataLoader

In [10]:
from torch.utils.data import DataLoader
from src.tracknet.data.dataset import TrackMLTracksDataset
from src.tracknet.data.transformations import MinMaxNormalizeXYZ, DropRepeatedLayerHits
from src.tracknet.data.filters import MinHitsFilter, PtFilter, FirstLayerFilter
from src.tracknet.data.collate import collate_fn

BATCH_SIZE = 4

val_dataset = TrackMLTracksDataset(
    data_dirs=Path("data/trackml/train_100_events"),
    blacklist_dir=Path("data/trackml/blacklist_training"),
    validation_split=0.1,
    split="validation",
    transforms=[
        DropRepeatedLayerHits(),
        MinMaxNormalizeXYZ(
            min_xyz=(-1000.0, -1000.0, -3000.0),
            max_xyz=(1000.0, 1000.0, 3000.0)
        )
    ],
    filters=[
        MinHitsFilter(min_hits=3),
        PtFilter(min_pt=1.0),
        FirstLayerFilter(
            {(8, 2), (7, 14), (9, 2)}
        )
    ]
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn,
    num_workers=4,
    persistent_workers=True,
)

In [11]:
for step, batch in enumerate(val_loader):
    if step == 10:
        break

for i in range(BATCH_SIZE):
    print(batch["inputs"][i])
    print(batch["input_lengths"][i])
    print(batch["targets"][i])
    print()

tensor([[0.4874, 0.4901, 0.4956],
        [0.4713, 0.4778, 0.4899],
        [0.4537, 0.4650, 0.4840],
        [0.4309, 0.4492, 0.4763],
        [0.3918, 0.4240, 0.4635],
        [0.3516, 0.4007, 0.4506],
        [0.2857, 0.3674, 0.4302],
        [0.2140, 0.3375, 0.4089],
        [0.1376, 0.3122, 0.3861]])
9
tensor([[0.4713, 0.4778, 0.4899],
        [0.4537, 0.4650, 0.4840],
        [0.4309, 0.4492, 0.4763],
        [0.3918, 0.4240, 0.4635],
        [0.3516, 0.4007, 0.4506],
        [0.2857, 0.3674, 0.4302],
        [0.2140, 0.3375, 0.4089],
        [0.1376, 0.3122, 0.3861],
        [0.0358, 0.2880, 0.3582]])

tensor([[0.4856, 0.4929, 0.4996],
        [0.4673, 0.4842, 0.4990],
        [0.4477, 0.4751, 0.4985],
        [0.4219, 0.4636, 0.4977],
        [0.3802, 0.4458, 0.4965],
        [0.3338, 0.4274, 0.4953],
        [0.2672, 0.4034, 0.4934],
        [0.1946, 0.3802, 0.4913],
        [0.1152, 0.3585, 0.4891]])
9
tensor([[0.4673, 0.4842, 0.4990],
        [0.4477, 0.4751, 0.4985],
      

In [12]:
from tqdm import tqdm

for i, batch in enumerate(tqdm(val_loader)):
    _ = batch["inputs"]

10780it [00:24, 443.13it/s]


In [13]:
for step, batch in enumerate(val_loader):
    if 2 in set(batch["input_lengths"]):  # track with 3 hits
        break

batch["input_lengths"]

[9, 9, 8, 2]

In [14]:
batch["inputs"][2:], batch["input_lengths"][2:]

(tensor([[[0.4842, 0.5042, 0.4983],
          [0.4652, 0.5091, 0.4963],
          [0.4440, 0.5141, 0.4940],
          [0.4167, 0.5202, 0.4911],
          [0.3737, 0.5286, 0.4865],
          [0.3254, 0.5363, 0.4815],
          [0.2539, 0.5447, 0.4741],
          [0.1728, 0.5497, 0.4658],
          [0.0000, 0.0000, 0.0000]],
 
         [[0.4845, 0.5052, 0.4800],
          [0.4658, 0.5110, 0.4558],
          [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000]]]),
 [8, 2])

In [15]:
batch["targets"][2:]

tensor([[[0.4652, 0.5091, 0.4963],
         [0.4440, 0.5141, 0.4940],
         [0.4167, 0.5202, 0.4911],
         [0.3737, 0.5286, 0.4865],
         [0.3254, 0.5363, 0.4815],
         [0.2539, 0.5447, 0.4741],
         [0.1728, 0.5497, 0.4658],
         [0.0918, 0.5502, 0.4576],
         [0.0000, 0.0000, 0.0000]],

        [[0.4658, 0.5110, 0.4558],
         [0.4445, 0.5170, 0.4286],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000]]])

In [16]:
batch["target_mask"][2:]

tensor([[ True,  True,  True,  True,  True,  True,  True,  True, False,  True,
          True,  True,  True,  True,  True,  True, False],
        [ True,  True, False, False, False, False, False, False, False,  True,
         False, False, False, False, False, False, False]])

### Verify min hit filter

In [36]:
for step, batch in enumerate(val_loader):
    if 1 in set(batch["input_lengths"]):  # track with 2 hits
        raise ValueError("Track with less than 3 hits found")