In [1]:
import os
from pathlib import Path
# change working directory to make src visible
os.chdir(Path.cwd().parent)

In [2]:
import numpy as np
import os

from pytorch_lightning import seed_everything
from torch.utils.data import DataLoader

from src.visualization import draw_event
from src.dataset import SPDEventsDataset
from src.data_generation import SPDEventGenerator
from src.normalization import ConstraintsNormalizer, TrackParamsNormalizer
from src.dataset import collate_fn_for_set_loss
from src.model import TRT

seed_everything(13)

Seed set to 13


13

## Prepare single batch for overfitting

In [3]:
MAX_EVENT_TRACKS = 5
TRUNCATION_LENGTH = 512
BATCH_SIZE = 16

In [4]:
hits_norm = ConstraintsNormalizer()
params_norm = TrackParamsNormalizer()
train_data = SPDEventsDataset(
    max_event_tracks=MAX_EVENT_TRACKS,
    generate_fixed_tracks_num=False,
    hits_normalizer=hits_norm,
    track_params_normalizer=params_norm,
    shuffle=True,
    truncation_length=TRUNCATION_LENGTH
)

train_loader = DataLoader(
    train_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn_for_set_loss,
    num_workers=4,
    pin_memory=False,
    persistent_workers=True,
)

In [5]:
# read single batch for test
for batch in train_loader:
    break

## Define functions for hungarian loss

In [6]:
import torch
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment


def match_targets(outputs, targets):
    cost_matrix = torch.cdist(outputs, targets, p=1)
    row_ind, col_ind = linear_sum_assignment(
        cost_matrix.cpu().detach().numpy()
    )
    return row_ind, col_ind


def hungarian_loss(outputs, targets):
    row_ind, col_ind = match_targets(outputs, targets)
    matched_outputs = outputs[row_ind]
    matched_targets = targets[col_ind]
    loss = F.l1_loss(matched_outputs, matched_targets)
    return loss


def criterion(preds, targets, preds_lengths, targets_lengths):
    hungarian = torch.tensor(0.0)
    for i in range(preds.shape[0]):
        hungarian += hungarian_loss(
            preds[i, :preds_lengths[i]],
            targets[i, :targets_lengths[i]]
        )
    hungarian /= preds.shape[0]  # batchmean
    return hungarian

In [7]:
batch.keys()

dict_keys(['inputs', 'mask', 'targets', 'orig_params', 'n_tracks_per_sample'])

In [8]:
batch["targets"].shape

torch.Size([16, 5, 7])

#### Test same inputs, but shuffled

In [9]:
batch["n_tracks_per_sample"]

tensor([2, 1, 3, 1, 5, 4, 5, 3, 3, 3, 1, 5, 3, 3, 5, 1])

In [10]:
# take a track with maximum number of tracks
idx = slice(4, 5)

In [11]:
rand_idx = torch.randperm(batch["targets"][idx].shape[1])
print(rand_idx)

tensor([3, 0, 2, 4, 1])


In [12]:
batch["targets"][idx]

tensor([[[0.4908, 0.4998, 0.4693, 0.7067, 0.8657, 0.6830, 1.0000],
         [0.4908, 0.4998, 0.4693, 0.0214, 0.1522, 0.1365, 0.0000],
         [0.4908, 0.4998, 0.4693, 0.9355, 0.2515, 0.4543, 1.0000],
         [0.4908, 0.4998, 0.4693, 0.4015, 0.7486, 0.4619, 0.0000],
         [0.4908, 0.4998, 0.4693, 0.2423, 0.5612, 0.5507, 1.0000]]])

In [13]:
batch["targets"][idx][:, rand_idx]

tensor([[[0.4908, 0.4998, 0.4693, 0.4015, 0.7486, 0.4619, 0.0000],
         [0.4908, 0.4998, 0.4693, 0.7067, 0.8657, 0.6830, 1.0000],
         [0.4908, 0.4998, 0.4693, 0.9355, 0.2515, 0.4543, 1.0000],
         [0.4908, 0.4998, 0.4693, 0.2423, 0.5612, 0.5507, 1.0000],
         [0.4908, 0.4998, 0.4693, 0.0214, 0.1522, 0.1365, 0.0000]]])

