In [1]:
import os
from pathlib import Path

# change working directory to make src visible
os.chdir(Path.cwd().parent)

## Usage of Tracks Dataset

In [2]:
from src.tracknet.data.dataset import TrackMLTracksDataset

dataset = TrackMLTracksDataset(
    data_dirs=Path("data/trackml/train_100_events"),
    blacklist_dir=Path("data/trackml/blacklist_training"),
    min_hits=3,
    validation_split=0.1,
)

In [3]:
from tqdm import tqdm

for i, sample in enumerate(tqdm(dataset)):
    if i == 10000:
        break
print(sample)
sample

10000it [00:05, 1921.26it/s]

Track 290 with 13 hits, pT=1.08





Track(
	event_id=event000001018,
	track_id=290,
	particle_id=270216458678566912,
	hits=13,
	px=0.65, py=-0.86, pz=7.24,
	pT=1.08, charge=1,
	detector_info=[
		vol8-l2-m150
		vol8-l2-m166
		vol8-l4-m428
		vol9-l2-m36
		vol9-l2-m39
		vol9-l4-m40
		vol9-l6-m38
		vol9-l8-m38
		vol9-l10-m38
		vol14-l6-m58
		vol14-l8-m58
		vol14-l10-m58
		vol14-l12-m57
	]
)

In [4]:
sample.hits_xyz

array([[  18.7269,  -25.4278,  209.98  ],
       [  19.0147,  -25.8346,  213.328 ],
       [  42.3317,  -58.9278,  485.948 ],
       [  51.7074,  -72.6563,  598.    ],
       [  52.0456,  -73.1403,  602.    ],
       [  60.2827,  -85.4704,  702.    ],
       [  69.7289,  -99.8009,  818.    ],
       [  80.8678, -117.212 ,  958.    ],
       [  91.8168, -134.7   , 1098.    ],
       [ 144.255 , -223.26  , 1798.5   ],
       [ 168.742 , -267.478 , 2148.5   ],
       [ 196.584 , -318.108 , 2548.5   ],
       [ 224.887 , -369.526 , 2954.5   ]], dtype=float32)

In [5]:
sample.hits_cylindrical

array([[ 3.1579578e+01, -9.3600851e-01,  1.4927454e-01],
       [ 3.2077801e+01, -9.3630469e-01,  1.4925027e-01],
       [ 7.2556587e+01, -9.4785053e-01,  1.4821444e-01],
       [ 8.9177315e+01, -9.5228016e-01,  1.4803502e-01],
       [ 8.9767746e+01, -9.5233685e-01,  1.4802516e-01],
       [ 1.0459060e+02, -9.5651883e-01,  1.4790149e-01],
       [ 1.2174703e+02, -9.6095681e-01,  1.4775039e-01],
       [ 1.4240173e+02, -9.6686238e-01,  1.4756432e-01],
       [ 1.6301660e+02, -9.7249961e-01,  1.4739020e-01],
       [ 2.6580920e+02, -9.9714643e-01,  1.4673272e-01],
       [ 3.1625677e+02, -1.0079919e+00,  1.4614934e-01],
       [ 3.7394916e+02, -1.0172619e+00,  1.4569336e-01],
       [ 4.3257788e+02, -1.0240902e+00,  1.4538027e-01]], dtype=float32)

## Dataset with transforms

In [6]:
from src.tracknet.data.dataset import TrackMLTracksDataset
from src.tracknet.data.transformations import MinMaxNormalizeXYZ, DropRepeatedLayerHits
import numpy as np


dataset = TrackMLTracksDataset(
    data_dirs=Path("data/trackml/train_100_events"),
    blacklist_dir=Path("data/trackml/blacklist_training"),
    min_hits=3,
    validation_split=0.1,
    transforms=[
        DropRepeatedLayerHits(),
        MinMaxNormalizeXYZ(
            min_xyz=np.array([-1000.0, -1000.0, -3000.0]),
            max_xyz=np.array([1000.0, 1000.0, 3000.0])
        )]
)

In [7]:
from tqdm import tqdm

for i, sample in enumerate(tqdm(dataset)):
    if i == 10000:
        break
print(sample)
sample

10000it [00:05, 1868.43it/s]

Track 290 with 11 hits, pT=1.08





Track(
	event_id=event000001018,
	track_id=290,
	particle_id=270216458678566912,
	hits=11,
	px=0.65, py=-0.86, pz=7.24,
	pT=1.08, charge=1,
	detector_info=[
		vol8-l2-m150
		vol8-l4-m428
		vol9-l2-m36
		vol9-l4-m40
		vol9-l6-m38
		vol9-l8-m38
		vol9-l10-m38
		vol14-l6-m58
		vol14-l8-m58
		vol14-l10-m58
		vol14-l12-m57
	]
)

In [8]:
sample.hits_xyz

array([[0.5093635 , 0.4872861 , 0.5349967 ],
       [0.52116585, 0.4705361 , 0.5809913 ],
       [0.5258537 , 0.46367183, 0.59966666],
       [0.53014135, 0.4572648 , 0.617     ],
       [0.5348644 , 0.45009956, 0.63633335],
       [0.5404339 , 0.441394  , 0.65966666],
       [0.5459084 , 0.43265   , 0.683     ],
       [0.5721275 , 0.38837   , 0.79975   ],
       [0.584371  , 0.36626098, 0.8580833 ],
       [0.598292  , 0.340946  , 0.92475   ],
       [0.6124435 , 0.315237  , 0.9924167 ]], dtype=float32)

