[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/veselm73/SU2/blob/main/notebooks/SU2_StarDist_inference.ipynb)

# Cell Detection & Tracking - Competition Inference

**Author:** Mateusz Vesel  
**Task:** Detect and track cells in microscopy video

---

## Pipeline Overview

### 1. Detection: StarDist with 5-Fold Ensemble
- **Architecture:** StarDist with ResNet18 encoder
- **Training:** 5-Fold stratified cross-validation on 120 annotated frames + bonus data
- **Inference:** Ensemble averaging of probability maps from all 5 folds
- **Post-processing:** Non-maximum suppression to extract cell centroids

### 2. Tracking: LapTrack
- **Method:** Linear Assignment Problem (LAP) based tracking
- **Features:** Frame-to-frame linking with gap closing for missed detections

---

## Available Pre-trained Models

| Model | Epochs | Augmentation | N_Rays |
|-------|--------|--------------|--------|
| `100e_noaug_32rays` | 100 | No | 32 |
| `100e_noaug_64rays` | 100 | No | 64 |
| `120e_aug_32rays` | 120 | Yes | 32 |

> **Note:** Run Section 8 (Benchmark) to compare models on held-out validation data.

---

## How to Use This Notebook

1. **Select a model** in the configuration cell below
2. **Run all cells** - weights auto-download from GitHub if needed
3. Results are evaluated against ground truth and visualized
4. **Section 8:** Benchmark all models on held-out sanity check data (unbiased evaluation)

## 1. Setup & Dependencies

In [None]:
# Install dependencies (fast with uv)
!pip install uv -q
!uv pip uninstall torch torchvision torchaudio --system -q 2>/dev/null || true
!uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 --system -q
!uv pip install "numpy<2" cellseg-models-pytorch pytorch-lightning laptrack tifffile --system -q

In [None]:
import sys
import os
from pathlib import Path

# Clone/update repository
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    if not Path('/content/SU2').exists():
        !git clone https://github.com/veselm73/SU2.git /content/SU2
    else:
        !cd /content/SU2 && git pull
    os.chdir('/content/SU2')
    repo_root = Path('/content/SU2')
else:
    notebook_dir = Path(os.getcwd())
    repo_root = notebook_dir.parent if notebook_dir.name == 'notebooks' else notebook_dir

if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

print(f"Repository: {repo_root}")

In [None]:
import torch
import numpy as np
import pandas as pd
import tifffile
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from skimage.measure import regionprops

from cellseg_models_pytorch.postproc.functional.stardist.stardist import post_proc_stardist

# Import tracking and metrics
from modules.stardist_helpers import (
    run_laptrack,
    hota,
    ROI_X_MIN, ROI_X_MAX, ROI_Y_MIN, ROI_Y_MAX
)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {DEVICE}")

## 2. Model Selection & Configuration

**Choose your model below.** Weights will be auto-downloaded from GitHub if not found locally.

In [None]:
# ============================================================
# MODEL SELECTION
# ============================================================
# Available pre-trained models:
#   - "100e_noaug_32rays": 100 epochs, no augmentation, 32 rays
#   - "100e_noaug_64rays": 100 epochs, no augmentation, 64 rays
#   - "120e_aug_32rays":   120 epochs, with augmentation, 32 rays
#
# Run Section 8 (Benchmark) for unbiased comparison on held-out data.

AVAILABLE_MODELS = {
    "100e_noaug_32rays": "100 epochs, no aug, 32 rays",
    "100e_noaug_64rays": "100 epochs, no aug, 64 rays",
    "120e_aug_32rays": "120 epochs, with aug, 32 rays"
}

# >>> CHANGE THIS TO SELECT MODEL <<<
SELECTED_MODEL = "100e_noaug_32rays"

print("Available models:")
for name, desc in AVAILABLE_MODELS.items():
    marker = ">>>" if name == SELECTED_MODEL else "   "
    print(f"  {marker} {name}: {desc}")