In [14]:
row_ind, col_ind = match_targets(
    batch["targets"][idx.start][rand_idx],
    batch["targets"][idx.start]
)

print(row_ind, col_ind)
# check that the permutation is the same
assert all(col_ind == rand_idx.tolist())

[0 1 2 3 4] [3 0 2 4 1]


In [15]:
rand_idx, col_ind

(tensor([3, 0, 2, 4, 1]), array([3, 0, 2, 4, 1]))

In [16]:
matched_outputs = batch["targets"][idx.start][rand_idx][row_ind]
matched_targets = batch["targets"][idx.start][col_ind]
print(matched_outputs)
print(matched_targets)
loss = F.l1_loss(matched_outputs, matched_targets)
print(loss)

tensor([[0.4908, 0.4998, 0.4693, 0.4015, 0.7486, 0.4619, 0.0000],
        [0.4908, 0.4998, 0.4693, 0.7067, 0.8657, 0.6830, 1.0000],
        [0.4908, 0.4998, 0.4693, 0.9355, 0.2515, 0.4543, 1.0000],
        [0.4908, 0.4998, 0.4693, 0.2423, 0.5612, 0.5507, 1.0000],
        [0.4908, 0.4998, 0.4693, 0.0214, 0.1522, 0.1365, 0.0000]])
tensor([[0.4908, 0.4998, 0.4693, 0.4015, 0.7486, 0.4619, 0.0000],
        [0.4908, 0.4998, 0.4693, 0.7067, 0.8657, 0.6830, 1.0000],
        [0.4908, 0.4998, 0.4693, 0.9355, 0.2515, 0.4543, 1.0000],
        [0.4908, 0.4998, 0.4693, 0.2423, 0.5612, 0.5507, 1.0000],
        [0.4908, 0.4998, 0.4693, 0.0214, 0.1522, 0.1365, 0.0000]])
tensor(0.)


In [17]:
criterion(
    preds=batch["targets"][idx],
    targets=batch["targets"][idx][:, rand_idx],
    preds_lengths=batch["n_tracks_per_sample"][idx],
    targets_lengths=batch["n_tracks_per_sample"][idx]
)

tensor(0.)

#### Test shuffled and expanded

In [18]:
# create a copy of tensors
preds = batch["targets"][idx].clone()
targets = batch["targets"][idx].clone()
preds_lengths = batch["n_tracks_per_sample"][idx].clone()
targets_lengths = batch["n_tracks_per_sample"][idx].clone()

In [19]:
# shuffle and expand predictions
rand_idx = torch.randperm(preds.shape[1])
print("Permutation indices:", rand_idx)

preds = preds[:, rand_idx]
print(preds.shape)

padded_preds = torch.zeros(
    preds.shape[0],
    preds.shape[1] + 5,
    preds.shape[2])
padded_preds[:, :preds.shape[1]] = preds.clone()
print(padded_preds.shape)

Permutation indices: tensor([4, 3, 1, 2, 0])
torch.Size([1, 5, 7])
torch.Size([1, 10, 7])


In [20]:
print(preds)
print()
print(padded_preds)

tensor([[[0.4908, 0.4998, 0.4693, 0.2423, 0.5612, 0.5507, 1.0000],
         [0.4908, 0.4998, 0.4693, 0.4015, 0.7486, 0.4619, 0.0000],
         [0.4908, 0.4998, 0.4693, 0.0214, 0.1522, 0.1365, 0.0000],
         [0.4908, 0.4998, 0.4693, 0.9355, 0.2515, 0.4543, 1.0000],
         [0.4908, 0.4998, 0.4693, 0.7067, 0.8657, 0.6830, 1.0000]]])

