In [6]:
import os
from pathlib import Path
# change working directory to make src visible
os.chdir(Path.cwd().parent)

In [7]:
import numpy as np
import os

from pytorch_lightning import seed_everything
from torch.utils.data import DataLoader

from src.visualization import draw_event
from src.dataset import SPDEventsDataset
from src.data_generation import SPDEventGenerator
from src.normalization import ConstraintsNormalizer, TrackParamsNormalizer
from src.dataset import collate_fn_for_set_loss
from src.model import TRT

seed_everything(13)

Seed set to 13


13

## Prepare single batch for overfitting

In [8]:
MAX_EVENT_TRACKS = 10
TRUNCATION_LENGTH = 1024
BATCH_SIZE = 64
NUM_EVENTS_TRAIN =  50000
NUM_EVENTS_VALID = 10000

In [9]:
from src.dataset import collate_fn_with_class_loss, DatasetMode

hits_norm = ConstraintsNormalizer()
params_norm = TrackParamsNormalizer()
train_data = SPDEventsDataset(
    max_event_tracks=MAX_EVENT_TRACKS,
    generate_fixed_tracks_num=False,
    hits_normalizer=hits_norm,
    track_params_normalizer=params_norm,
    shuffle=True,
    truncation_length=TRUNCATION_LENGTH,
    n_samples=NUM_EVENTS_TRAIN,
    mode=DatasetMode.train,
    
)

train_loader = DataLoader(
    train_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn_with_class_loss,
    num_workers=4,
    pin_memory=False,
    persistent_workers=True
)

val_data = SPDEventsDataset(
    n_samples=NUM_EVENTS_VALID,
    max_event_tracks=MAX_EVENT_TRACKS,
    truncation_length=TRUNCATION_LENGTH,
    generate_fixed_tracks_num=False,
    hits_normalizer=hits_norm,
    track_params_normalizer=params_norm,
    mode=DatasetMode.val,
)
val_loader = DataLoader(
    val_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn_with_class_loss,
    num_workers=4,
    pin_memory=False,
    persistent_workers=True
)


In [10]:
# read single batch for test
for batch in train_loader:
    print(batch["targets"][0])
    print(batch["labels"][0])
    break

tensor([[0.4986, 0.5002, 0.5029, 0.2888, 0.3288, 0.8035, 1.0000],
        [0.4986, 0.5002, 0.5029, 0.2582, 0.7055, 0.6422, 0.0000],
        [0.4986, 0.5002, 0.5029, 0.5950, 0.2489, 0.1212, 1.0000],
        [0.4986, 0.5002, 0.5029, 0.9962, 0.6036, 0.3745, 0.0000],
        [0.4986, 0.5002, 0.5029, 0.5224, 0.5975, 0.8601, 0.0000],
        [0.4986, 0.5002, 0.5029, 0.4812, 0.9956, 0.2096, 0.0000],
        [0.4986, 0.5002, 0.5029, 0.7359, 0.3758, 0.1811, 1.0000],
        [0.4986, 0.5002, 0.5029, 0.2623, 0.3439, 0.2612, 1.0000],
        [0.4986, 0.5002, 0.5029, 0.6519, 0.3243, 0.5381, 1.0000],
        [0.4986, 0.5002, 0.5029, 0.9035, 0.0878, 0.3956, 1.0000]])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])


In [11]:
import torch
torch.cuda.is_available()

True

## Define functions for hungarian loss

In [12]:
import torch
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment


def match_targets(outputs, targets):
    cost_matrix = torch.cdist(outputs, targets, p=1)
    row_ind, col_ind = linear_sum_assignment(
        cost_matrix.cpu().detach().numpy()
    )
    return row_ind, col_ind


def hungarian_loss(outputs, targets):
    row_ind, col_ind = match_targets(outputs, targets)
    matched_outputs = outputs[row_ind]
    matched_targets = targets[col_ind]
    loss_dist = F.l1_loss(matched_outputs, matched_targets)
    return loss_dist


def criterion(preds, targets, preds_lengths, targets_lengths):
    hungarian = torch.tensor(0.0)
    for i in range(preds.shape[0]):
        hungarian += hungarian_loss(
            preds[i, :preds_lengths[i]],
            targets[i, :targets_lengths[i]]
        )
    hungarian /= preds.shape[0]  # batchmean
    return hungarian

