[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/veselm73/SU2/blob/main/notebooks/SU2_StarDist_inference.ipynb)

# Cell Detection & Tracking - Competition Submission

**Author:** Mateusz Vesel  
**Task:** Detect and track cells in microscopy video

---

## Pipeline Overview

This notebook presents my solution for cell detection and tracking:

### 1. Detection: StarDist with 5-Fold Ensemble
- **Architecture:** StarDist with ResNet18 encoder
- **Training:** 5-Fold cross-validation on 120 annotated frames + bonus data
- **Inference:** Ensemble averaging of probability maps from all 5 folds
- **Post-processing:** Non-maximum suppression to extract cell centroids

### 2. Tracking: LapTrack
- **Method:** Linear Assignment Problem (LAP) based tracking
- **Features:** Frame-to-frame linking with gap closing for missed detections

### Key Results (Validation Set)
- **DetA:** ~0.82 (detection accuracy)
- **AssA:** ~0.69 (association accuracy)  
- **HOTA:** ~0.76 (overall tracking accuracy)

---

## How to Use This Notebook

1. **Run all cells** to load pre-trained models
2. **Upload your test video** (TIFF format)
3. **Call `infer_video()`** to get predictions
4. **Download results** as CSV

## 1. Setup & Dependencies

In [None]:
# Install dependencies (fast with uv)
!pip install uv -q
!uv pip uninstall torch torchvision torchaudio --system -q 2>/dev/null || true
!uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 --system -q
!uv pip install "numpy<2" cellseg-models-pytorch pytorch-lightning laptrack tifffile --system -q

In [None]:
import sys
import os
from pathlib import Path

# Clone/update repository
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    if not Path('/content/SU2').exists():
        !git clone https://github.com/veselm73/SU2.git /content/SU2
    else:
        !cd /content/SU2 && git pull
    os.chdir('/content/SU2')
    repo_root = Path('/content/SU2')
else:
    notebook_dir = Path(os.getcwd())
    repo_root = notebook_dir.parent if notebook_dir.name == 'notebooks' else notebook_dir

if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

print(f"Repository: {repo_root}")

In [None]:
import torch
import numpy as np
import pandas as pd
import tifffile
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from skimage.measure import regionprops

from cellseg_models_pytorch.postproc.functional.stardist.stardist import post_proc_stardist

# Import model class and tracking
from modules.stardist_helpers import (
    run_laptrack,
    hota,
    ROI_X_MIN, ROI_X_MAX, ROI_Y_MIN, ROI_Y_MAX
)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {DEVICE}")

## 2. Model Configuration

These settings match the trained models. **Do not modify unless retraining.**

In [None]:
# ============================================================
# MODEL CONFIGURATION (must match training)
# ============================================================
MODEL_CONFIG = {
    "encoder_name": "resnet18",
    "n_rays": 32,
    "dropout": 0.1
}

# ============================================================
# INFERENCE CONFIGURATION (optimized on validation set)
# ============================================================
INFERENCE_CONFIG = {
    # Detection thresholds (tuned via ensemble sweep)
    "prob_thresh": 0.5,     # score_thresh for post_proc_stardist
    "nms_thresh": 0.3,      # iou_thresh for post_proc_stardist
    
    # Tracking parameters
    "track_max_dist": 5,    # pixels
    "gap_closing_frames": 2,
    
    # Region of Interest
    "roi": {
        "x_min": ROI_X_MIN,  # 256
        "x_max": ROI_X_MAX,  # 512
        "y_min": ROI_Y_MIN,  # 512
        "y_max": ROI_Y_MAX   # 768
    }
}

# Paths
MODELS_DIR = repo_root / "results" / "stardist" / "competition" / "models"
K_FOLDS = 5

print("Configuration loaded:")
print(f"  Encoder: {MODEL_CONFIG['encoder_name']}")
print(f"  Detection: prob_thresh={INFERENCE_CONFIG['prob_thresh']}, nms_thresh={INFERENCE_CONFIG['nms_thresh']}")
print(f"  Tracking: max_dist={INFERENCE_CONFIG['track_max_dist']}px, gap={INFERENCE_CONFIG['gap_closing_frames']} frames")
print(f"  ROI: {INFERENCE_CONFIG['roi']}")

## 3. Load Pre-trained Models

Load the 5-fold ensemble from saved checkpoints.

In [None]:
# Import StarDistLightning class
# This is defined in stardist_helpers but we need to import it properly
import pytorch_lightning as pl
from cellseg_models_pytorch.models.stardist.stardist import StarDist
import torch.nn as nn

