[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/veselm73/SU2/blob/main/notebooks/SU2_StarDist_inference.ipynb)

# Cell Detection & Tracking - Model Benchmark

**Author:** Mateusz Vesel  
**Task:** Compare 3 trained StarDist ensembles on held-out validation data

---

## Pipeline Overview

### Detection: StarDist with 5-Fold Ensemble
- **Architecture:** StarDist with ResNet18 encoder
- **Training:** 5-Fold cross-validation on training data
- **Inference:** Ensemble averaging of probability maps from all 5 folds
- **Post-processing:** Non-maximum suppression to extract cell centroids

### Evaluation
- **Held-out data:** 10 manually annotated frames (58-67) never seen during training
- **Metrics:** DetA (Detection Accuracy), Precision, Recall

---

## Available Pre-trained Models

| Model | Epochs | Augmentation | N_Rays |
|-------|--------|--------------|--------|
| `100e_noaug_32rays` | 100 | No | 32 |
| `100e_noaug_64rays` | 100 | No | 64 |
| `120e_aug_32rays` | 120 | Yes | 32 |

---

## How to Use This Notebook

1. **Upload `val.tif`** when prompted (validation video containing frames 58-67)
2. **Run all cells** - model weights auto-download from GitHub
3. All 3 models are benchmarked on 10 held-out annotated frames
4. Results show true generalization performance (unbiased evaluation)

## 1. Setup & Dependencies

In [None]:
# Install dependencies (fast with uv)
!pip install uv -q
!uv pip uninstall torch torchvision torchaudio --system -q 2>/dev/null || true
!uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 --system -q
!uv pip install "numpy<2" cellseg-models-pytorch pytorch-lightning laptrack tifffile --system -q

In [None]:
import sys
import os
from pathlib import Path

# Clone/update repository
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    if not Path('/content/SU2').exists():
        !git clone https://github.com/veselm73/SU2.git /content/SU2
    else:
        !cd /content/SU2 && git pull
    os.chdir('/content/SU2')
    repo_root = Path('/content/SU2')
else:
    notebook_dir = Path(os.getcwd())
    repo_root = notebook_dir.parent if notebook_dir.name == 'notebooks' else notebook_dir

if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

print(f"Repository: {repo_root}")

In [None]:
import torch
import numpy as np
import pandas as pd
import tifffile
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from skimage.measure import regionprops

from cellseg_models_pytorch.postproc.functional.stardist.stardist import post_proc_stardist

# Import tracking and metrics
from modules.stardist_helpers import (
    run_laptrack,
    hota,
    ROI_X_MIN, ROI_X_MAX, ROI_Y_MIN, ROI_Y_MAX
)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {DEVICE}")

In [None]:
# Available pre-trained models
AVAILABLE_MODELS = {
    "100e_noaug_32rays": "100 epochs, no aug, 32 rays",
    "100e_noaug_64rays": "100 epochs, no aug, 64 rays",
    "120e_aug_32rays": "120 epochs, with aug, 32 rays"
}

print("Models to benchmark:")
for name, desc in AVAILABLE_MODELS.items():
    print(f"  - {name}: {desc}")

## 2. Available Models

In [None]:
import urllib.request
import json

def download_weights_if_needed(model_name, repo_root):
    """Download model weights from GitHub if not present locally."""
    weights_dir = repo_root / "weights" / model_name
    models_dir = weights_dir / "models"
    
    # Check if weights already exist
    if models_dir.exists() and len(list(models_dir.glob("*.pth"))) == 5:
        print(f"Weights found locally: {weights_dir}")
        return weights_dir
    
    print(f"Downloading weights for '{model_name}' from GitHub...")
    weights_dir.mkdir(parents=True, exist_ok=True)
    models_dir.mkdir(parents=True, exist_ok=True)
    
    base_url = f"https://raw.githubusercontent.com/veselm73/SU2/main/weights/{model_name}"
    
    # Download config files
    for cfg in ["model_config.json", "inference_config.json"]:
        url = f"{base_url}/{cfg}"
        dest = weights_dir / cfg
        print(f"  Downloading {cfg}...")
        urllib.request.urlretrieve(url, dest)
    
    # Download fold weights (5 folds, ~53MB each)
    for fold in range(1, 6):
        url = f"{base_url}/models/fold_{fold}.pth"
        dest = models_dir / f"fold_{fold}.pth"
        print(f"  Downloading fold_{fold}.pth (~53MB)...")
        urllib.request.urlretrieve(url, dest)
    
    print("Download complete!")
    return weights_dir