print(f"\nSelected: {SELECTED_MODEL}")

In [None]:
import urllib.request
import json

def download_weights_if_needed(model_name, repo_root):
    """Download model weights from GitHub if not present locally."""
    weights_dir = repo_root / "weights" / model_name
    models_dir = weights_dir / "models"
    
    # Check if weights already exist
    if models_dir.exists() and len(list(models_dir.glob("*.pth"))) == 5:
        print(f"Weights found locally: {weights_dir}")
        return weights_dir
    
    print(f"Downloading weights for '{model_name}' from GitHub...")
    weights_dir.mkdir(parents=True, exist_ok=True)
    models_dir.mkdir(parents=True, exist_ok=True)
    
    base_url = f"https://raw.githubusercontent.com/veselm73/SU2/main/weights/{model_name}"
    
    # Download config files
    for cfg in ["model_config.json", "inference_config.json"]:
        url = f"{base_url}/{cfg}"
        dest = weights_dir / cfg
        print(f"  Downloading {cfg}...")
        urllib.request.urlretrieve(url, dest)
    
    # Download fold weights (5 folds, ~53MB each)
    for fold in range(1, 6):
        url = f"{base_url}/models/fold_{fold}.pth"
        dest = models_dir / f"fold_{fold}.pth"
        print(f"  Downloading fold_{fold}.pth (~53MB)...")
        urllib.request.urlretrieve(url, dest)
    
    print("Download complete!")
    return weights_dir


# Download weights if needed
WEIGHTS_DIR = download_weights_if_needed(SELECTED_MODEL, repo_root)

# Load configuration from JSON files
with open(WEIGHTS_DIR / "model_config.json") as f:
    MODEL_CONFIG = json.load(f)

with open(WEIGHTS_DIR / "inference_config.json") as f:
    INFERENCE_CONFIG = json.load(f)

# Add ROI from helpers if not in config
if 'roi' not in INFERENCE_CONFIG:
    INFERENCE_CONFIG['roi'] = {
        'x_min': ROI_X_MIN, 'x_max': ROI_X_MAX,
        'y_min': ROI_Y_MIN, 'y_max': ROI_Y_MAX
    }

# Set paths
MODELS_DIR = WEIGHTS_DIR / "models"
VAL_TIF = repo_root / "data" / "val" / "val.tif"
VAL_CSV = repo_root / "data" / "val" / "val.csv"
K_FOLDS = 5

print("\nConfiguration loaded:")
print(f"  Model: {MODEL_CONFIG['encoder_name']}, n_rays={MODEL_CONFIG['n_rays']}")
print(f"  Detection: prob_thresh={INFERENCE_CONFIG['prob_thresh']}, nms_thresh={INFERENCE_CONFIG['nms_thresh']}")
if 'track_max_dist' in INFERENCE_CONFIG:
    print(f"  Tracking: max_dist={INFERENCE_CONFIG['track_max_dist']}px")

## 3. Load Pre-trained Models

In [None]:
import pytorch_lightning as pl
from cellseg_models_pytorch.models.stardist.stardist import StarDist
import torch.nn as nn

class StarDistLightning(pl.LightningModule):
    """StarDist model wrapper for inference."""
    
    def __init__(self, n_rays=32, encoder_name="resnet18", dropout=0.0, **kwargs):
        super().__init__()
        self.n_rays = n_rays
        wrapper = StarDist(
            n_nuc_classes=1,
            n_rays=n_rays,
            enc_name=encoder_name,
            model_kwargs={"encoder_kws": {"in_chans": 1}}
        )
        self.model = wrapper.model
        self.dropout = nn.Dropout2d(p=dropout) if dropout > 0 else None
    
    def forward(self, x):
        return self.model(x)


def load_fold_model(fold_path, config):
    """Load a single fold model from .pth file."""
    model = StarDistLightning(
        n_rays=config['n_rays'],
        encoder_name=config['encoder_name'],
        dropout=config.get('dropout', 0.0)
    )
    state_dict = torch.load(fold_path, map_location='cpu')
    model.load_state_dict(state_dict)
    model.eval()
    return model