class StarDistLightning(pl.LightningModule):
    """StarDist model wrapper for inference."""
    
    def __init__(self, n_rays=32, encoder_name="resnet18", dropout=0.0, **kwargs):
        super().__init__()
        self.n_rays = n_rays
        
        # Create StarDist model
        wrapper = StarDist(
            n_nuc_classes=1,
            n_rays=n_rays,
            enc_name=encoder_name,
            model_kwargs={"encoder_kws": {"in_chans": 1}}
        )
        self.model = wrapper.model
        self.dropout = nn.Dropout2d(p=dropout) if dropout > 0 else None
    
    def forward(self, x):
        return self.model(x)


def load_fold_model(fold_path, config):
    """Load a single fold model from .pth file."""
    model = StarDistLightning(
        n_rays=config['n_rays'],
        encoder_name=config['encoder_name'],
        dropout=config.get('dropout', 0.0)
    )
    state_dict = torch.load(fold_path, map_location='cpu')
    model.load_state_dict(state_dict)
    model.eval()
    return model

In [None]:
# Load all 5 fold models
print("Loading 5-fold ensemble models...")
print(f"  Looking in: {MODELS_DIR}")

fold_models = []

for fold in range(1, K_FOLDS + 1):
    fold_path = MODELS_DIR / f"fold_{fold}.pth"
    
    if fold_path.exists():
        model = load_fold_model(fold_path, MODEL_CONFIG)
        model = model.to(DEVICE)
        fold_models.append(model)
        print(f"  ✓ Fold {fold}: Loaded")
    else:
        print(f"  ✗ Fold {fold}: NOT FOUND at {fold_path}")

print(f"\nLoaded {len(fold_models)}/{K_FOLDS} models")

if len(fold_models) == 0:
    print("\n⚠️  No models found! Please ensure trained weights are in:")
    print(f"   {MODELS_DIR}")
    print("\n   Run SU2_StarDist_final.ipynb first to train models.")

## 4. Inference Functions

Core functions for detection and tracking.

In [None]:
def preprocess_frame(frame):
    """Normalize frame using percentile normalization."""
    frame = frame.astype(np.float32)
    p1, p99 = np.percentile(frame, (1, 99.8))
    frame = np.clip(frame, p1, p99)
    frame = (frame - p1) / (p99 - p1 + 1e-8)
    return frame


def detect_cells_ensemble(models, frame, prob_thresh, nms_thresh, device):
    """
    Detect cells using ensemble of models.
    
    Averages probability maps from all models before thresholding.
    This is more robust than averaging final detections.
    
    Args:
        models: List of StarDistLightning models
        frame: 2D numpy array (H, W), already preprocessed
        prob_thresh: Detection threshold (score_thresh)
        nms_thresh: NMS threshold (iou_thresh)
        device: torch device
    
    Returns:
        List of (x, y) coordinates in ROI space
    """
    # Prepare input tensor
    x = torch.from_numpy(frame).float().unsqueeze(0).unsqueeze(0).to(device)
    
    # Collect predictions from all models
    all_stardist = []
    all_prob = []
    
    with torch.no_grad():
        for model in models:
            out = model(x)
            nuc_out = out['nuc']
            
            # aux_map = stardist rays (n_rays, H, W)
            # binary_map = probability map (1, H, W)
            stardist_map = nuc_out.aux_map.cpu().numpy()[0]
            prob_map = torch.sigmoid(nuc_out.binary_map).cpu().numpy()[0, 0]
            
            all_stardist.append(stardist_map)
            all_prob.append(prob_map)
    
    # Average across all folds
    avg_stardist = np.mean(all_stardist, axis=0)
    avg_prob = np.mean(all_prob, axis=0)
    
    # Post-process to get detections
    try:
        labels = post_proc_stardist(
            avg_prob,       # dist_map (H, W)
            avg_stardist,   # stardist_map (n_rays, H, W)
            score_thresh=prob_thresh,
            iou_thresh=nms_thresh
        )
        
        detections = []
        for prop in regionprops(labels):
            cy, cx = prop.centroid
            detections.append((cx, cy))
        return detections
    
    except Exception as e:
        return []