In [12]:
sample.hits_cylindrical

array([[ 3.1579578e+01, -9.3600851e-01,  1.4927454e-01],
       [ 3.2077801e+01, -9.3630469e-01,  1.4925027e-01],
       [ 7.2556587e+01, -9.4785053e-01,  1.4821444e-01],
       [ 8.9177315e+01, -9.5228016e-01,  1.4803502e-01],
       [ 8.9767746e+01, -9.5233685e-01,  1.4802516e-01],
       [ 1.0459060e+02, -9.5651883e-01,  1.4790149e-01],
       [ 1.2174703e+02, -9.6095681e-01,  1.4775039e-01],
       [ 1.4240173e+02, -9.6686238e-01,  1.4756432e-01],
       [ 1.6301660e+02, -9.7249961e-01,  1.4739020e-01],
       [ 2.6580920e+02, -9.9714643e-01,  1.4673272e-01],
       [ 3.1625677e+02, -1.0079919e+00,  1.4614934e-01],
       [ 3.7394916e+02, -1.0172619e+00,  1.4569336e-01],
       [ 4.3257788e+02, -1.0240902e+00,  1.4538027e-01]], dtype=float32)

In [9]:
sample.hits_cylindrical

array([[ 3.1579578e+01, -9.3600851e-01,  1.4927454e-01],
       [ 7.2556587e+01, -9.4785053e-01,  1.4821444e-01],
       [ 8.9177315e+01, -9.5228016e-01,  1.4803502e-01],
       [ 1.0459060e+02, -9.5651883e-01,  1.4790149e-01],
       [ 1.2174703e+02, -9.6095681e-01,  1.4775039e-01],
       [ 1.4240173e+02, -9.6686238e-01,  1.4756432e-01],
       [ 1.6301660e+02, -9.7249961e-01,  1.4739020e-01],
       [ 2.6580920e+02, -9.9714643e-01,  1.4673272e-01],
       [ 3.1625677e+02, -1.0079919e+00,  1.4614934e-01],
       [ 3.7394916e+02, -1.0172619e+00,  1.4569336e-01],
       [ 4.3257788e+02, -1.0240902e+00,  1.4538027e-01]], dtype=float32)

## Usage in a DataLoader

In [23]:
import numpy as np
from torch.utils.data import DataLoader
from src.tracknet.data.dataset import TrackMLTracksDataset
from src.tracknet.data.transformations import MinMaxNormalizeXYZ, DropRepeatedLayerHits
from src.tracknet.data.utils import collate_fn

BATCH_SIZE = 4

val_dataset = TrackMLTracksDataset(
    data_dirs=Path("data/trackml/train_100_events"),
    blacklist_dir=Path("data/trackml/blacklist_training"),
    min_hits=3,
    validation_split=0.1,
    split="validation",
    transforms=[
        DropRepeatedLayerHits(),
        MinMaxNormalizeXYZ(
            min_xyz=np.array([-1000.0, -1000.0, -3000.0]),
            max_xyz=np.array([1000.0, 1000.0, 3000.0])
        )]
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn,
    num_workers=4,
    persistent_workers=True,
)

In [24]:
for step, batch in enumerate(val_loader):
    if step == 10:
        break

for i in range(BATCH_SIZE):
    print(batch["inputs"][i])
    print(batch["targets"][i])
    print()

tensor([[0.5140, 0.5083, 0.4897],
        [0.5309, 0.5180, 0.4774],
        [0.5503, 0.5289, 0.4633],
        [0.5751, 0.5424, 0.4455],
        [0.6126, 0.5618, 0.4188],
        [0.6579, 0.5839, 0.3870],
        [0.7233, 0.6133, 0.3417],
        [0.7889, 0.6398, 0.2969],
        [0.8572, 0.6646, 0.2509]])
tensor([[0.5309, 0.5180, 0.4774],
        [0.5503, 0.5289, 0.4633],
        [0.5751, 0.5424, 0.4455],
        [0.6126, 0.5618, 0.4188],
        [0.6579, 0.5839, 0.3870],
        [0.7233, 0.6133, 0.3417],
        [0.7889, 0.6398, 0.2969],
        [0.8572, 0.6646, 0.2509],
        [0.9367, 0.6904, 0.1991]])

tensor([[0.5143, 0.5068, 0.4878],
        [0.5329, 0.5156, 0.4718],
        [0.5523, 0.5245, 0.4553],
        [0.5781, 0.5360, 0.4335],
        [0.6198, 0.5539, 0.3984],
        [0.6661, 0.5727, 0.3597],
        [0.7425, 0.6016, 0.2964],
        [0.7985, 0.6212, 0.2508],
        [0.8616, 0.6418, 0.2004]])
tensor([[0.5329, 0.5156, 0.4718],
        [0.5523, 0.5245, 0.4553],
        [0

In [25]:
from tqdm import tqdm

for i, batch in enumerate(tqdm(val_loader)):
    _ = batch["inputs"]

12148it [00:08, 1425.57it/s]