In [None]:
# Load 5-fold ensemble
print(f"Loading models from: {MODELS_DIR}")

fold_models = []
for fold in range(1, K_FOLDS + 1):
    fold_path = MODELS_DIR / f"fold_{fold}.pth"
    if fold_path.exists():
        model = load_fold_model(fold_path, MODEL_CONFIG)
        model = model.to(DEVICE)
        fold_models.append(model)
        print(f"  ‚úì Fold {fold}")
    else:
        print(f"  ‚úó Fold {fold}: NOT FOUND")

print(f"\nLoaded {len(fold_models)}/{K_FOLDS} models")

## 4. Inference Functions

In [None]:
def preprocess_frame(frame):
    """Percentile normalization."""
    frame = frame.astype(np.float32)
    p1, p99 = np.percentile(frame, (1, 99.8))
    frame = np.clip(frame, p1, p99)
    frame = (frame - p1) / (p99 - p1 + 1e-8)
    return frame


def detect_cells_ensemble(models, frame, prob_thresh, nms_thresh, device):
    """Detect cells using ensemble (average probability maps before thresholding)."""
    x = torch.from_numpy(frame).float().unsqueeze(0).unsqueeze(0).to(device)
    
    all_stardist = []
    all_prob = []
    
    with torch.no_grad():
        for model in models:
            out = model(x)
            nuc_out = out['nuc']
            stardist_map = nuc_out.aux_map.cpu().numpy()[0]
            prob_map = torch.sigmoid(nuc_out.binary_map).cpu().numpy()[0, 0]
            all_stardist.append(stardist_map)
            all_prob.append(prob_map)
    
    avg_stardist = np.mean(all_stardist, axis=0)
    avg_prob = np.mean(all_prob, axis=0)
    
    try:
        labels = post_proc_stardist(
            avg_prob, avg_stardist,
            score_thresh=prob_thresh,
            iou_thresh=nms_thresh
        )
        detections = [(prop.centroid[1], prop.centroid[0]) for prop in regionprops(labels)]
        return detections
    except:
        return []


def infer_video(video_path, config, models=None):
    """Run detection + tracking on video. Returns DataFrame with frame, x, y, track_id."""
    if models is None:
        models = fold_models
    
    video = tifffile.imread(video_path)
    roi = config['roi']
    video_roi = video[:, roi['y_min']:roi['y_max'], roi['x_min']:roi['x_max']]
    
    print(f"Video: {video.shape} -> ROI: {video_roi.shape}")
    print(f"Running ensemble detection ({len(models)} models)...")
    
    all_detections = []
    for frame_idx in tqdm(range(len(video_roi))):
        frame = preprocess_frame(video_roi[frame_idx])
        detections = detect_cells_ensemble(
            models, frame,
            config['prob_thresh'], config['nms_thresh'], DEVICE
        )
        for x, y in detections:
            all_detections.append({
                'frame': frame_idx,
                'x': x + roi['x_min'],
                'y': y + roi['y_min']
            })
    
    detections_df = pd.DataFrame(all_detections)
    print(f"Detections: {len(detections_df)}")
    
    # Handle different tracking config formats
    if 'track_max_dist' in config:
        max_dist = config['track_max_dist']
        closing_gap = config.get('gap_closing_frames', 2)
    elif 'track_cost_cutoff' in config:
        # Convert squared distance to distance
        max_dist = int(np.sqrt(config['track_cost_cutoff']))
        closing_gap = config.get('gap_closing_max_frame_count', 1)
    else:
        # Default values
        max_dist = 5
        closing_gap = 2
    
    print(f"Running tracking (max_dist={max_dist}px, gap={closing_gap})...")
    tracked_df = run_laptrack(
        detections_df,
        max_dist=max_dist,
        closing_gap=closing_gap,
        min_length=2
    )
    print(f"Tracks: {tracked_df['track_id'].nunique()}")
    
    return tracked_df

