In [1]:
import os
from pathlib import Path

# change working directory to make src visible
os.chdir(Path.cwd().parent)

## Prepare dataset

In [2]:
from torch.utils.data import DataLoader
from src.tracknet.data.dataset import TrackMLTracksDataset
from src.tracknet.data.transformations import MinMaxNormalizeXYZ, DropRepeatedLayerHits
from src.tracknet.data.utils import collate_fn

BATCH_SIZE = 4

val_dataset = TrackMLTracksDataset(
    data_dirs=Path("data/trackml/train_100_events"),
    blacklist_dir=Path("data/trackml/blacklist_training"),
    min_hits=3,  # TODO: this should be supplemented by filter
    validation_split=0.1,
    split="validation",
    transforms=[
        DropRepeatedLayerHits(),
        MinMaxNormalizeXYZ(
            min_xyz=(-1000.0, -1000.0, -3000.0),
            max_xyz=(1000.0, 1000.0, 3000.0)
        )]
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn,
    num_workers=4,
    persistent_workers=True,
)

## Initialize the model and loss

In [3]:
from src.tracknet.model import StepAheadTrackNET
from src.tracknet.loss import TrackNetLoss, PointInAreaLoss, AreaSizeLoss

model = StepAheadTrackNET()
criterion = TrackNetLoss()

## Calculate outputs and loss value

In [4]:
for batch in val_loader:
    targets = batch.pop("targets")
    target_mask = batch.pop("target_mask")
    outputs = model(**batch)
    loss = criterion(outputs, targets, target_mask)
    print(f"Loss: {loss.item()}")
    break

Loss: 1.2082741260528564


In [10]:
point_in_area_loss = PointInAreaLoss()
area_size_loss = AreaSizeLoss()

points_in_area = point_in_area_loss(outputs, targets)
area_size = area_size_loss(outputs)

print(points_in_area.shape, area_size.shape, target_mask.shape)

torch.Size([4, 17]) torch.Size([4, 17]) torch.Size([4, 17])