In [13]:
batch.keys()

dict_keys(['inputs', 'mask', 'targets', 'orig_params', 'n_tracks_per_sample', 'labels'])

In [14]:
batch["targets"].shape

torch.Size([64, 10, 7])

#### Test same inputs, but shuffled

In [15]:
def adjust_targets(row_ind, col_ind, targets, num_candidates=10):
    """
    Args:
        logits: Predicted logits with shape [num_candidates, num_classes]
        row_ind: Matched row indices for predictions N (for N matched pairs).
        col_ind: Matched column indices for predictions N.
        targets: Ground truth labels corresponding to matched pairs N.
        num_candidates (int): Number of candidates predicted per sample (default=10).

    Returns:
        adjusted_logits: Logits with shape [num_candidates, num_classes], adjusted for unmatched candidates.
        adjusted_targets: Target labels with shape num_candidates, where unmatched candidates get label 1.
    """
    # Initialize adjusted logits and targets
    #adjusted_logits = logits  #.clone()  # Copy logits
    adjusted_targets = torch.ones(num_candidates, dtype=torch.long, device=targets.device)  # Default label is 1 for unmatched candidates

    # For each matched pair, assign the corresponding target
    matched_rows = row_ind
    matched_cols = col_ind
    adjusted_targets[matched_rows] = targets[matched_cols]

    return adjusted_targets


# Test
num_candidates = 10
num_classes = 5
num_matched_pairs = 3

# Logits: shape [10, C]
logits = torch.randn(num_candidates, num_classes)

# Random indices of matched pairs (row_ind and col_ind) for 3 matched elements per sample
row_ind =  torch.tensor([1, 4, 7])
col_ind = torch.tensor([0, 1, 2])

# Targets for matched pairs, shape N
targets =  torch.tensor([2, 0, 3])

# Adjust logits and targets for unmatched predictions
adjusted_targets = adjust_targets(row_ind, col_ind, targets, num_candidates)

print("Adjusted Targets:")
print(adjusted_targets)

Adjusted Targets:
tensor([1, 2, 1, 1, 0, 1, 1, 3, 1, 1])


## Prepare loss and overfit on a single batch

In [16]:
import torch
import torch.nn.functional as F

def focal_loss(logits, targets, alpha=1, gamma=2, reduction='mean'):
    """
    Args:
        logits: Predictions for each class with shape [B, N, C] where C is the number of classes (raw logits, not softmaxed).
        targets: Ground truth labels with shape [B, N] where each value is in the range [0, C-1].
        alpha (float, optional): A balancing factor for classes (default=1).
        gamma (float, optional): Focusing parameter for hard examples (default=2).
        reduction (string, optional): Specifies the reduction to apply to the output:
                                      'none' | 'mean' | 'sum'. 'mean': the sum of the output will be divided by the number of elements in the output;
                                      'sum': the output will be summed;
                                      'none': no reduction will be applied (default='mean').
    Returns:
        Loss: Scalar if reduction is applied or the same shape as input without reduction.
    """
    # Convert logits to probabilities with softmax
    probs = F.softmax(logits, dim=-1)  # [N, C]

    # Get the probabilities of the targets
    targets = targets.unsqueeze(-1)  # [N, 1] to align with logits
    probs_target_class = probs.gather(dim=-1, index=targets).squeeze(-1)  # [N]

    # Compute the focal loss
    log_pt = torch.log(probs_target_class + 1e-9)  # Stability for log
    loss = -alpha * (1 - probs_target_class) ** gamma * log_pt  # Focal loss equation

    # Apply the reduction
    if reduction == 'mean':
        return loss.mean()
    elif reduction == 'sum':
        return loss.sum()
    else:
        return loss  # No reduction

In [17]:
import torch
import torch.nn.functional as F
from torch import nn
from scipy.optimize import linear_sum_assignment
from typing import Callable


def match_targets(outputs, targets):
    cost_matrix = torch.cdist(outputs, targets, p=1)
    row_ind, col_ind = linear_sum_assignment(
        cost_matrix.cpu().detach().numpy()
    )
    return row_ind, col_ind