## 5. Run Inference on Validation Video

In [None]:
if len(fold_models) > 0 and VAL_TIF.exists():
    predictions = infer_video(VAL_TIF, INFERENCE_CONFIG)
else:
    print("Models not loaded or validation video not found.")

## 6. Evaluation (HOTA Metrics)

In [None]:
if VAL_CSV.exists() and 'predictions' in dir() and len(predictions) > 0:
    # Load ground truth
    gt_df = pd.read_csv(VAL_CSV)
    roi = INFERENCE_CONFIG['roi']
    gt_roi = gt_df[
        (gt_df.x >= roi['x_min']) & (gt_df.x < roi['x_max']) &
        (gt_df.y >= roi['y_min']) & (gt_df.y < roi['y_max'])
    ].copy()
    
    # Calculate HOTA
    hota_scores = hota(gt_roi, predictions, threshold=5.0)
    
    print("="*50)
    print("RESULTS")
    print("="*50)
    print(f"\n  HOTA: {hota_scores['HOTA']:.4f}")
    print(f"  DetA: {hota_scores['DetA']:.4f}")
    print(f"  AssA: {hota_scores['AssA']:.4f}")
    print(f"\n  Detections: {len(predictions)}")
    print(f"  Tracks: {predictions['track_id'].nunique()}")
    print("="*50)

## 7. Visualization

In [None]:
# Detection visualization: GT vs Predictions
if VAL_TIF.exists() and 'predictions' in dir() and len(predictions) > 0:
    video = tifffile.imread(VAL_TIF)
    roi = INFERENCE_CONFIG['roi']
    video_roi = video[:, roi['y_min']:roi['y_max'], roi['x_min']:roi['x_max']]
    
    gt_available = VAL_CSV.exists()
    if gt_available:
        gt_df = pd.read_csv(VAL_CSV)
        gt_roi = gt_df[
            (gt_df.x >= roi['x_min']) & (gt_df.x < roi['x_max']) &
            (gt_df.y >= roi['y_min']) & (gt_df.y < roi['y_max'])
        ]
    
    sample_frames = [0, 30, 60, 90]
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    
    for ax, fidx in zip(axes, sample_frames):
        ax.imshow(video_roi[fidx], cmap='gray')
        
        if gt_available:
            frame_gt = gt_roi[gt_roi.frame == fidx]
            ax.scatter(frame_gt.x - roi['x_min'], frame_gt.y - roi['y_min'],
                      c='lime', s=40, marker='o', facecolors='none', linewidths=1.5, label='GT')
        
        frame_preds = predictions[predictions.frame == fidx]
        ax.scatter(frame_preds.x - roi['x_min'], frame_preds.y - roi['y_min'],
                  c='red', s=25, marker='x', linewidths=1.5, label='Pred')
        
        ax.set_title(f'Frame {fidx}')
        ax.axis('off')
    
    axes[0].legend(loc='upper left')
    plt.suptitle('Detection: Green=GT, Red=Predicted', fontsize=12)
    plt.tight_layout()
    plt.show()

In [None]:
# Tracking visualization: color by track_id
if 'predictions' in dir() and len(predictions) > 0:
    video = tifffile.imread(VAL_TIF)
    roi = INFERENCE_CONFIG['roi']
    video_roi = video[:, roi['y_min']:roi['y_max'], roi['x_min']:roi['x_max']]
    
    sample_frames = [0, 30, 60, 90]
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    
    unique_tracks = predictions['track_id'].unique()
    colors = plt.cm.tab20(np.linspace(0, 1, len(unique_tracks)))
    track_colors = {t: colors[i % len(colors)] for i, t in enumerate(unique_tracks)}
    
    for ax, fidx in zip(axes, sample_frames):
        ax.imshow(video_roi[fidx], cmap='gray')
        frame_preds = predictions[predictions.frame == fidx]
        for _, row in frame_preds.iterrows():
            ax.scatter(row['x'] - roi['x_min'], row['y'] - roi['y_min'],
                      c=[track_colors[row['track_id']]], s=30, marker='o')
        ax.set_title(f'Frame {fidx} ({len(frame_preds)} cells)')
        ax.axis('off')
    
    plt.suptitle('Tracking: color = track ID', fontsize=12)
    plt.tight_layout()
    plt.show()