In [None]:
def infer_video(video_path, config=None, models=None, output_csv=None):
    """
    Run full detection + tracking pipeline on a video.
    
    This is the main inference API.
    
    Args:
        video_path: Path to TIFF video file
        config: Inference config dict (uses INFERENCE_CONFIG if None)
        models: List of models (uses global fold_models if None)
        output_csv: Optional path to save predictions
    
    Returns:
        DataFrame with columns: frame, x, y, track_id
    
    Example:
        >>> predictions = infer_video('test.tif')
        >>> predictions.to_csv('submission.csv', index=False)
    """
    # Use defaults
    if config is None:
        config = INFERENCE_CONFIG
    if models is None:
        models = fold_models
    
    if len(models) == 0:
        raise ValueError("No models loaded! Run model loading cell first.")
    
    # Load video
    print(f"Loading: {video_path}")
    video = tifffile.imread(video_path)
    print(f"  Video shape: {video.shape}")
    
    # Extract ROI
    roi = config['roi']
    video_roi = video[:, roi['y_min']:roi['y_max'], roi['x_min']:roi['x_max']]
    print(f"  ROI shape: {video_roi.shape}")
    
    # Detection
    print(f"\nDetecting cells (ensemble of {len(models)} models)...")
    all_detections = []
    
    for frame_idx in tqdm(range(len(video_roi)), desc="Detection"):
        # Preprocess
        frame = preprocess_frame(video_roi[frame_idx])
        
        # Detect
        detections = detect_cells_ensemble(
            models, frame,
            prob_thresh=config['prob_thresh'],
            nms_thresh=config['nms_thresh'],
            device=DEVICE
        )
        
        # Convert to full image coordinates
        for x, y in detections:
            all_detections.append({
                'frame': frame_idx,
                'x': x + roi['x_min'],
                'y': y + roi['y_min']
            })
    
    detections_df = pd.DataFrame(all_detections)
    print(f"  Total detections: {len(detections_df)}")
    
    if detections_df.empty:
        print("  ⚠️ No detections found!")
        return pd.DataFrame(columns=['frame', 'x', 'y', 'track_id'])
    
    # Tracking
    print(f"\nTracking (max_dist={config['track_max_dist']}px)...")
    tracked_df = run_laptrack(
        detections_df,
        max_dist=config['track_max_dist'],
        closing_gap=config['gap_closing_frames'],
        min_length=2
    )
    print(f"  Total tracks: {tracked_df['track_id'].nunique()}")
    
    # Save if requested
    if output_csv:
        tracked_df.to_csv(output_csv, index=False)
        print(f"\n  Saved: {output_csv}")
    
    return tracked_df

## 5. Run Inference

### Option A: Use validation video (for testing)

In [None]:
# Run on validation video
VAL_TIF = repo_root / "data" / "val" / "val.tif"
VAL_CSV = repo_root / "data" / "val" / "val.csv"
OUTPUT_DIR = repo_root / "results" / "stardist" / "competition"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

if VAL_TIF.exists() and len(fold_models) > 0:
    predictions = infer_video(
        VAL_TIF,
        output_csv=OUTPUT_DIR / "predictions.csv"
    )
else:
    print("Validation video not found or models not loaded.")
    print(f"  VAL_TIF exists: {VAL_TIF.exists()}")
    print(f"  Models loaded: {len(fold_models)}")

### Option B: Upload your own video

In [None]:
# Uncomment to upload and process your own video
# from google.colab import files
# uploaded = files.upload()
# test_video_path = list(uploaded.keys())[0]
# predictions = infer_video(test_video_path, output_csv="my_predictions.csv")

## 6. Evaluate Results

Compare predictions against ground truth (if available).

In [None]:
# Evaluate against ground truth
if VAL_CSV.exists() and 'predictions' in dir() and len(predictions) > 0:
    print("Evaluating against ground truth...")
    
    # Load GT
    gt_df = pd.read_csv(VAL_CSV)
    roi = INFERENCE_CONFIG['roi']
    gt_roi = gt_df[
        (gt_df.x >= roi['x_min']) & (gt_df.x < roi['x_max']) &
        (gt_df.y >= roi['y_min']) & (gt_df.y < roi['y_max'])
    ].copy()
    
    # Calculate HOTA metrics
    hota_scores = hota(gt_roi, predictions, threshold=5.0)
    
    print("\n" + "="*50)
    print("EVALUATION RESULTS")
    print("="*50)
    print(f"\n  HOTA: {hota_scores['HOTA']:.4f}")
    print(f"  DetA: {hota_scores['DetA']:.4f}  (detection accuracy)")
    print(f"  AssA: {hota_scores['AssA']:.4f}  (association accuracy)")
    print(f"\n  Detections: {len(predictions)}")
    print(f"  Tracks: {predictions['track_id'].nunique()}")
    print("="*50)
else:
    print("Ground truth not available for evaluation.")

## 7. Visualization

