In [1]:
# Install uv
!curl -LsSf https://astral.sh/uv/install.sh | sh

# Use uv to install a compatible stack (pydantic v1 + albumentations 1.3.1)
!uv pip install --system \
    "btrack==0.6.5" \
    "pydantic<2" \
    "albumentations==1.3.1" \
    numpy \
    pandas \
    scipy \
    scikit-image \
    scikit-learn \
    opencv-python-headless \
    matplotlib \
    seaborn \
    tqdm \
    ipywidgets \
    tifffile \
    numba


downloading uv 0.9.15 x86_64-unknown-linux-gnu
no checksums to verify
installing to /usr/local/bin
  uv
  uvx
everything's installed!
[2mUsing Python 3.12.12 environment at: /usr[0m
[2K[2mResolved [1m118 packages[0m [2min 1.03s[0m[0m
[2K[2mPrepared [1m5 packages[0m [2min 719ms[0m[0m
[2mUninstalled [1m2 packages[0m [2min 19ms[0m[0m
[2K[2mInstalled [1m5 packages[0m [2min 61ms[0m[0m
 [31m-[39m [1malbumentations[0m[2m==2.0.8[0m
 [32m+[39m [1malbumentations[0m[2m==1.3.1[0m
 [32m+[39m [1mbtrack[0m[2m==0.6.5[0m
 [32m+[39m [1mjedi[0m[2m==0.19.2[0m
 [31m-[39m [1mpydantic[0m[2m==2.12.3[0m
 [32m+[39m [1mpydantic[0m[2m==1.10.24[0m
 [32m+[39m [1mqudida[0m[2m==0.0.4[0m


In [1]:
# @title Imports & Globals
import os
import math
import shutil
from pathlib import Path
from typing import List, Tuple

import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tifffile
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import ConcatDataset, DataLoader, Dataset, Subset
from sklearn.model_selection import KFold
from skimage.measure import label, regionprops
from scipy.spatial.distance import cdist
from scipy.optimize import linear_sum_assignment

# ROI constants
ROI_Y_MIN, ROI_Y_MAX = 512, 768
ROI_X_MIN, ROI_X_MAX = 256, 512
ROI_H, ROI_W = ROI_Y_MAX - ROI_Y_MIN, ROI_X_MAX - ROI_X_MIN

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cudnn.benchmark = True
print(f"Using device: {device}")

import random
from IPython.display import HTML, display
from matplotlib.animation import FuncAnimation
from matplotlib.patches import Rectangle
import matplotlib.collections as mc
from scipy.ndimage import gaussian_filter
from PIL import Image
from scipy import spatial, optimize

Using device: cuda


In [2]:
# @title Data Download Helpers (validation + optional Drive training)
import os, shutil, zipfile, requests

# Download validation video+labels zip (with SSL chain)
def download_validation_data(target_dir: str = "val_data",
                             url: str = "https://su2.utia.cas.cz/files/labs/final2025/val_and_sota.zip",
                             cert_url: str = "https://pki.cesnet.cz/_media/certs/chain-harica-rsa-ov-crosssigned-root.pem"):
    if os.path.exists(target_dir) and len(os.listdir(target_dir)) > 0:
        print(f"'{target_dir}' already exists. Skipping download.")
        return

    chain_path = "chain-harica-cross.pem"
    print("1) Downloading SSL certificate chain...")
    r = requests.get(cert_url, timeout=10, stream=True)
    r.raise_for_status()
    with open(chain_path, "wb") as f:
        f.write(r.content)
    print("2) Downloading validation archive...")
    zip_name = os.path.basename(url)
    with requests.get(url, stream=True, verify=chain_path, timeout=30) as resp:
        resp.raise_for_status()
        with open(zip_name, "wb") as f:
            for chunk in resp.iter_content(chunk_size=8192):
                if chunk:
                    f.write(chunk)
    print("3) Extracting...")
    os.makedirs(target_dir, exist_ok=True)
    with zipfile.ZipFile(zip_name, "r") as zf:
        zf.extractall(target_dir)
    os.remove(zip_name)
    print(f"Done. Data in '{target_dir}/'")


# Optional: copy training data from Google Drive (manual annotations)
def fetch_from_drive(DRIVE_SOURCE_PATH: str = "/content/drive/MyDrive/unet_train",
                     TARGET_DIR: str = "real_training_data"):
    try:
        from google.colab import drive
    except ImportError:
        print("google.colab not available (not running in Colab).")
        return

    print("?? Mounting Drive...")
    drive.mount('/content/drive')

    if os.path.exists(TARGET_DIR) and len(os.listdir(TARGET_DIR)) > 0:
        print(f"'{TARGET_DIR}' already exists and is not empty. Skipping copy.")
        return

    if not os.path.exists(DRIVE_SOURCE_PATH):
        print(f"? Source not found: {DRIVE_SOURCE_PATH}")
        return

    print(f"Copying {DRIVE_SOURCE_PATH} -> {TARGET_DIR}")
    try:
        shutil.copytree(DRIVE_SOURCE_PATH, TARGET_DIR)
        print("? Copy complete.")
    except Exception as e:
        print(f"? Copy failed: {e}")
        if os.path.exists(TARGET_DIR):
            shutil.rmtree(TARGET_DIR)


In [37]:
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate=0.0):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.dropout = nn.Dropout2d(dropout_rate) if dropout_rate > 0 else None

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        if self.dropout:
            x = self.dropout(x)
        return x

class LightweightUNet(nn.Module):
    def __init__(self, in_channels=1, n_classes=1, features=[16, 32, 64, 128]):
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        in_ch = in_channels
        for feature in features:
            self.downs.append(ConvBlock(in_ch, feature))
            in_ch = feature

        self.bottleneck = ConvBlock(features[-1], features[-1] * 2)

        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2))
            self.ups.append(ConvBlock(feature * 2, feature))

        self.final_conv = nn.Conv2d(features[0], n_classes, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx // 2]
            if x.shape != skip_connection.shape:
                x = F.interpolate(x, size=skip_connection.shape[2:], mode="bilinear", align_corners=True)
            x = self.ups[idx + 1](torch.cat((skip_connection, x), dim=1))

        return self.final_conv(x)

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.75, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.bce = nn.BCEWithLogitsLoss(reduction="none")

    def forward(self, inputs, targets):
        bce_loss = self.bce(inputs, targets)
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
        return focal_loss.mean()

class ComboLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.focal = FocalLoss(alpha=0.75, gamma=2)

    def dice_loss(self, pred, target, eps=1e-6):
        pred = torch.sigmoid(pred)
        smooth = 1.0
        intersection = (pred * target).sum()
        dice = (2.0 * intersection + smooth) / (pred.sum() + target.sum() + smooth + eps)
        return 1 - dice

    def forward(self, inputs, targets):
        if isinstance(inputs, list):
            loss = 0
            for item in inputs:
                loss += 0.5 * self.focal(item, targets) + 0.5 * self.dice_loss(item, targets)
            return loss / len(inputs)
        return 0.5 * self.focal(inputs, targets) + 0.5 * self.dice_loss(inputs, targets)

# Quick sanity check
_dummy = torch.randn(1, 1, 256, 256)
_model = LightweightUNet()
_out = _model(_dummy)
print(f"Model ok, output shape: {_out.shape}")

Model ok, output shape: torch.Size([1, 1, 256, 256])


In [4]:
# @title Data prep: build experiment_dataset

def prepare_grand_dataset(
    real_data_dir: str = "real_training_data",
    val_tif_path: str = "val_data/val.tif",
    val_csv_path: str = "val_data/val.csv",
    out_dir: str = "experiment_dataset",
    roi_y: Tuple[int, int] = (ROI_Y_MIN, ROI_Y_MAX),
    roi_x: Tuple[int, int] = (ROI_X_MIN, ROI_X_MAX),
) -> Path:
    out_path = Path(out_dir)
    bonus_images = out_path / "bonus" / "images"
    bonus_masks = out_path / "bonus" / "masks"
    video_images = out_path / "video" / "images"
    video_masks = out_path / "video" / "masks"

    if out_path.exists():
        shutil.rmtree(out_path)
    for p in [bonus_images, bonus_masks, video_images, video_masks]:
        p.mkdir(parents=True, exist_ok=True)

    src_images = Path(real_data_dir) / "images"
    src_masks = Path(real_data_dir) / "masks"
    if not src_images.exists() or not src_masks.exists():
        raise FileNotFoundError("real_training_data must contain images/ and masks/")
    shutil.copytree(src_images, bonus_images, dirs_exist_ok=True)
    shutil.copytree(src_masks, bonus_masks, dirs_exist_ok=True)

    if not Path(val_tif_path).exists():
        raise FileNotFoundError(f"Missing video file: {val_tif_path}")
    if not Path(val_csv_path).exists():
        raise FileNotFoundError(f"Missing CSV file: {val_csv_path}")

    video = tifffile.imread(val_tif_path)
    coords = pd.read_csv(val_csv_path)
    y_min, y_max = roi_y
    x_min, x_max = roi_x
    records: List[dict] = []

    for idx, frame in enumerate(video):
        crop = frame[y_min:y_max, x_min:x_max]
        mask = np.zeros((ROI_H, ROI_W), dtype=np.uint8)
        points = coords[coords["frame"] == idx]
        for _, row in points.iterrows():
            cx = int(round(row["x"] - x_min))
            cy = int(round(row["y"] - y_min))
            if 0 <= cx < ROI_W and 0 <= cy < ROI_H:
                cv2.circle(mask, (cx, cy), 5, 1, -1)
        fname = f"frame_{idx + 1:03}.png"
        cv2.imwrite(str(video_images / fname), crop)
        cv2.imwrite(str(video_masks / fname), mask * 255)
        records.append({"filename": fname, "real_frame_idx": idx})

    pd.DataFrame(records).to_csv(out_path / "video_map.csv", index=False)
    print(f"Dataset written to {out_path.resolve()}")
    print(f"Bonus samples: {len(list(bonus_images.glob('*.png')))} | Video frames: {len(records)}")
    return out_path


In [20]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2

def get_train_transforms(
    rotate_p: float = 0.7,
    hflip_p: float = 0.5,
    vflip_p: float = 0.5,
    clahe_p: float = 0.5,
    brightness_p: float = 0.5,
    gauss_p: float = 0.3,
    elastic_p: float = 0.2,
    coarse_p: float = 0.5,
    coarse_max_holes: int = 16,
    coarse_min_holes: int = 8,
    coarse_max_hw: int = 16,
    coarse_min_hw: int = 8,
    crop_scale_min: float = 0.8,
    crop_scale_max: float = 1.0,
    crop_ratio_min: float = 0.9,
    crop_ratio_max: float = 1.1,
):
    return A.Compose([
        A.Rotate(limit=180, p=rotate_p),
        A.HorizontalFlip(p=hflip_p),
        A.VerticalFlip(p=vflip_p),
        A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=clahe_p),
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=brightness_p),
        A.GaussNoise(var_limit=(10.0, 50.0), p=gauss_p),
        A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=elastic_p),
        A.CoarseDropout(
            max_holes=coarse_max_holes, max_height=coarse_max_hw, max_width=coarse_max_hw,
            min_holes=coarse_min_holes, min_height=coarse_min_hw, min_width=coarse_min_hw,
            fill_value=0, mask_fill_value=0, p=coarse_p,
        ),
        A.RandomResizedCrop(
            height=256,
            width=256,
            scale=(crop_scale_min, crop_scale_max),
            ratio=(crop_ratio_min, crop_ratio_max),
            p=0.5,
        ),
        A.PadIfNeeded(min_height=256, min_width=256, border_mode=cv2.BORDER_CONSTANT, value=0),
        ToTensorV2(),
    ])


def get_val_transforms():
    return A.Compose([ToTensorV2()])


class AugmentedMicroscopyDataset(Dataset):
    def __init__(self, root_dir: str, transform=None, return_meta: bool = False):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.return_meta = return_meta
        self.image_paths = sorted(self.root_dir.joinpath("images").glob("*.png"))
        self.mask_paths = sorted(self.root_dir.joinpath("masks").glob("*.png"))
        if len(self.image_paths) != len(self.mask_paths):
            raise ValueError(f"Images and masks counts differ in {root_dir}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx: int):
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        img = cv2.imread(str(img_path), cv2.IMREAD_UNCHANGED)
        mask = cv2.imread(str(mask_path), cv2.IMREAD_UNCHANGED)
        if img is None or mask is None:
            raise FileNotFoundError(f"Missing pair for index {idx}")
        if img.ndim == 3:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        if mask.ndim == 3:
            mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)

        # Convert to uint8 for augmentations (CLAHE requires uint8)
        img = img.astype(np.uint8)
        # Ensure mask is also uint8 (0 or 255) for albumentations
        mask = (mask > 127).astype(np.uint8) * 255

        if self.transform:
            augmented = self.transform(image=img, mask=mask)
            img, mask = augmented["image"], augmented["mask"]
            if isinstance(img, torch.Tensor):
                img = img.float()
                if img.max() > 1:
                    img = img / 255.0
            if isinstance(mask, torch.Tensor):
                mask = mask.float()
                if mask.max() > 1:
                    mask = mask / 255.0
                mask = (mask > 0.5).float()
        else:
            # Manual conversion and normalization if no transform is applied
            img = torch.from_numpy(img).float().unsqueeze(0) / 255.0
            mask = torch.from_numpy(mask).float().unsqueeze(0) / 255.0
            mask = (mask > 0.5).float()

        # Ensure mask has a channel dimension (1, H, W)
        while mask.ndim < 3:
            mask = mask.unsqueeze(0)

        if self.return_meta:
            meta = {'filename': img_path.name}
            return img, mask, meta
        else:
            return img, mask

In [26]:
import torch.nn as nn
import torch.nn.functional as F

def make_fold_loaders(video_root: Path, bonus_root: Path, train_idx, val_idx, batch_size=8, num_workers=2, use_bonus: bool = True, train_transform=None, val_transform=None):
    train_tf = train_transform if train_transform is not None else get_train_transforms()
    val_tf = val_transform if val_transform is not None else get_val_transforms()
    train_ds = Subset(AugmentedMicroscopyDataset(video_root, transform=train_tf), train_idx)
    val_ds = Subset(AugmentedMicroscopyDataset(video_root, transform=val_tf), val_idx)
    datasets = [train_ds]
    bonus_len = 0
    if use_bonus:
        bonus_ds = AugmentedMicroscopyDataset(bonus_root, transform=train_tf)
        datasets.append(bonus_ds)
        bonus_len = len(bonus_ds)
    train_loader = DataLoader(ConcatDataset(datasets), batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=max(1, batch_size // 2), shuffle=False, num_workers=num_workers, pin_memory=True)
    print(f"Train loader: video {len(train_idx)} + bonus {bonus_len} = {len(train_loader.dataset)}")
    print(f"Val loader: {len(val_idx)}")
    return train_loader, val_loader


def train_one_fold(model, train_loader, val_loader, device, epochs=40, lr=1e-3):
    import copy
    criterion = ComboLoss()
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=0.05)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
    best_val = float('inf')
    best_state = copy.deepcopy(model.state_dict())
    for epoch in range(1, epochs + 1):
        model.train()
        train_loss = 0.0
        for imgs, masks in train_loader:
            # Explicitly cast to float after moving to device
            imgs, masks = imgs.to(device).float(), masks.to(device).float()
            optimizer.zero_grad()
            outputs = model(imgs)
            if isinstance(outputs, (list, tuple)):
                outputs = outputs[-1]
            loss = criterion(outputs, masks)
            loss.backward()
            # Add gradient clipping to prevent NaN loss from exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            train_loss += loss.item() * imgs.size(0)
        train_loss /= len(train_loader.dataset)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for imgs, masks in val_loader:
                # Explicitly cast to float after moving to device
                imgs, masks = imgs.to(device).float(), masks.to(device).float()
                outputs = model(imgs)
                if isinstance(outputs, (list, tuple)):
                    outputs = outputs[-1]
                loss = criterion(outputs, masks)
                val_loss += loss.item() * imgs.size(0)
        val_loss /= len(val_loader.dataset)
        scheduler.step(val_loss)
        if val_loss < best_val:
            best_val = val_loss
            best_state = copy.deepcopy(model.state_dict())
        print(f"Epoch {epoch:03d}: train {train_loss:.4f} | val {val_loss:.4f} | best {best_val:.4f}")
    model.load_state_dict(best_state)
    return model, best_val


def infer_fold_oof(model, val_indices, video_root: Path, video_map: Path, device, threshold: float = 0.5):
    infer_ds = Subset(AugmentedMicroscopyDataset(video_root, transform=get_val_transforms(), return_meta=True), val_indices)
    infer_loader = DataLoader(infer_ds, batch_size=1, shuffle=False)
    frame_lookup = pd.read_csv(video_map).set_index('filename')['real_frame_idx'].to_dict()
    model.eval()
    preds = []
    with torch.no_grad():
        for imgs, _, meta in infer_loader:
            filenames = meta['filename'] if isinstance(meta, dict) else [m['filename'] for m in meta]
            # Explicitly cast to float after moving to device
            imgs = imgs.to(device).float()
            outputs = model(imgs)
            if isinstance(outputs, (list, tuple)):
                outputs = outputs[-1]
            probs = torch.sigmoid(outputs).cpu().numpy()
            for b in range(probs.shape[0]):
                bin_mask = (probs[b, 0] >= threshold).astype(np.uint8)
                labeled = label(bin_mask)
                for prop in regionprops(labeled):
                    cy, cx = prop.centroid
                    preds.append({'frame': int(frame_lookup[filenames[b]]), 'x': float(cx + ROI_X_MIN), 'y': float(cy + ROI_Y_MIN)})
    return preds


def _filter_roi_frames(df: pd.DataFrame, frames_filter=None, roi=None):
    if frames_filter is not None:
        frame_set = set(map(int, frames_filter))
        df = df[df.frame.isin(frame_set)]
    if roi is not None:
        y_min, y_max, x_min, x_max = roi
        df = df[(df.y >= y_min) & (df.y < y_max) & (df.x >= x_min) & (df.x < x_max)]
    return df


def hota_metric(gt: pd.DataFrame, tr: pd.DataFrame, threshold: float = 5.0):
    """Delegate to baseline hota() for consistency."""
    return hota(gt, tr, threshold=threshold)

def link_detections(detections_per_frame, max_dist: float = 7.0):
    # Lightweight nearest-neighbor tracker to assign track_ids to per-frame detections.
    next_id = 0
    active = {}
    records = []
    for frame_idx, dets in enumerate(detections_per_frame):
        assigned = [False] * len(dets)
        new_active = {}
        for tid, (tx, ty, lf) in list(active.items()):
            best = None; best_d = max_dist
            for i, (x, y) in enumerate(dets):
                if assigned[i]:
                    continue
                d = ((x - tx) ** 2 + (y - ty) ** 2) ** 0.5
                if d < best_d:
                    best_d = d; best = i
            if best is not None:
                assigned[best] = True
                new_active[tid] = (dets[best][0], dets[best][1], frame_idx)
                records.append({'frame': frame_idx, 'x': dets[best][0], 'y': dets[best][1], 'track_id': tid})
        for i, (x, y) in enumerate(dets):
            if not assigned[i]:
                tid = next_id; next_id += 1
                new_active[tid] = (x, y, frame_idx)
                records.append({'frame': frame_idx, 'x': x, 'y': y, 'track_id': tid})
        active = new_active
    return pd.DataFrame(records)


def track_detections_simple(preds: pd.DataFrame, max_dist: float):
    max_frame = int(preds.frame.max()) if len(preds) else -1
    dets = [[] for _ in range(max_frame + 1)]
    for _, r in preds.iterrows():
        dets[int(r.frame)].append((float(r.x), float(r.y)))
    tracks_df = link_detections(dets, max_dist=max_dist)
    return tracks_df


def show_val_overlay(model, dataset, val_indices, device: str, threshold: float = 0.5):
    import numpy as np
    if len(val_indices) == 0:
        print("No validation indices to visualize")
        return
    idx = int(np.random.choice(val_indices))
    img, mask = dataset[idx]
    if isinstance(img, tuple):
        img = img[0]
    base = img.squeeze().cpu().numpy()
    if base.max() > 1:
        base = base / 255.0
    with torch.no_grad():
        out = model(img.unsqueeze(0).to(device).float()) # Ensure input to model is float
        if isinstance(out, (list, tuple)):
            out = out[-1]
        prob = torch.sigmoid(out)[0, 0].cpu().numpy()
    bin_mask = (prob >= threshold).astype(float)
    gt_mask = mask.squeeze().cpu().numpy()
    gt_props = regionprops(label(gt_mask > 0.5)); pred_props = regionprops(label(bin_mask))
    gx, gy = [], []
    for p in gt_props:
        cy, cx = p.centroid; gx.append(cx); gy.append(cy)
    px, py = [], []
    for p in pred_props:
        cy, cx = p.centroid; px.append(cx); py.append(cy)
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    axes[0].imshow(base, cmap='gray'); axes[0].set_title(f"Val image (idx {idx})")
    axes[1].imshow(prob, cmap='viridis'); axes[1].set_title("Prob map")
    axes[2].imshow(base, cmap='gray')
    axes[2].scatter(gx, gy, s=35, facecolors='none', edgecolors='lime', linewidths=1.2, label='GT')
    axes[2].scatter(px, py, s=30, marker='x', color='red', linewidths=1.2, label='Pred')
    axes[2].set_title("Overlay: red=pred (x), green=gt (o)")
    axes[2].legend(loc='upper right')
    for ax in axes:
        ax.axis('off')
    plt.tight_layout(); plt.show()


In [34]:
# @title Metrics & Visualization Helpers

def hota(gt: pd.DataFrame, tr: pd.DataFrame, threshold: float = 5) -> dict[str, float]:
    """Slightly adapted from https://github.com/JonathonLuiten/TrackEval"""
    gt = gt.copy(); tr = tr.copy()
    gt.track_id = gt.track_id.map({old: new for old, new in zip(gt.track_id.unique(), range(gt.track_id.nunique()))})
    tr.track_id = tr.track_id.map({old: new for old, new in zip(tr.track_id.unique(), range(tr.track_id.nunique()))})
    num_gt_ids = gt.track_id.nunique(); num_tr_ids = tr.track_id.nunique()
    frames = sorted(set(gt.frame.unique()) | set(tr.frame.unique()))
    potential_matches_count = np.zeros((num_gt_ids, num_tr_ids))
    gt_id_count = np.zeros((num_gt_ids, 1)); tracker_id_count = np.zeros((1, num_tr_ids))
    HOTA_TP = HOTA_FN = HOTA_FP = 0; LocA = 0.0
    similarities = [1 - np.clip(spatial.distance.cdist(gt[gt.frame == t][['x', 'y']], tr[tr.frame == t][['x', 'y']]) / threshold, 0, 1) for t in frames]
    for t in frames:
        gt_ids_t = gt[gt.frame == t].track_id.to_numpy(); tr_ids_t = tr[tr.frame == t].track_id.to_numpy()
        similarity = similarities[t]
        sim_iou_denom = similarity.sum(0)[np.newaxis, :] + similarity.sum(1)[:, np.newaxis] - similarity
        sim_iou = np.zeros_like(similarity)
        mask = sim_iou_denom > 0 + np.finfo('float').eps
        sim_iou[mask] = similarity[mask] / sim_iou_denom[mask]
        potential_matches_count[gt_ids_t[:, None], tr_ids_t[None, :]] += sim_iou
        gt_id_count[gt_ids_t] += 1; tracker_id_count[0, tr_ids_t] += 1
    global_alignment_score = potential_matches_count / (gt_id_count + tracker_id_count - potential_matches_count)
    matches_count = np.zeros_like(potential_matches_count)
    for t in frames:
        gt_ids_t = gt[gt.frame == t].track_id.to_numpy(); tr_ids_t = tr[tr.frame == t].track_id.to_numpy()
        if len(gt_ids_t) == 0:
            HOTA_FP += len(tr_ids_t); continue
        if len(tr_ids_t) == 0:
            HOTA_FN += len(gt_ids_t); continue
        similarity = similarities[t]
        score_mat = global_alignment_score[gt_ids_t[:, None], tr_ids_t[None, :]] * similarity
        match_rows, match_cols = optimize.linear_sum_assignment(-score_mat)
        mask = similarity[match_rows, match_cols] > 0
        alpha_match_rows = match_rows[mask]; alpha_match_cols = match_cols[mask]
        num_matches = len(alpha_match_rows)
        HOTA_TP += num_matches; HOTA_FN += len(gt_ids_t) - num_matches; HOTA_FP += len(tr_ids_t) - num_matches
        if num_matches > 0:
            LocA += float(np.sum(similarity[alpha_match_rows, alpha_match_cols]))
            matches_count[gt_ids_t[alpha_match_rows], tr_ids_t[alpha_match_cols]] += 1
    ass_a = matches_count / np.maximum(1, gt_id_count + tracker_id_count - matches_count)
    AssA = np.sum(matches_count * ass_a) / np.maximum(1, HOTA_TP)
    DetA = HOTA_TP / np.maximum(1, HOTA_TP + HOTA_FN + HOTA_FP)
    HOTA = np.sqrt(DetA * AssA)
    return {'HOTA': HOTA, 'AssA': AssA, 'DetA': DetA, 'LocA': LocA, 'HOTA TP': HOTA_TP, 'HOTA FN': HOTA_FN, 'HOTA FP': HOTA_FP}


def link_detections(detections_per_frame: list[list[tuple[int, int]]], max_dist: float = 7.0) -> pd.DataFrame:
    next_track_id = 0
    active_tracks: dict[int, tuple[int, int, int]] = {}
    records: list[dict[str, int]] = []
    for frame_idx, detections in enumerate(detections_per_frame):
        assigned = [False] * len(detections)
        detection_track_id: list[int | None] = [None] * len(detections)
        updated_tracks: dict[int, tuple[int, int, int]] = {}
        for track_id, (tx, ty, last_frame) in list(active_tracks.items()):
            best_dist = max_dist; best_idx: int | None = None
            for i, (x, y) in enumerate(detections):
                if assigned[i]:
                    continue
                dist = math.hypot(x - tx, y - ty)
                if dist < best_dist:
                    best_dist = dist; best_idx = i
            if best_idx is not None:
                assigned[best_idx] = True; detection_track_id[best_idx] = track_id
                updated_tracks[track_id] = (detections[best_idx][0], detections[best_idx][1], frame_idx)
        for i, (x, y) in enumerate(detections):
            if not assigned[i]:
                tid = next_track_id; next_track_id += 1
                detection_track_id[i] = tid
                updated_tracks[tid] = (x, y, frame_idx)
        active_tracks = updated_tracks
        for i, (x, y) in enumerate(detections):
            tid = detection_track_id[i]
            records.append({'frame': frame_idx, 'x': x, 'y': y, 'track_id': tid})
    return pd.DataFrame(records)


def show_tracking(data, image_stack, y_min=512, y_max=768, x_min=256, x_max=512, tail_length=10, color='yellow', show_roi=True):
    if isinstance(data, str):
        trajectories_df = pd.read_csv(data)
    elif isinstance(data, pd.DataFrame):
        trajectories_df = data.copy()
    else:
        raise TypeError("`data` must be a CSV file path or a pandas DataFrame.")
    tracks_in_roi = trajectories_df.groupby('track_id').filter(lambda t: (y_min < t.y.mean() < y_max) and (x_min < t.x.mean() < x_max))
    html_code_linking = loading_html("Loading cropped region and tracks, please wait...")
    display(HTML(html_code_linking))
    if show_roi:
        fig, ax = plt.subplots(figsize=(6, 6))
        ax.imshow(image_stack[0], cmap='magma')
        rect = Rectangle((x_min, y_min), x_max - x_min, y_max - y_min, linewidth=2, edgecolor='cyan', facecolor='none')
        ax.add_patch(rect); ax.set_title("Full image (cyan box shows cropped region)")
        plt.show()
    def animate_trajectories_cropped(trajectories_df, image_stack, tail_length=10, color='yellow'):
        cropped_stack = image_stack[:, y_min:y_max, x_min:x_max]
        fig, ax = plt.subplots(); im = ax.imshow(cropped_stack[0], cmap='magma')
        particles = trajectories_df['track_id'].unique()
        line_collections = {pid: mc.LineCollection([], linewidths=1, colors=color) for pid in particles}
        for lc in line_collections.values():
            ax.add_collection(lc)
        dot = ax.scatter([], [], s=5, c=color)
        def animate(i):
            im.set_array(cropped_stack[i])
            window = trajectories_df[(trajectories_df['frame'] >= i - tail_length) & (trajectories_df['frame'] <= i)]
            now = window[window['frame'] == i]
            if len(now) > 0:
                coords = np.column_stack((now.x.values - x_min, now.y.values - y_min))
                dot.set_offsets(coords)
            else:
                dot.set_offsets(np.empty((0, 2)))
            for pid in particles:
                traj = window[window['track_id'] == pid].sort_values('frame')
                if len(traj) >= 2:
                    segs = [[(x0 - x_min, y0 - y_min), (x1 - x_min, y1 - y_min)] for (x0, y0, x1, y1) in zip(traj.x.values[:-1], traj.y.values[:-1], traj.x.values[1:], traj.y.values[1:])]
                    line_collections[pid].set_segments(segs)
                else:
                    line_collections[pid].set_segments([])
            return [im, dot] + list(line_collections.values())
        ani = FuncAnimation(fig, animate, frames=cropped_stack.shape[0], interval=100, blit=True)
        plt.close(fig)
        return HTML(ani.to_jshtml())
    html = animate_trajectories_cropped(tracks_in_roi, image_stack, tail_length, color)
    display(html); display(HTML(replace_loading_js_empty))
    print("Total number of trajectories in ROI:", len(tracks_in_roi['track_id'].unique()))


def visualize_model_on_dataset(model, dataset, device, num_samples=4, threshold=0.5, sigma=1.0):
    model.eval(); loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)
    with torch.no_grad():
        for i, (img_tensor, mask_tensor) in enumerate(loader):
            if i >= num_samples:
                break
            img_tensor = img_tensor.to(device); mask_tensor = mask_tensor.to(device)
            logits = model(img_tensor)
            prob_map = torch.sigmoid(logits[0, 0]).cpu().numpy()
            if sigma > 0:
                prob_map = gaussian_filter(prob_map, sigma=sigma)
            pred_mask = (prob_map >= threshold).astype(float)
            img = img_tensor[0, 0].cpu().numpy(); gt_mask = mask_tensor[0, 0].cpu().numpy()
            fig, axes = plt.subplots(1, 4, figsize=(14, 4))
            axes[0].imshow(img, cmap='magma'); axes[0].set_title("Input Image")
            axes[1].imshow(gt_mask, cmap='gray', vmin=0, vmax=1); axes[1].set_title("Ground Truth Mask")
            axes[2].imshow(prob_map, cmap='viridis'); axes[2].set_title("Predicted Heatmap")
            axes[3].imshow(pred_mask, cmap='gray', vmin=0, vmax=1); axes[3].set_title(f"Thresholded Output (> {threshold})")
            for ax in axes:
                ax.axis('off')
            plt.tight_layout(); plt.show()


def show_detections(detections_per_frame, image_stack, y_min=512, y_max=768, x_min=256, x_max=512, color='yellow', max_frames=5):
    rows = [(frame_idx, x, y) for frame_idx, dets in enumerate(detections_per_frame) for (x, y) in dets]
    detections_df = pd.DataFrame(rows, columns=["frame", "x", "y"])
    detections_df = detections_df[(detections_df.y.between(y_min, y_max)) & (detections_df.x.between(x_min, x_max))]
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(image_stack[0], cmap='magma')
    rect = Rectangle((x_min, y_min), x_max - x_min, y_max - y_min, linewidth=2, edgecolor='cyan', facecolor='none')
    ax.add_patch(rect); ax.set_title("Full image (cyan box shows cropped region)")
    plt.show()
    def animate_detections_cropped(detections_df, image_stack, color='yellow'):
        cropped_stack = image_stack[:, y_min:y_max, x_min:x_max]
        fig, ax = plt.subplots(); im = ax.imshow(cropped_stack[0], cmap='magma')
        dot = ax.scatter([], [], s=10, c=color)
        def animate(i):
            im.set_array(cropped_stack[i])
            now = detections_df[detections_df['frame'] == i]
            if len(now) > 0:
                coords = np.column_stack((now.x.values - x_min, now.y.values - y_min))
                dot.set_offsets(coords)
            else:
                dot.set_offsets(np.empty((0, 2)))
            return [im, dot]
        ani = FuncAnimation(fig, animate, frames=min(max_frames, cropped_stack.shape[0]), interval=1000, blit=True)
        plt.close(fig); return HTML(ani.to_jshtml())
    html = animate_detections_cropped(detections_df, np.array(image_stack), color)
    display(html); print(f"Total detections in ROI : {len(detections_df)} 
Showing first {max_frames} frames")


def calculate_performance(gt_path, tracks, y_min=512, y_max=768, x_min=256, x_max=512, name="Method"):
    val_gt = pd.read_csv(gt_path)
    val_gt = val_gt.groupby('track_id').filter(lambda t: (y_min < t.y.mean() < y_max) and (x_min < t.x.mean() < x_max))
    if isinstance(tracks, str):
        val_tracks = pd.read_csv(tracks)
    elif isinstance(tracks, pd.DataFrame):
        val_tracks = tracks.copy()
    else:
        raise TypeError("`tracks` must be a CSV path or a pandas DataFrame.")
    val_tracks = val_tracks.groupby('track_id').filter(lambda t: (y_min < t.y.mean() < y_max) and (x_min < t.x.mean() < x_max))
    results = hota(val_gt, val_tracks)
    print(f"{name}:"); print(f"  HOTA: {results['HOTA']:.2f} (AssA: {results['AssA']:.2f}, DetA: {results['DetA']:.2f})
")
    return results


def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False

set_seed(42)


def open_tiff_file(name: str) -> np.ndarray:
    img = Image.open(name); frames = []
    for i in range(img.n_frames):
        img.seek(i); frames.append(np.array(img))
    return np.array(frames).squeeze()


def loading_html(message: str) -> str:
    return f"""
<div id="loading-msg">
  <br /><br />
  <b><span style='display:inline-block;animation:flipPause 2s ease infinite;'>?</span>
  {message}</b>
</div>
<style>
@keyframes flipPause {{
  0% {{transform:rotate(0deg);}}
  40%{{transform:rotate(180deg);}}
  50%{{transform:rotate(180deg);}}
  90%{{transform:rotate(360deg);}}
  100%{{transform:rotate(360deg);}}
}}
</style>
"""

html_code_reconstruction = loading_html("Showing input validation data, please wait...")


def replace_loading_js(message: str, delay_ms: int = 0) -> str:
    return f"""
<script>
  setTimeout(function(){{
    var loadingDiv = document.getElementById("loading-msg");
    if (loadingDiv) {{
      loadingDiv.innerHTML = '<br /><b>{message}</b>';
    }}
  }}, {delay_ms});
</script>
"""

replace_loading_js_default = replace_loading_js("Only the first 50 frames are displayed.")
replace_loading_js_empty = replace_loading_js("")


Helpers ready: prepare_grand_dataset, run_grand_kfold, run_btrack_tracking


In [23]:
# @title Download / Fetch Data (run once)
AUTO_FETCH_DRIVE = True  # set True to copy training data from Drive
DRIVE_SOURCE_PATH = "/content/drive/MyDrive/unet_train"
TARGET_DIR = "real_training_data"

# Download validation data if missing
if not (Path("val_data/val.tif").exists() and Path("val_data/val.csv").exists()):
    download_validation_data(target_dir="val_data")
else:
    print("Validation data already present.")

# Optional Drive copy for training data
if AUTO_FETCH_DRIVE:
    fetch_from_drive(DRIVE_SOURCE_PATH=DRIVE_SOURCE_PATH, TARGET_DIR=TARGET_DIR)
else:
    print("Skipping Drive fetch (AUTO_FETCH_DRIVE=False). Ensure real_training_data/ exists with images/ and masks/.")


Validation data already present.
?? Mounting Drive...
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
'real_training_data' already exists and is not empty. Skipping copy.


In [35]:
# This cell was a duplicate of ebc8af67 and has been removed to prevent definition conflicts.

In [38]:
# @title Run Full Pipeline (edit params as needed)
# Paths
REAL_DATA_DIR = "real_training_data"
VAL_TIF = "val_data/val.tif"
VAL_CSV = "val_data/val.csv"
OUT_DIR = "experiment_dataset"
SAVE_DIR = "."

# Pipeline knobs
K_SPLITS = 5
USE_BONUS = True          # set False to ignore bonus set
EPOCHS = 5               # Increased epochs to give more time for training convergence
BATCH_SIZE = 8
LR = 1e-5                # Reduced learning rate further to combat persistent NaN loss
THRESHOLD = 0.5           # sigmoid->mask threshold for detections
MATCH_THRESH = 5.0        # px for DetA calculation
VISUALIZE = True          # show sample masks from all folds
SAMPLE_IDX = 0            # which sample to visualize
RUN_BTRACK = True         # Toggle btrack tracking ON for testing
BTRACK_CONFIG = None      # None -> use btrack.datasets.cell_config() default (btrack>=0.6.x)
BTRACK_RADIUS = 12.0      # max_search_radius
USE_TRAIN_AUG = True      # set False to disable train-time augmentations

# Augmentation knobs
TRAIN_AUG_KWARTS = dict(
    rotate_p=0.7,
    hflip_p=0.5,
    vflip_p=0.5,
    clahe_p=0.5,
    brightness_p=0.5,
    gauss_p=0.3,
    elastic_p=0.2,
    coarse_p=0.5,
    coarse_max_holes=16,
    coarse_min_holes=8,
    coarse_max_hw=16,
    coarse_min_hw=8,
    crop_scale_min=0.8,
    crop_scale_max=1.0,
    crop_ratio_min=0.9,
    crop_ratio_max=1.1,
)

# 1) Build grand dataset (cleans OUT_DIR each time)
prepare_grand_dataset(
    real_data_dir=REAL_DATA_DIR,
    val_tif_path=VAL_TIF,
    val_csv_path=VAL_CSV,
    out_dir=OUT_DIR,
)

# 2) Train K-fold, infer OOF, compute DetA, (optionally) run btrack, visualize
oof_df, model_paths, tracks_df = run_grand_kfold(
    out_dir=OUT_DIR,
    k_splits=K_SPLITS,
    use_bonus=USE_BONUS,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    lr=LR,
    threshold=THRESHOLD,
    save_dir=SAVE_DIR,
    visualize_samples=VISUALIZE,
    sample_idx=SAMPLE_IDX,
    match_thresh=MATCH_THRESH,
    run_btrack=RUN_BTRACK,
    btrack_config=BTRACK_CONFIG,
    btrack_radius=BTRACK_RADIUS,
    train_aug_kwargs=TRAIN_AUG_KWARTS if USE_TRAIN_AUG else {},
)

print("Done. OOF detections:", len(oof_df), "| Models:", model_paths)
if tracks_df is not None:
    print("BTrack tracks:", len(tracks_df))

Dataset written to /content/experiment_dataset
Bonus samples: 42 | Video frames: 120
Frames: 120, Bonus: 42 (use_bonus=True)

Train loader: video 96 + bonus 42 = 138
Val loader: 24
Epoch 001: train nan | val nan | best inf
Epoch 002: train nan | val nan | best inf
Epoch 003: train nan | val nan | best inf
Epoch 004: train nan | val nan | best inf
Epoch 005: train nan | val nan | best inf
Saved best checkpoint (val_loss=inf) to model_fold_1.pth


[INFO][2025/12/03 01:09:33 PM] Loaded btrack: /usr/local/lib/python3.12/dist-packages/btrack/libs/libtracker.so
INFO:btrack.libwrapper:Loaded btrack: /usr/local/lib/python3.12/dist-packages/btrack/libs/libtracker.so
[INFO][2025/12/03 01:09:33 PM] Starting BayesianTracker session
INFO:btrack.core:Starting BayesianTracker session
[INFO][2025/12/03 01:09:33 PM] Loading configuration file: /root/.cache/btrack-examples/examples/cell_config.json
INFO:btrack.config:Loading configuration file: /root/.cache/btrack-examples/examples/cell_config.json
[INFO][2025/12/03 01:09:33 PM] Objects are of type: <class 'pandas.core.frame.DataFrame'>
INFO:btrack.io.utils:Objects are of type: <class 'pandas.core.frame.DataFrame'>


Using btrack.datasets.cell_config() (btrack>=0.6.x)


[INFO][2025/12/03 01:09:33 PM] Starting tracking... 
INFO:btrack.core:Starting tracking... 
[INFO][2025/12/03 01:09:33 PM] Update using: ['MOTION']
INFO:btrack.core:Update using: ['MOTION']
[INFO][2025/12/03 01:09:33 PM] Tracking objects in frames 0 to 99 (of 110)...
INFO:btrack.core:Tracking objects in frames 0 to 99 (of 110)...
[INFO][2025/12/03 01:09:39 PM]  - Timing (Bayesian updates: 222.52ms, Linking: 2.86ms)
INFO:btrack.utils: - Timing (Bayesian updates: 222.52ms, Linking: 2.86ms)
[INFO][2025/12/03 01:09:39 PM]  - Probabilities (Link: 0.09724, Lost: 0.10917)
INFO:btrack.utils: - Probabilities (Link: 0.09724, Lost: 0.10917)
[INFO][2025/12/03 01:09:39 PM]  - Stats (Active: 616, Lost: 8530, Conflicts resolved: 686)
INFO:btrack.utils: - Stats (Active: 616, Lost: 8530, Conflicts resolved: 686)
[INFO][2025/12/03 01:09:39 PM] Tracking objects in frames 100 to 110 (of 110)...
INFO:btrack.core:Tracking objects in frames 100 to 110 (of 110)...
[INFO][2025/12/03 01:09:39 PM]  - Timing (Bay

KeyboardInterrupt: 

In [None]:
# @title Compare Baselines vs UNet (OOF tracked)
# Assumes you already ran the UNet pipeline (oof_predictions.csv + tracks_btrack.csv) and baseline outputs exist.
# Adjust baseline paths as needed.

BASE_GT = "val_data/val.csv"
BASE_SOTA = "val_data/sota.csv"          # example baseline detections/tracks
BASE_SIMPLE = None                        # set to CSV if available
UNET_OOF = "tracks_btrack.csv"           # produced by run_grand_kfold when run_btrack=True

# Load GT and filter to ROI
roi_tuple = (ROI_Y_MIN, ROI_Y_MAX, ROI_X_MIN, ROI_X_MAX)
def _load_and_filter(path):
    df = pd.read_csv(path)
    return df[(df.y.between(ROI_Y_MIN, ROI_Y_MAX)) & (df.x.between(ROI_X_MIN, ROI_X_MAX))]

results = []

# Baseline: SOTA
if BASE_SOTA and Path(BASE_SOTA).exists():
    res_sota = calculate_performance(BASE_GT, BASE_SOTA, y_min=ROI_Y_MIN, y_max=ROI_Y_MAX, x_min=ROI_X_MIN, x_max=ROI_X_MAX, name="SOTA")
    results.append(("SOTA", res_sota))

# Baseline: Simple
if BASE_SIMPLE and Path(BASE_SIMPLE).exists():
    res_simple = calculate_performance(BASE_GT, BASE_SIMPLE, y_min=ROI_Y_MIN, y_max=ROI_Y_MAX, x_min=ROI_X_MIN, x_max=ROI_X_MAX, name="Simple" )
    results.append(("Simple", res_simple))

# UNet OOF (tracked with btrack)
if Path(UNET_OOF).exists():
    res_unet = calculate_performance(BASE_GT, UNET_OOF, y_min=ROI_Y_MIN, y_max=ROI_Y_MAX, x_min=ROI_X_MIN, x_max=ROI_X_MAX, name="UNet OOF")
    results.append(("UNet OOF", res_unet))
else:
    print("UNet tracked output not found; run k-fold pipeline with run_btrack=True")

print("\nSummary (HOTA/DetA):")
for name, res in results:
    print(f"{name}: HOTA={res['HOTA']:.3f}, DetA={res['DetA']:.3f}, AssA={res['AssA']:.3f}")


?? Environment requirements: see [`requirements.txt`](requirements.txt) for pinned versions (btrack, pydantic<2, numpy, albumentations, etc.). Run the setup cell to install from this file.