K_FOLDS = 5

## 3. Load Pre-trained Models

In [None]:
import pytorch_lightning as pl
from cellseg_models_pytorch.models.stardist.stardist import StarDist
import torch.nn as nn

class StarDistLightning(pl.LightningModule):
    """StarDist model wrapper for inference."""
    
    def __init__(self, n_rays=32, encoder_name="resnet18", dropout=0.0, **kwargs):
        super().__init__()
        self.n_rays = n_rays
        wrapper = StarDist(
            n_nuc_classes=1,
            n_rays=n_rays,
            enc_name=encoder_name,
            model_kwargs={"encoder_kws": {"in_chans": 1}}
        )
        self.model = wrapper.model
        self.dropout = nn.Dropout2d(p=dropout) if dropout > 0 else None
    
    def forward(self, x):
        return self.model(x)


def load_fold_model(fold_path, config):
    """Load a single fold model from .pth file."""
    model = StarDistLightning(
        n_rays=config['n_rays'],
        encoder_name=config['encoder_name'],
        dropout=config.get('dropout', 0.0)
    )
    state_dict = torch.load(fold_path, map_location='cpu')
    model.load_state_dict(state_dict)
    model.eval()
    return model

## 4. Inference Functions

In [None]:
def preprocess_frame(frame):
    """Percentile normalization."""
    frame = frame.astype(np.float32)
    p1, p99 = np.percentile(frame, (1, 99.8))
    frame = np.clip(frame, p1, p99)
    frame = (frame - p1) / (p99 - p1 + 1e-8)
    return frame


def detect_cells_ensemble(models, frame, prob_thresh, nms_thresh, device):
    """Detect cells using ensemble (average probability maps before thresholding)."""
    x = torch.from_numpy(frame).float().unsqueeze(0).unsqueeze(0).to(device)
    
    all_stardist = []
    all_prob = []
    
    with torch.no_grad():
        for model in models:
            out = model(x)
            nuc_out = out['nuc']
            stardist_map = nuc_out.aux_map.cpu().numpy()[0]
            prob_map = torch.sigmoid(nuc_out.binary_map).cpu().numpy()[0, 0]
            all_stardist.append(stardist_map)
            all_prob.append(prob_map)
    
    avg_stardist = np.mean(all_stardist, axis=0)
    avg_prob = np.mean(all_prob, axis=0)
    
    try:
        labels = post_proc_stardist(
            avg_prob, avg_stardist,
            score_thresh=prob_thresh,
            iou_thresh=nms_thresh
        )
        detections = [(prop.centroid[1], prop.centroid[0]) for prop in regionprops(labels)]
        return detections
    except:
        return []

## 5. Upload Validation Video & Load Annotations

Upload `val.tif` to benchmark models on your 10 manually annotated frames (58-67).

In [None]:
# Upload val.tif 
import os
import shutil

VAL_TIF = repo_root / "data" / "val" / "val.tif"

if not VAL_TIF.exists():
    print("val.tif not found. Please upload it.")
    if IN_COLAB:
        from google.colab import files
        os.makedirs(repo_root / "data" / "val", exist_ok=True)
        print("\nUpload val.tif:")
        uploaded = files.upload()
        for filename in uploaded.keys():
            shutil.move(filename, VAL_TIF)
            print(f"Moved to {VAL_TIF}")
    else:
        raise FileNotFoundError(f"Please place val.tif at {VAL_TIF}")