In [None]:
# Visualize predictions vs ground truth
if VAL_TIF.exists() and 'predictions' in dir() and len(predictions) > 0:
    video = tifffile.imread(VAL_TIF)
    roi = INFERENCE_CONFIG['roi']
    video_roi = video[:, roi['y_min']:roi['y_max'], roi['x_min']:roi['x_max']]
    
    # Load GT if available
    gt_available = VAL_CSV.exists()
    if gt_available:
        gt_df = pd.read_csv(VAL_CSV)
        gt_roi = gt_df[
            (gt_df.x >= roi['x_min']) & (gt_df.x < roi['x_max']) &
            (gt_df.y >= roi['y_min']) & (gt_df.y < roi['y_max'])
        ]
    
    # Plot sample frames
    sample_frames = [0, 30, 60, 90]
    fig, axes = plt.subplots(1, len(sample_frames), figsize=(16, 4))
    
    for ax, fidx in zip(axes, sample_frames):
        if fidx >= len(video_roi):
            continue
        
        ax.imshow(video_roi[fidx], cmap='gray')
        
        # Plot GT (green circles)
        if gt_available:
            frame_gt = gt_roi[gt_roi.frame == fidx]
            ax.scatter(
                frame_gt.x - roi['x_min'],
                frame_gt.y - roi['y_min'],
                c='lime', s=40, marker='o', facecolors='none', 
                linewidths=1.5, label='GT'
            )
        
        # Plot predictions (red crosses)
        frame_preds = predictions[predictions.frame == fidx]
        ax.scatter(
            frame_preds.x - roi['x_min'],
            frame_preds.y - roi['y_min'],
            c='red', s=25, marker='x', linewidths=1.5, label='Predicted'
        )
        
        ax.set_title(f'Frame {fidx}')
        ax.axis('off')
    
    axes[0].legend(loc='upper left')
    plt.suptitle('Detection Results: Green=GT, Red=Predicted', fontsize=12)
    plt.tight_layout()
    plt.savefig(OUTPUT_DIR / 'detection_visualization.png', dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
# Visualize tracking (color by track_id)
if 'predictions' in dir() and len(predictions) > 0:
    video = tifffile.imread(VAL_TIF)
    roi = INFERENCE_CONFIG['roi']
    video_roi = video[:, roi['y_min']:roi['y_max'], roi['x_min']:roi['x_max']]
    
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    sample_frames = [0, 30, 60, 90]
    
    # Get unique track IDs and assign colors
    unique_tracks = predictions['track_id'].unique()
    colors = plt.cm.tab20(np.linspace(0, 1, len(unique_tracks)))
    track_to_color = {t: colors[i % len(colors)] for i, t in enumerate(unique_tracks)}
    
    for ax, fidx in zip(axes, sample_frames):
        if fidx >= len(video_roi):
            continue
        
        ax.imshow(video_roi[fidx], cmap='gray')
        
        frame_preds = predictions[predictions.frame == fidx]
        for _, row in frame_preds.iterrows():
            color = track_to_color[row['track_id']]
            ax.scatter(
                row['x'] - roi['x_min'],
                row['y'] - roi['y_min'],
                c=[color], s=30, marker='o'
            )
        
        ax.set_title(f'Frame {fidx} ({len(frame_preds)} cells)')
        ax.axis('off')
    
    plt.suptitle('Tracking Results (color = track ID)', fontsize=12)
    plt.tight_layout()
    plt.savefig(OUTPUT_DIR / 'tracking_visualization.png', dpi=150, bbox_inches='tight')
    plt.show()

## 8. Export for Submission

In [None]:
# Summary of exported files
print("="*60)
print("COMPETITION SUBMISSION FILES")
print("="*60)

print(f"\nOutput directory: {OUTPUT_DIR}")
print("\nFiles:")
for f in sorted(OUTPUT_DIR.glob('*')):
    if f.is_file():
        size_kb = f.stat().st_size / 1024
        print(f"  {f.name} ({size_kb:.1f} KB)")

print("\n" + "="*60)
print("SUBMISSION FORMAT")
print("="*60)
print("""
The predictions.csv file contains:
  - frame: Frame index (0-based)
  - x: Cell x-coordinate (full image)
  - y: Cell y-coordinate (full image)  
  - track_id: Unique track identifier

To submit:
  1. Download predictions.csv
  2. Rename if needed
  3. Upload to competition platform
""")
print("="*60)

In [None]:
# Download predictions (Colab only)
if IN_COLAB and (OUTPUT_DIR / 'predictions.csv').exists():
    from google.colab import files
    files.download(str(OUTPUT_DIR / 'predictions.csv'))
    print("Download started!")