tensor([[[0.4908, 0.4998, 0.4693, 0.2423, 0.5612, 0.5507, 1.0000],
         [0.4908, 0.4998, 0.4693, 0.4015, 0.7486, 0.4619, 0.0000],
         [0.4908, 0.4998, 0.4693, 0.0214, 0.1522, 0.1365, 0.0000],
         [0.4908, 0.4998, 0.4693, 0.9355, 0.2515, 0.4543, 1.0000],
         [0.4908, 0.4998, 0.4693, 0.7067, 0.8657, 0.6830, 1.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0

In [21]:
criterion(
    preds=padded_preds,
    targets=targets,
    preds_lengths=preds_lengths,
    targets_lengths=targets_lengths
)

tensor(0.)

#### Test correctly predicted tracks + random

In [22]:
batch["n_tracks_per_sample"]

tensor([2, 1, 3, 1, 5, 4, 5, 3, 3, 3, 1, 5, 3, 3, 5, 1])

In [23]:
# take event with at least two tracks
idx2 = slice(2, 3)

In [24]:
batch["targets"][idx]

tensor([[[0.4908, 0.4998, 0.4693, 0.7067, 0.8657, 0.6830, 1.0000],
         [0.4908, 0.4998, 0.4693, 0.0214, 0.1522, 0.1365, 0.0000],
         [0.4908, 0.4998, 0.4693, 0.9355, 0.2515, 0.4543, 1.0000],
         [0.4908, 0.4998, 0.4693, 0.4015, 0.7486, 0.4619, 0.0000],
         [0.4908, 0.4998, 0.4693, 0.2423, 0.5612, 0.5507, 1.0000]]])

In [25]:
batch["targets"][idx2]

tensor([[[0.5059, 0.4981, 0.5559, 0.5204, 0.8110, 0.2880, 0.0000],
         [0.5059, 0.4981, 0.5559, 0.4023, 0.9057, 0.4842, 0.0000],
         [0.5059, 0.4981, 0.5559, 0.2223, 0.6505, 0.1441, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])

In [26]:
# create a copy of tensors
preds = batch["targets"][idx].clone()
targets = batch["targets"][idx].clone()
preds_lengths = batch["n_tracks_per_sample"][idx].clone()
targets_lengths = batch["n_tracks_per_sample"][idx].clone()

# targets + random predictions
preds_targets_random = torch.hstack([preds, batch["targets"][idx2]])
preds_targets_random_lengths = torch.LongTensor(
    [batch["n_tracks_per_sample"][idx.start] + batch["n_tracks_per_sample"][idx2.start]])
print(preds_targets_random)
print(preds_targets_random_lengths)

# shuffle and expand predictions
# calculate random indices for everything except padding
rand_idx = torch.randperm(preds_targets_random_lengths[0])
print("Permutation indices:", rand_idx)
preds_targets_random[:, :preds_targets_random_lengths[0]
                     ] = preds_targets_random[:, :preds_targets_random_lengths[0]][:, rand_idx]
print(preds_targets_random)

tensor([[[0.4908, 0.4998, 0.4693, 0.7067, 0.8657, 0.6830, 1.0000],
         [0.4908, 0.4998, 0.4693, 0.0214, 0.1522, 0.1365, 0.0000],
         [0.4908, 0.4998, 0.4693, 0.9355, 0.2515, 0.4543, 1.0000],
         [0.4908, 0.4998, 0.4693, 0.4015, 0.7486, 0.4619, 0.0000],
         [0.4908, 0.4998, 0.4693, 0.2423, 0.5612, 0.5507, 1.0000],
         [0.5059, 0.4981, 0.5559, 0.5204, 0.8110, 0.2880, 0.0000],
         [0.5059, 0.4981, 0.5559, 0.4023, 0.9057, 0.4842, 0.0000],
         [0.5059, 0.4981, 0.5559, 0.2223, 0.6505, 0.1441, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])
tensor([8])
Permutation indices: tensor([4, 5, 2, 0, 6, 7, 1, 3])
tensor([[[0.4908, 0.4998, 0.4693, 0.2423, 0.5612, 0.5507, 1.0000],
         [0.5059, 0.4981, 0.5559, 0.5204, 0.8110, 0.2880, 0.0000],
         [0.4908, 0.4998, 0.4693, 0.9355, 0.2515, 0.4543, 1.0000],
         [0.4908, 0.4998, 0.4693, 0.7067, 0.8657, 0.6830, 1.0

In [27]:
criterion(
    preds=preds_targets_random,
    targets=targets,
    preds_lengths=preds_targets_random_lengths,
    targets_lengths=targets_lengths
)

tensor(0.)

## Prepare loss and overfit on a single batch

In [55]:
import torch
import torch.nn.functional as F
from torch import nn
from scipy.optimize import linear_sum_assignment
from typing import Callable


def match_targets(outputs, targets):
    cost_matrix = torch.cdist(outputs, targets, p=1)
    row_ind, col_ind = linear_sum_assignment(
        cost_matrix.cpu().detach().numpy()
    )
    return row_ind, col_ind


def hungarian_loss(outputs, targets, distance: Callable):
    row_ind, col_ind = match_targets(outputs, targets)
    matched_outputs = outputs[row_ind]
    matched_targets = targets[col_ind]
    # loss = F.l1_loss(matched_outputs, matched_targets)
    # loss = F.smooth_l1_loss(matched_outputs, matched_targets)
    # loss = F.mse_loss(matched_outputs, matched_targets)
    loss = distance(matched_outputs, matched_targets)
    return loss


class TRTHungarianLoss(nn.Module):
    def __init__(self, distance: Callable = F.l1_loss):
        super().__init__()
        self._distance = distance

    def forward(self, preds, targets, preds_lengths, targets_lengths):
        hungarian = torch.tensor(0.0).to(preds.device)
        for i in range(preds.shape[0]):
            hungarian += hungarian_loss(
                preds[i, :preds_lengths[i]],
                targets[i, :targets_lengths[i]],
                distance=self._distance
            )
        hungarian /= preds.shape[0]  # batchmean
        return hungarian

In [56]:
criterion = TRTHungarianLoss()

#### Check the loss

In [57]:
criterion(
    preds=preds_targets_random,
    targets=targets,
    preds_lengths=preds_targets_random_lengths,
    targets_lengths=targets_lengths
)

tensor(0.)

In [48]:
batch

{'inputs': tensor([[[-0.4306, -0.0775,  0.3621],
          [-0.7161,  0.3530,  0.9452],
          [-0.1093,  0.3612,  0.7205],
          ...,
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],
 
         [[-0.9330,  0.1032,  0.8561],
          [ 0.8773, -0.2724, -0.7827],
          [-0.0673, -0.6296, -0.0309],
          ...,
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],
 
         [[-0.2009,  0.3660,  0.2520],
          [-0.4346,  0.3497, -0.0374],
          [-0.0352, -0.3758, -0.1095],
          ...,
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],
 
         ...,
 
         [[ 0.7158, -0.0570,  0.0866],
          [-0.3957, -0.4747,  0.0431],
          [ 0.0088,  0.4602,  0.0262],
          ...,
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
        

In [49]:
model = TRT(
    num_candidates=MAX_EVENT_TRACKS,
    n_points=TRUNCATION_LENGTH,
    num_out_params=batch["targets"].shape[2]
)

In [50]:
with torch.no_grad():
    outputs = model(batch["inputs"], batch["mask"])

print(outputs["params"])

tensor([[[0.5804, 0.5000, 0.5080, 0.4795, 0.4953, 0.4848, 0.4939],
         [0.5229, 0.4988, 0.5438, 0.4862, 0.5641, 0.4869, 0.4947],
         [0.5207, 0.5187, 0.4933, 0.4873, 0.4992, 0.4928, 0.5127],
         [0.5425, 0.5135, 0.5000, 0.4840, 0.4954, 0.4849, 0.5029],
         [0.5000, 0.5147, 0.4887, 0.4957, 0.4991, 0.4988, 0.5362]],

        [[0.5409, 0.5000, 0.4970, 0.4863, 0.5000, 0.4888, 0.5141],
         [0.5303, 0.5117, 0.4969, 0.4929, 0.5190, 0.4972, 0.5046],
         [0.5692, 0.4944, 0.4988, 0.4856, 0.5000, 0.5000, 0.4977],
         [0.5345, 0.4999, 0.4901, 0.5000, 0.5249, 0.4973, 0.5390],
         [0.5179, 0.5000, 0.4979, 0.4887, 0.5383, 0.4907, 0.5313]],

        [[0.5322, 0.4889, 0.5000, 0.4948, 0.5691, 0.4923, 0.5135],
         [0.5455, 0.5404, 0.4925, 0.4994, 0.5314, 0.4927, 0.4976],
         [0.5151, 0.4975, 0.4962, 0.4949, 0.5598, 0.4978, 0.5347],
         [0.4944, 0.4970, 0.4842, 0.5000, 0.5476, 0.5121, 0.5956],
         [0.4983, 0.4925, 0.4815, 0.4954, 0.5134, 0.4948, 

In [51]:
torch.LongTensor([MAX_EVENT_TRACKS]*len(outputs["params"]))

tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5])

### L1 Distance for Hungarian Loss

In [53]:
from tqdm.notebook import tqdm

device = 'cuda' if torch.cuda.is_available(
) else 'mps' if torch.backends.mps.is_available() else 'cpu'

model = TRT(
    num_candidates=MAX_EVENT_TRACKS,
    n_points=TRUNCATION_LENGTH,
    num_out_params=batch["targets"].shape[2]
).to(device)
criterion = TRTHungarianLoss().to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=0.0003)
progress_bar = tqdm(range(10000))
for epoch in progress_bar:
    train_loss = 0

    optimizer.zero_grad(set_to_none=True)
    outputs = model(batch["inputs"].to(device), batch["mask"].to(device))
    loss = criterion(
        preds=outputs["params"],
        targets=batch["targets"].to(device),
        preds_lengths=torch.LongTensor(
            [MAX_EVENT_TRACKS]*len(outputs["params"])
        ).to(device),
        targets_lengths=batch["n_tracks_per_sample"].to(device)
    )
    loss.backward()
    optimizer.step()

    progress_bar.set_postfix({"epoch": epoch, "loss": loss.detach().item()})

  0%|          | 0/10000 [00:00<?, ?it/s]

### L1 Smooth Distance for Hungarian Loss

In [60]:
from tqdm.notebook import tqdm

device = 'cuda' if torch.cuda.is_available(
) else 'mps' if torch.backends.mps.is_available() else 'cpu'

model = TRT(
    num_candidates=MAX_EVENT_TRACKS,
    n_points=TRUNCATION_LENGTH,
    num_out_params=batch["targets"].shape[2]
).to(device)
criterion = TRTHungarianLoss(distance=F.smooth_l1_loss).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=0.0003)
progress_bar = tqdm(range(10000))
for epoch in progress_bar:
    train_loss = 0

    optimizer.zero_grad(set_to_none=True)
    outputs = model(batch["inputs"].to(device), batch["mask"].to(device))
    loss = criterion(
        preds=outputs["params"],
        targets=batch["targets"].to(device),
        preds_lengths=torch.LongTensor(
            [MAX_EVENT_TRACKS]*len(outputs["params"])
        ).to(device),
        targets_lengths=batch["n_tracks_per_sample"].to(device)
    )
    loss.backward()
    optimizer.step()

    progress_bar.set_postfix({"epoch": epoch, "loss": loss.detach().item()})

  0%|          | 0/10000 [00:00<?, ?it/s]

### MSE distance for Hungarian Loss

In [61]:
from tqdm.notebook import tqdm

device = 'cuda' if torch.cuda.is_available(
) else 'mps' if torch.backends.mps.is_available() else 'cpu'

model = TRT(
    num_candidates=MAX_EVENT_TRACKS,
    n_points=TRUNCATION_LENGTH,
    num_out_params=batch["targets"].shape[2]
).to(device)
criterion = TRTHungarianLoss(distance=F.mse_loss).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=0.0003)
progress_bar = tqdm(range(10000))
for epoch in progress_bar:
    train_loss = 0

    optimizer.zero_grad(set_to_none=True)
    outputs = model(batch["inputs"].to(device), batch["mask"].to(device))
    loss = criterion(
        preds=outputs["params"],
        targets=batch["targets"].to(device),
        preds_lengths=torch.LongTensor(
            [MAX_EVENT_TRACKS]*len(outputs["params"])
        ).to(device),
        targets_lengths=batch["n_tracks_per_sample"].to(device)
    )
    loss.backward()
    optimizer.step()

    progress_bar.set_postfix({"epoch": epoch, "loss": loss.detach().item()})

  0%|          | 0/10000 [00:00<?, ?it/s]

KeyboardInterrupt: 