else:
    print(f"Found: {VAL_TIF}")

# Load video
video = tifffile.imread(VAL_TIF)
print(f"Video shape: {video.shape}")

# Download held-out validation annotations (frames 58-67)
SANITY_CHECK_CSV_URL = "https://raw.githubusercontent.com/veselm73/SU2/main/data/val_annotations/ensemble_sanity_check.csv"
SANITY_CHECK_CSV = repo_root / "data" / "val_annotations" / "ensemble_sanity_check.csv"

if not SANITY_CHECK_CSV.exists():
    print("Downloading held-out validation annotations...")
    SANITY_CHECK_CSV.parent.mkdir(parents=True, exist_ok=True)
    urllib.request.urlretrieve(SANITY_CHECK_CSV_URL, SANITY_CHECK_CSV)
    
sanity_gt = pd.read_csv(SANITY_CHECK_CSV)
print(f"\nLoaded {len(sanity_gt)} annotations across frames {sorted(sanity_gt['frame'].unique())}")
print(f"Annotations per frame: ~{len(sanity_gt) // sanity_gt['frame'].nunique()}")

## 6. Benchmark All Models

Compare all 3 trained models on the held-out validation frames.

In [None]:
def benchmark_model_on_sanity_check(model_name, sanity_gt, video, device=DEVICE):
    """
    Run a single model on sanity check frames and compute detection metrics.
    Returns DetA score and per-frame detection results.
    """
    from scipy.spatial.distance import cdist
    
    # Load model weights and config
    weights_dir = download_weights_if_needed(model_name, repo_root)
    with open(weights_dir / "model_config.json") as f:
        model_config = json.load(f)
    with open(weights_dir / "inference_config.json") as f:
        inf_config = json.load(f)
    
    # Load models
    models = []
    for fold in range(1, 6):
        model = load_fold_model(weights_dir / "models" / f"fold_{fold}.pth", model_config)
        model = model.to(device)
        models.append(model)
    
    # Get unique frames in sanity check
    frames = sorted(sanity_gt['frame'].unique())
    
    # ROI config - sanity check uses ROI coordinates already in the CSV
    # The images are 256x256 crops, so x,y in CSV are relative to ROI
    roi = inf_config.get('roi', {'x_min': 256, 'x_max': 512, 'y_min': 512, 'y_max': 768})
    
    all_results = []
    total_tp, total_fp, total_fn = 0, 0, 0
    
    for frame_idx in tqdm(frames, desc=f"Benchmarking {model_name}"):
        # Get frame from video (already in ROI coordinates in CSV)
        frame = video[frame_idx, roi['y_min']:roi['y_max'], roi['x_min']:roi['x_max']]
        frame_norm = preprocess_frame(frame)
        
        # Detect cells
        detections = detect_cells_ensemble(
            models, frame_norm,
            inf_config['prob_thresh'], inf_config['nms_thresh'], device
        )
        
        # Get GT for this frame (coordinates are in ROI space, 0-256)
        frame_gt = sanity_gt[sanity_gt['frame'] == frame_idx][['x', 'y']].values
        
        # Match predictions to GT using Hungarian algorithm
        if len(detections) > 0 and len(frame_gt) > 0:
            pred_coords = np.array(detections)  # (x, y) format
            gt_coords = frame_gt  # (x, y) format
            
            # Compute distance matrix
            dist_matrix = cdist(pred_coords, gt_coords)
            
            # Match within threshold (5 pixels)
            match_thresh = 5.0
            tp = 0
            matched_gt = set()
            matched_pred = set()
            
            # Greedy matching (for simplicity)
            for _ in range(min(len(pred_coords), len(gt_coords))):
                if dist_matrix.size == 0:
                    break
                min_idx = np.unravel_index(np.argmin(dist_matrix), dist_matrix.shape)
                if dist_matrix[min_idx] <= match_thresh:
                    tp += 1
                    matched_pred.add(min_idx[0])
                    matched_gt.add(min_idx[1])
                    dist_matrix[min_idx[0], :] = np.inf
                    dist_matrix[:, min_idx[1]] = np.inf
                else:
                    break
            
            fp = len(pred_coords) - tp
            fn = len(gt_coords) - tp
        elif len(detections) > 0:
            tp, fp, fn = 0, len(detections), 0
        elif len(frame_gt) > 0:
            tp, fp, fn = 0, 0, len(frame_gt)
        else:
            tp, fp, fn = 0, 0, 0
        
        total_tp += tp
        total_fp += fp
        total_fn += fn
        
        all_results.append({
            'frame': frame_idx,
            'n_gt': len(frame_gt),
            'n_pred': len(detections),
            'tp': tp, 'fp': fp, 'fn': fn
        })
    
    # Compute DetA
    if total_tp + total_fp + total_fn > 0:
        precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
        recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
        deta = total_tp / (total_tp + total_fp + total_fn)
    else:
        precision, recall, deta = 0, 0, 0
    
    # Clean up models from GPU
    del models
    torch.cuda.empty_cache()
    
    return {
        'model': model_name,
        'DetA': deta,
        'Precision': precision,
        'Recall': recall,
        'TP': total_tp,
        'FP': total_fp,
        'FN': total_fn,
        'per_frame': pd.DataFrame(all_results)
    }