## 8. Model Benchmark on Held-Out Validation Data

**Unbiased evaluation** comparing all 3 trained models on manually annotated frames (58-67).

These frames were annotated separately and never seen during training - this gives a true measure of generalization performance.

In [None]:
# Download sanity check validation annotations
SANITY_CHECK_CSV_URL = "https://raw.githubusercontent.com/veselm73/SU2/main/data/val_annotations/ensemble_sanity_check.csv"
SANITY_CHECK_CSV = repo_root / "data" / "val_annotations" / "ensemble_sanity_check.csv"

if not SANITY_CHECK_CSV.exists():
    print("Downloading sanity check annotations...")
    SANITY_CHECK_CSV.parent.mkdir(parents=True, exist_ok=True)
    urllib.request.urlretrieve(SANITY_CHECK_CSV_URL, SANITY_CHECK_CSV)
    
sanity_gt = pd.read_csv(SANITY_CHECK_CSV)
print(f"Loaded {len(sanity_gt)} annotations across frames {sanity_gt['frame'].min()}-{sanity_gt['frame'].max()}")
print(f"Annotations per frame: ~{len(sanity_gt) // sanity_gt['frame'].nunique()}")

In [None]:
def benchmark_model_on_sanity_check(model_name, sanity_gt, video, device=DEVICE):
    """
    Run a single model on sanity check frames and compute detection metrics.
    Returns DetA score and per-frame detection results.
    """
    from scipy.spatial.distance import cdist
    
    # Load model weights and config
    weights_dir = download_weights_if_needed(model_name, repo_root)
    with open(weights_dir / "model_config.json") as f:
        model_config = json.load(f)
    with open(weights_dir / "inference_config.json") as f:
        inf_config = json.load(f)
    
    # Load models
    models = []
    for fold in range(1, 6):
        model = load_fold_model(weights_dir / "models" / f"fold_{fold}.pth", model_config)
        model = model.to(device)
        models.append(model)
    
    # Get unique frames in sanity check
    frames = sorted(sanity_gt['frame'].unique())
    
    # ROI config - sanity check uses ROI coordinates already in the CSV
    # The images are 256x256 crops, so x,y in CSV are relative to ROI
    roi = inf_config.get('roi', {'x_min': 256, 'x_max': 512, 'y_min': 512, 'y_max': 768})
    
    all_results = []
    total_tp, total_fp, total_fn = 0, 0, 0
    
    for frame_idx in tqdm(frames, desc=f"Benchmarking {model_name}"):
        # Get frame from video (already in ROI coordinates in CSV)
        frame = video[frame_idx, roi['y_min']:roi['y_max'], roi['x_min']:roi['x_max']]
        frame_norm = preprocess_frame(frame)
        
        # Detect cells
        detections = detect_cells_ensemble(
            models, frame_norm,
            inf_config['prob_thresh'], inf_config['nms_thresh'], device
        )
        
        # Get GT for this frame (coordinates are in ROI space, 0-256)
        frame_gt = sanity_gt[sanity_gt['frame'] == frame_idx][['x', 'y']].values
        
        # Match predictions to GT using Hungarian algorithm
        if len(detections) > 0 and len(frame_gt) > 0:
            pred_coords = np.array(detections)  # (x, y) format
            gt_coords = frame_gt  # (x, y) format
            
            # Compute distance matrix
            dist_matrix = cdist(pred_coords, gt_coords)
            
            # Match within threshold (5 pixels)
            match_thresh = 5.0
            tp = 0
            matched_gt = set()
            matched_pred = set()
            
            # Greedy matching (for simplicity)
            for _ in range(min(len(pred_coords), len(gt_coords))):
                if dist_matrix.size == 0:
                    break
                min_idx = np.unravel_index(np.argmin(dist_matrix), dist_matrix.shape)
                if dist_matrix[min_idx] <= match_thresh:
                    tp += 1
                    matched_pred.add(min_idx[0])
                    matched_gt.add(min_idx[1])
                    dist_matrix[min_idx[0], :] = np.inf
                    dist_matrix[:, min_idx[1]] = np.inf
                else:
                    break
            
            fp = len(pred_coords) - tp
            fn = len(gt_coords) - tp
        elif len(detections) > 0:
            tp, fp, fn = 0, len(detections), 0
        elif len(frame_gt) > 0:
            tp, fp, fn = 0, 0, len(frame_gt)
        else:
            tp, fp, fn = 0, 0, 0
        
        total_tp += tp
        total_fp += fp
        total_fn += fn
        
        all_results.append({
            'frame': frame_idx,
            'n_gt': len(frame_gt),
            'n_pred': len(detections),
            'tp': tp, 'fp': fp, 'fn': fn
        })
    
    # Compute DetA
    if total_tp + total_fp + total_fn > 0:
        precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
        recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
        deta = total_tp / (total_tp + total_fp + total_fn)
    else:
        precision, recall, deta = 0, 0, 0
    
    # Clean up models from GPU
    del models
    torch.cuda.empty_cache()
    
    return {
        'model': model_name,
        'DetA': deta,
        'Precision': precision,
        'Recall': recall,
        'TP': total_tp,
        'FP': total_fp,
        'FN': total_fn,
        'per_frame': pd.DataFrame(all_results)
    }

