In [1]:
import os
from pathlib import Path

# change working directory to make src visible
os.chdir(Path.cwd().parent)

## Prepare dataset

In [2]:
from torch.utils.data import DataLoader
from src.tracknet.data.dataset import TrackMLTracksDataset
from src.tracknet.data.transformations import MinMaxNormalizeXYZ, DropRepeatedLayerHits
from src.tracknet.data.filters import MinHitsFilter, PtFilter, FirstLayerFilter
from src.tracknet.data.collate import collate_fn

BATCH_SIZE = 4
hits_normalizer = MinMaxNormalizeXYZ(
    min_xyz=(-1000.0, -1000.0, -3000.0),
    max_xyz=(1000.0, 1000.0, 3000.0)
)

val_dataset = TrackMLTracksDataset(
    data_dirs=Path("data/trackml/train_100_events"),
    blacklist_dir=Path("data/trackml/blacklist_training"),
    validation_split=0.1,
    split="validation",
    transforms=[
        DropRepeatedLayerHits(),
        hits_normalizer
    ],
    filters=[
        MinHitsFilter(min_hits=3),
        PtFilter(min_pt=1.0),
        FirstLayerFilter(
            {(8, 2), (7, 14), (9, 2)}
        )
    ]
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn,
    num_workers=4,
    persistent_workers=True,
)

## Initialize the model and loss

In [3]:
from src.tracknet.model import StepAheadTrackNET
from src.tracknet.loss import TrackNetLoss
# , HitDensityMetric2
from src.tracknet.metrics import SearchAreaMetric, HitEfficiencyMetric, HitDensityMetric

model = StepAheadTrackNET()
criterion = TrackNetLoss()
t1_area_metric = SearchAreaMetric("t1")
t2_area_metric = SearchAreaMetric("t2")
t1_efficiency = HitEfficiencyMetric("t1")
t2_efficiency = HitEfficiencyMetric("t2")
t1_density = HitDensityMetric(
    "outputs/hit_density_stats.npz", "t1", normalizer=hits_normalizer)
t2_density = HitDensityMetric(
    "outputs/hit_density_stats.npz", "t2", normalizer=hits_normalizer)

# t1_density2 = HitDensityMetric2(
#     "outputs/hit_density_stats.npz", "t1", normalizer=hits_normalizer)
# t2_density2 = HitDensityMetric2(
#     "outputs/hit_density_stats.npz", "t2", normalizer=hits_normalizer)

## Calculate outputs and loss value

In [7]:
for batch in val_loader:
    targets = batch.pop("targets")
    target_mask = batch.pop("target_mask")
    outputs = model(**batch)
    # calculate loss
    loss = criterion(outputs, targets, target_mask)
    print(f"Loss: {loss.item()}")
    # calculate metrics
    t1_area = t1_area_metric(outputs, target_mask)
    t2_area = t2_area_metric(outputs, target_mask)
    t1_eff = t1_efficiency(outputs, targets, target_mask)
    t2_eff = t2_efficiency(outputs, targets, target_mask)
    t1_density_val = t1_density(outputs, target_mask)
    t2_density_val = t2_density(outputs, target_mask)
    print(f"Metrics:\n\tt1_area: {t1_area}, t1_eff: {t1_eff}, t1_density: {t1_density_val}\n\t"
          f"t2_area: {t2_area}, t2_eff: {t2_eff}, t2_density: {t2_density_val}\n\t")
    break

Loss: 1.2560718059539795
Metrics:
	t1_area: 1.4762532711029053, t1_eff: 0.0357142873108387, t1_density: 797.7449951171875
	t2_area: 1.3770127296447754, t2_eff: 0.0, t2_density: 625.7937622070312
	