In [None]:
# Run benchmark on all 3 models
print("="*60)
print("BENCHMARK: Comparing 3 Models on Held-Out Validation Data")
print("="*60)
print(f"Frames: {sorted(sanity_gt['frame'].unique())}")
print(f"Total annotations: {len(sanity_gt)}")
print("="*60)

# Benchmark each model
benchmark_results = []
for model_name in AVAILABLE_MODELS.keys():
    print(f"\n>>> Testing: {model_name}")
    result = benchmark_model_on_sanity_check(model_name, sanity_gt, video)
    benchmark_results.append(result)
    print(f"    DetA: {result['DetA']:.4f} | Precision: {result['Precision']:.4f} | Recall: {result['Recall']:.4f}")

print("\n" + "="*60)

## 7. Results Summary

In [None]:
# Display benchmark summary table
summary_df = pd.DataFrame([{
    'Model': r['model'],
    'DetA': f"{r['DetA']:.4f}",
    'Precision': f"{r['Precision']:.4f}",
    'Recall': f"{r['Recall']:.4f}",
    'TP': r['TP'],
    'FP': r['FP'],
    'FN': r['FN']
} for r in benchmark_results])

print("\n" + "="*60)
print("BENCHMARK RESULTS SUMMARY")
print("="*60)
print(summary_df.to_string(index=False))
print("="*60)

# Find best model
best_idx = np.argmax([r['DetA'] for r in benchmark_results])
print(f"\nüèÜ Best Model: {benchmark_results[best_idx]['model']} (DetA={benchmark_results[best_idx]['DetA']:.4f})")

In [None]:
# Visualize benchmark: bar chart comparison
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

models = [r['model'] for r in benchmark_results]
deta_scores = [r['DetA'] for r in benchmark_results]
precision_scores = [r['Precision'] for r in benchmark_results]
recall_scores = [r['Recall'] for r in benchmark_results]

colors = ['#2ecc71', '#3498db', '#e74c3c']

# DetA comparison
ax = axes[0]
bars = ax.bar(range(len(models)), deta_scores, color=colors)
ax.set_xticks(range(len(models)))
ax.set_xticklabels([m.replace('_', '\n') for m in models], fontsize=9)
ax.set_ylabel('DetA')
ax.set_title('Detection Accuracy (DetA)', fontweight='bold')
ax.set_ylim(0, 1)
for bar, score in zip(bars, deta_scores):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
            f'{score:.3f}', ha='center', va='bottom', fontsize=10)