In [None]:
# Run benchmark on all 3 models
print("="*60)
print("BENCHMARK: Comparing 3 Models on Sanity Check Data")
print("="*60)
print(f"Frames: {sorted(sanity_gt['frame'].unique())}")
print(f"Total annotations: {len(sanity_gt)}")
print("="*60)

# Load video once
video = tifffile.imread(VAL_TIF)

# Benchmark each model
benchmark_results = []
for model_name in AVAILABLE_MODELS.keys():
    print(f"\n>>> Testing: {model_name}")
    result = benchmark_model_on_sanity_check(model_name, sanity_gt, video)
    benchmark_results.append(result)
    print(f"    DetA: {result['DetA']:.4f} | Precision: {result['Precision']:.4f} | Recall: {result['Recall']:.4f}")

print("\n" + "="*60)

In [None]:
# Display benchmark summary table
summary_df = pd.DataFrame([{
    'Model': r['model'],
    'DetA': f"{r['DetA']:.4f}",
    'Precision': f"{r['Precision']:.4f}",
    'Recall': f"{r['Recall']:.4f}",
    'TP': r['TP'],
    'FP': r['FP'],
    'FN': r['FN']
} for r in benchmark_results])

print("\n" + "="*60)
print("BENCHMARK RESULTS SUMMARY")
print("="*60)
print(summary_df.to_string(index=False))
print("="*60)

# Find best model
best_idx = np.argmax([r['DetA'] for r in benchmark_results])
print(f"\nüèÜ Best Model: {benchmark_results[best_idx]['model']} (DetA={benchmark_results[best_idx]['DetA']:.4f})")

In [None]:
# Visualize benchmark: bar chart comparison
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

models = [r['model'] for r in benchmark_results]
deta_scores = [r['DetA'] for r in benchmark_results]
precision_scores = [r['Precision'] for r in benchmark_results]
recall_scores = [r['Recall'] for r in benchmark_results]

colors = ['#2ecc71', '#3498db', '#e74c3c']