def hungarian_loss(outputs, targets, distance: Callable):

    # loss = F.l1_loss(matched_outputs, matched_targets)
    # loss = F.smooth_l1_loss(matched_outputs, matched_targets)
    # loss = F.mse_loss(matched_outputs, matched_targets)
    loss = distance(outputs, targets)
    return loss

def class_loss(outputs, targets, loss_fn: Callable):
    return loss_fn(outputs, targets)


class TRTHungarianLoss(nn.Module):
    def __init__(
            self,
            distance: Callable = F.l1_loss,
            class_loss: Callable = F.cross_entropy,
            weights: tuple[float, float] = (1, 1)
    ):
        super().__init__()
        self._distance = distance
        self._class_loss = class_loss
        self._weights = weights

    def forward(
            self,
            preds: dict[str, torch.Tensor],
            targets: dict[str, torch.Tensor],
            preds_lengths,
            targets_lengths,
    ):
        batch_size = preds["params"].shape[0]
        pred_logits = preds["logits"]
        pred_params = preds["params"]
        target_params = targets["targets"]
        target_labels = targets["labels"]
        hungarian = torch.tensor(0.0).to(pred_params.device)
        label_loss = torch.tensor(0.0).to(pred_params.device)
        for i in range(batch_size):
            row_ind, col_ind = match_targets(
                pred_params[i, :preds_lengths[i]],
                target_params[i, :targets_lengths[i]])
            matched_outputs = pred_params[i, row_ind]
            matched_targets = target_params[i, col_ind]
            hungarian += hungarian_loss(
                matched_outputs,
                matched_targets,
                distance=self._distance
            )
            matched_targets = adjust_targets(
                row_ind,
                col_ind,
                target_labels[i, :targets_lengths[i]],
                num_candidates=pred_logits.shape[1]
            )
            label_loss += class_loss(
                pred_logits[i],
                matched_targets,
                loss_fn=self._class_loss
            )
        hungarian /= batch_size  # batchmean
        label_loss /= batch_size  # batchmean
        return self._weights[0] * hungarian + self._weights[1] * label_loss

In [18]:
from functools import partial

criterion = TRTHungarianLoss(class_loss=focal_loss)

#### Check the loss

In [19]:
import sys
sys.path.append("../")

import src
dir(src)

['__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__path__',
 '__spec__',
 'constants',
 'data_generation',
 'dataset',
 'model',
 'normalization',
 'visualization']

In [20]:
import importlib
import sys
sys.path.append("../")
from src.model_hybrid import TRTHybrid

model_h = TRTHybrid(
    num_candidates=30,
    n_points=TRUNCATION_LENGTH,
    num_out_params=batch["targets"].shape[2],
)

In [21]:
with torch.no_grad():
    outputs = model_h(inputs=batch["inputs"], mask=batch["mask"])

print(outputs["params"].shape)
print(outputs["logits"].shape)

torch.Size([64, 30, 7])
torch.Size([64, 30, 2])


In [22]:
torch.LongTensor([MAX_EVENT_TRACKS]*len(outputs["params"]))

tensor([10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
        10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
        10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
        10, 10, 10, 10, 10, 10, 10, 10, 10, 10])

In [23]:
import os
os.getcwd()

'E:\\projects'

### L1 Distance for Hungarian Loss

In [24]:
from tqdm.notebook import tqdm

device = 'cuda' if torch.cuda.is_available(
) else 'mps' if torch.backends.mps.is_available() else 'cpu'