# Precision comparison
ax = axes[1]
bars = ax.bar(range(len(models)), precision_scores, color=colors)
ax.set_xticks(range(len(models)))
ax.set_xticklabels([m.replace('_', '\n') for m in models], fontsize=9)
ax.set_ylabel('Precision')
ax.set_title('Precision (TP / (TP + FP))', fontweight='bold')
ax.set_ylim(0, 1)
for bar, score in zip(bars, precision_scores):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
            f'{score:.3f}', ha='center', va='bottom', fontsize=10)

# Recall comparison
ax = axes[2]
bars = ax.bar(range(len(models)), recall_scores, color=colors)
ax.set_xticks(range(len(models)))
ax.set_xticklabels([m.replace('_', '\n') for m in models], fontsize=9)
ax.set_ylabel('Recall')
ax.set_title('Recall (TP / (TP + FN))', fontweight='bold')
ax.set_ylim(0, 1)
for bar, score in zip(bars, recall_scores):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
            f'{score:.3f}', ha='center', va='bottom', fontsize=10)

plt.suptitle('Model Benchmark on Held-Out Validation Data (10 frames, 1480 annotations)', 
             fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()

## 8. Visualization: Best Model Detections

In [None]:
# Visualize detections from best model on sample frames
best_result = benchmark_results[best_idx]
best_model_name = best_result['model']

# Reload best model for visualization
weights_dir = download_weights_if_needed(best_model_name, repo_root)
with open(weights_dir / "model_config.json") as f:
    model_config = json.load(f)
with open(weights_dir / "inference_config.json") as f:
    inf_config = json.load(f)

best_models = []
for fold in range(1, 6):
    model = load_fold_model(weights_dir / "models" / f"fold_{fold}.pth", model_config)
    model = model.to(DEVICE)
    best_models.append(model)

roi = inf_config.get('roi', {'x_min': 256, 'x_max': 512, 'y_min': 512, 'y_max': 768})
frames_to_show = sorted(sanity_gt['frame'].unique())[:5]  # First 5 frames

fig, axes = plt.subplots(2, 5, figsize=(20, 8))

for i, frame_idx in enumerate(frames_to_show):
    # Get frame
    frame = video[frame_idx, roi['y_min']:roi['y_max'], roi['x_min']:roi['x_max']]
    frame_norm = preprocess_frame(frame)
    
    # Detect
    detections = detect_cells_ensemble(
        best_models, frame_norm,
        inf_config['prob_thresh'], inf_config['nms_thresh'], DEVICE
    )
    
    # Get GT
    frame_gt = sanity_gt[sanity_gt['frame'] == frame_idx][['x', 'y']].values
    
    # Plot GT (top row)
    axes[0, i].imshow(frame, cmap='gray')
    axes[0, i].scatter(frame_gt[:, 0], frame_gt[:, 1], c='lime', s=40, 
                       marker='o', facecolors='none', linewidths=1.5)
    axes[0, i].set_title(f'Frame {frame_idx} - GT ({len(frame_gt)})')
    axes[0, i].axis('off')
    
    # Plot Predictions (bottom row)
    axes[1, i].imshow(frame, cmap='gray')
    if detections:
        pred_coords = np.array(detections)
        axes[1, i].scatter(pred_coords[:, 0], pred_coords[:, 1], c='red', s=40, 
                          marker='x', linewidths=1.5)
    axes[1, i].set_title(f'Frame {frame_idx} - Pred ({len(detections)})')
    axes[1, i].axis('off')

plt.suptitle(f'Best Model: {best_model_name}\nTop: Ground Truth (green) | Bottom: Predictions (red)', 
             fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()

# Cleanup
del best_models
torch.cuda.empty_cache()