# DetA comparison
ax = axes[0]
bars = ax.bar(range(len(models)), deta_scores, color=colors)
ax.set_xticks(range(len(models)))
ax.set_xticklabels([m.replace('_', '\n') for m in models], fontsize=9)
ax.set_ylabel('DetA')
ax.set_title('Detection Accuracy (DetA)', fontweight='bold')
ax.set_ylim(0, 1)
for bar, score in zip(bars, deta_scores):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
            f'{score:.3f}', ha='center', va='bottom', fontsize=10)

# Precision comparison
ax = axes[1]
bars = ax.bar(range(len(models)), precision_scores, color=colors)
ax.set_xticks(range(len(models)))
ax.set_xticklabels([m.replace('_', '\n') for m in models], fontsize=9)
ax.set_ylabel('Precision')
ax.set_title('Precision (TP / (TP + FP))', fontweight='bold')
ax.set_ylim(0, 1)
for bar, score in zip(bars, precision_scores):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
            f'{score:.3f}', ha='center', va='bottom', fontsize=10)

# Recall comparison
ax = axes[2]
bars = ax.bar(range(len(models)), recall_scores, color=colors)
ax.set_xticks(range(len(models)))
ax.set_xticklabels([m.replace('_', '\n') for m in models], fontsize=9)
ax.set_ylabel('Recall')
ax.set_title('Recall (TP / (TP + FN))', fontweight='bold')
ax.set_ylim(0, 1)
for bar, score in zip(bars, recall_scores):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
            f'{score:.3f}', ha='center', va='bottom', fontsize=10)

plt.suptitle('Model Benchmark on Sanity Check Validation Data (10 frames, 1480 annotations)', 
             fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Visualize detections from best model on sample frames
best_result = benchmark_results[best_idx]
best_model_name = best_result['model']

# Reload best model for visualization
weights_dir = download_weights_if_needed(best_model_name, repo_root)
with open(weights_dir / "model_config.json") as f:
    model_config = json.load(f)
with open(weights_dir / "inference_config.json") as f:
    inf_config = json.load(f)

best_models = []
for fold in range(1, 6):
    model = load_fold_model(weights_dir / "models" / f"fold_{fold}.pth", model_config)
    model = model.to(DEVICE)
    best_models.append(model)

roi = inf_config.get('roi', {'x_min': 256, 'x_max': 512, 'y_min': 512, 'y_max': 768})
frames_to_show = sorted(sanity_gt['frame'].unique())[:5]  # First 5 frames

fig, axes = plt.subplots(2, 5, figsize=(20, 8))

for i, frame_idx in enumerate(frames_to_show):
    # Get frame
    frame = video[frame_idx, roi['y_min']:roi['y_max'], roi['x_min']:roi['x_max']]
    frame_norm = preprocess_frame(frame)
    
    # Detect
    detections = detect_cells_ensemble(
        best_models, frame_norm,
        inf_config['prob_thresh'], inf_config['nms_thresh'], DEVICE
    )
    
    # Get GT
    frame_gt = sanity_gt[sanity_gt['frame'] == frame_idx][['x', 'y']].values
    
    # Plot GT (top row)
    axes[0, i].imshow(frame, cmap='gray')
    axes[0, i].scatter(frame_gt[:, 0], frame_gt[:, 1], c='lime', s=40, 
                       marker='o', facecolors='none', linewidths=1.5)
    axes[0, i].set_title(f'Frame {frame_idx} - GT ({len(frame_gt)})')
    axes[0, i].axis('off')
    
    # Plot Predictions (bottom row)
    axes[1, i].imshow(frame, cmap='gray')
    if detections:
        pred_coords = np.array(detections)
        axes[1, i].scatter(pred_coords[:, 0], pred_coords[:, 1], c='red', s=40, 
                          marker='x', linewidths=1.5)
    axes[1, i].set_title(f'Frame {frame_idx} - Pred ({len(detections)})')
    axes[1, i].axis('off')

plt.suptitle(f'Best Model: {best_model_name}\nTop: Ground Truth (green) | Bottom: Predictions (red)', 
             fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()

# Cleanup
del best_models
torch.cuda.empty_cache()