model = TRTHybrid(
    num_candidates=MAX_EVENT_TRACKS*5,
    n_points=TRUNCATION_LENGTH,
    num_out_params=batch["targets"].shape[2]
).to(device)
criterion = TRTHungarianLoss(weights=(1, 0.02)).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)
progress_bar = tqdm(range(5000))
min_loss_train = min_loss_val = 1e5
for epoch in progress_bar:        
    train_loss = 0
    num_train_batches = 0
    for batch in train_loader:
        num_train_batches += 1
        optimizer.zero_grad(set_to_none=True)
        outputs = model(batch["inputs"].to(device), batch["mask"].to(device))
        loss = criterion(
            preds=outputs,
            targets={
                "targets": batch["targets"].to(device),
                "labels": batch["labels"].to(device),
            },
            preds_lengths=torch.LongTensor(
                [MAX_EVENT_TRACKS]*len(outputs["params"])
            ).to(device),
            targets_lengths=batch["n_tracks_per_sample"].to(device)
        )
        train_loss += loss.detach().item()
        loss.backward()
        optimizer.step()
    if train_loss < min_loss_train:
        min_loss = train_loss
        torch.save(model.state_dict(), "E:/projects/trt/weights/trt_hybrid_10_10_train.pt")
        
    val_loss = 0
    num_val_batches = 0
    model.eval()
    for batch in val_loader:
        num_val_batches += 1
        optimizer.zero_grad(set_to_none=True)
        outputs = model(batch["inputs"].to(device), batch["mask"].to(device))
        loss = criterion(
            preds=outputs,
            targets={
                "targets": batch["targets"].to(device),
                "labels": batch["labels"].to(device),
            },
            preds_lengths=torch.LongTensor(
                [MAX_EVENT_TRACKS]*len(outputs["params"])
            ).to(device),
            targets_lengths=batch["n_tracks_per_sample"].to(device)
        )
        val_loss += loss.detach().item()
        
    if val_loss < min_loss_val:
        min_loss = val_loss
        torch.save(model.state_dict(), "E:/projects/trt/weights/trt_hybrid_10_10_val.pt")
        
    progress_bar.set_postfix(
        {
            "epoch": epoch, 
            "val_loss": loss.detach().item()/num_val_batches,  
            "train_loss": loss.detach().item()/num_train_batches}
    )

  0%|          | 0/5000 [00:00<?, ?it/s]

RuntimeError: DataLoader worker (pid(s) 25528, 23036, 9724) exited unexpectedly

In [None]:
torch.save(model.state_dict(), "weights/trt_hybrid.pt")

### L1 Smooth Distance for Hungarian Loss

In [None]:
from tqdm.notebook import tqdm

device = 'cuda' if torch.cuda.is_available(
) else 'mps' if torch.backends.mps.is_available() else 'cpu'

model = TRT(
    num_candidates=MAX_EVENT_TRACKS,
    n_points=TRUNCATION_LENGTH,
    num_out_params=batch["targets"].shape[2]
).to(device)
criterion = TRTHungarianLoss(distance=F.smooth_l1_loss).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=0.0003)
progress_bar = tqdm(range(10000))
for epoch in progress_bar:
    train_loss = 0

    optimizer.zero_grad(set_to_none=True)
    outputs = model(batch["inputs"].to(device), batch["mask"].to(device))
    loss = criterion(
        preds=outputs["params"],
        targets=batch["targets"].to(device),
        preds_lengths=torch.LongTensor(
            [MAX_EVENT_TRACKS]*len(outputs["params"])
        ).to(device),
        targets_lengths=batch["n_tracks_per_sample"].to(device)
    )
    loss.backward()
    optimizer.step()

    progress_bar.set_postfix({"epoch": epoch, "loss": loss.detach().item()})

### MSE distance for Hungarian Loss

In [None]:
from tqdm.notebook import tqdm

device = 'cuda' if torch.cuda.is_available(
) else 'mps' if torch.backends.mps.is_available() else 'cpu'

model = TRT(
    num_candidates=MAX_EVENT_TRACKS,
    n_points=TRUNCATION_LENGTH,
    num_out_params=batch["targets"].shape[2]
).to(device)
criterion = TRTHungarianLoss(distance=F.mse_loss).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=0.0003)
progress_bar = tqdm(range(10000))
for epoch in progress_bar:
    train_loss = 0

    optimizer.zero_grad(set_to_none=True)
    outputs = model(batch["inputs"].to(device), batch["mask"].to(device))
    loss = criterion(
        preds=outputs["params"],
        targets=batch["targets"].to(device),
        preds_lengths=torch.LongTensor(
            [MAX_EVENT_TRACKS]*len(outputs["params"])
        ).to(device),
        targets_lengths=batch["n_tracks_per_sample"].to(device)
    )
    loss.backward()
    optimizer.step()

    progress_bar.set_postfix({"epoch": epoch, "loss": loss.detach().item()})