# StarDist Cell Detection Training Pipeline

This notebook trains a StarDist model for cell detection using K-Fold cross-validation.

**Key Features:**
- StarDist architecture with configurable backbone (ResNet18/34/50, EfficientNet)
- Combined loss function (Focal + Dice + Smooth L1)
- Data augmentation with albumentations
- Regularization: Weight decay + LR scheduling + Early stopping
- Post-training threshold sweep for optimal prob_thresh/nms_thresh
- Tracking evaluation with HOTA metric

**Data Sources:**
- Validation video: Downloaded from UTIA server
- Bonus training data: Fetched from GitHub repository (annotated frames)

**Outputs:**
- Trained StarDist model (best fold)
- Detection predictions CSV
- Training curves and metrics
- Tracking visualization

## 1. Setup & Configuration

In [None]:
# ============================================================
# REPOSITORY CONFIGURATION
# ============================================================
# Set these if running from Colab or a different location

REPO_URL = "https://github.com/veselm73/SU2"
REPO_BRANCH = "main"

# Path to bonus training data within the repo
BONUS_DATA_SUBPATH = "annotation/sam_data/unet_train"

In [None]:
# Install dependencies using uv (much faster than pip)
!pip install uv -q

# Uninstall conflicting packages  
!uv pip uninstall torch torchvision torchaudio tensorflow tensorflow-metal --system -q 2>/dev/null || true

# Install PyTorch with CUDA 12.1
!uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 --system -q

# Install other dependencies (numpy<2 for compatibility)
!uv pip install "numpy<2" cellseg-models-pytorch pytorch-lightning laptrack btrack "albumentations>=1.3.1" tifffile opencv-python-headless --system -q

In [None]:
import sys
import os
from pathlib import Path

# Check if running in Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Clone repo if not already present, otherwise pull latest
    if not Path('/content/SU2').exists():
        !git clone https://github.com/veselm73/SU2.git /content/SU2
    else:
        !cd /content/SU2 && git pull
    os.chdir('/content/SU2')
    repo_root = Path('/content/SU2')
else:
    # Local setup - find repo root
    notebook_dir = Path(os.getcwd())
    if notebook_dir.name == 'notebooks':
        repo_root = notebook_dir.parent
    else:
        repo_root = notebook_dir

if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

print(f"Repository root: {repo_root}")
print(f"Working directory: {os.getcwd()}")

In [None]:
# Import helper functions
from modules.stardist_helpers import (
    # Data fetching
    download_validation_data,
    fetch_training_data_from_github,
    prepare_grand_dataset,
    create_stardist_label_mask,
    # Training
    train_stardist_kfold,
    infer_stardist_full_video,
    sweep_stardist_thresholds,
    # Metrics
    calculate_deta_robust,
    hota,
    # Tracking
    run_btrack_tracking,
    run_laptrack,
    run_tracking_sweep,
    # Visualization
    plot_training_history,
    show_detection_overlay,
    show_tracking_animation,
    print_results_summary,
    # Utilities
    set_seed,
    get_device,
    ROI_X_MIN, ROI_X_MAX, ROI_Y_MIN, ROI_Y_MAX
)

import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

set_seed(42)
device = get_device()

## 2. Data Preparation

This section downloads/fetches the required data:
1. **Validation video**: From UTIA server (val.tif + val.csv)
2. **Bonus training data**: From GitHub repository (annotated frames with masks)

In [None]:
# Define data paths
VAL_DIR = repo_root / "data" / "val"
VAL_TIF = VAL_DIR / "val.tif"
VAL_CSV = VAL_DIR / "val.csv"
BONUS_DATA_DIR = repo_root / "bonus_training_data"

print("Data paths:")
print(f"  Validation: {VAL_DIR}")
print(f"  Bonus data: {BONUS_DATA_DIR}")

In [None]:
# Download validation data from UTIA server
if not VAL_TIF.exists():
    print("Downloading validation data from UTIA server...")
    download_validation_data(target_dir=str(VAL_DIR))
else:
    print(f"Validation data exists: {VAL_TIF}")

In [None]:
# Fetch bonus training data from GitHub
# This will:
# 1. Check if data exists locally (in cloned repo)
# 2. If not, download from GitHub API

bonus_path = fetch_training_data_from_github(
    repo_url=REPO_URL,
    branch=REPO_BRANCH,
    data_subpath=BONUS_DATA_SUBPATH,
    target_dir=str(BONUS_DATA_DIR),
    use_local_if_available=True
)

if bonus_path:
    print(f"\nBonus data ready at: {bonus_path}")
else:
    print("\nWarning: Could not fetch bonus data. Training will use only video frames.")

In [None]:
# Prepare dataset structure (creates experiment_dataset/)
# This combines:
# - Video frames cropped to ROI with generated disk masks
# - Bonus annotated frames with instance masks

DATASET_DIR = repo_root / "experiment_dataset"
DISK_RADIUS = 6  # Radius for disk masks (StarDist needs instance labels)

print("Preparing combined dataset...")
prepare_grand_dataset(
    bonus_data_dir=str(BONUS_DATA_DIR) if bonus_path else None,
    val_tif_path=str(VAL_TIF),
    val_csv_path=str(VAL_CSV),
    out_dir=str(DATASET_DIR),
    disk_radius=DISK_RADIUS
)

VIDEO_MAP_PATH = DATASET_DIR / "video_map.csv"
print(f"\nDataset ready at: {DATASET_DIR}")
print(f"Video frame map: {VIDEO_MAP_PATH}")

## 3. Training Configuration

Configure the training parameters below.

**Mode selection:**
- **Baseline mode**: Set `USE_AUGMENTATION = False`, `WEIGHT_DECAY = 0`
- **Improved mode**: Use defaults (all regularization features enabled)

In [None]:
# ============================================================
# TRAINING CONFIGURATION - EDIT THESE VALUES
# ============================================================

# Basic training parameters
K_SPLITS = 5          # Number of cross-validation folds
EPOCHS = 50           # Maximum epochs per fold (may stop earlier with early stopping)
BATCH_SIZE = 4        # Batch size (reduce if OOM)
LR = 1e-4             # Initial learning rate

# StarDist model parameters
N_RAYS = 32           # Number of radial directions (32, 64, or 96)
ENCODER_NAME = "resnet18"  # Backbone: resnet18, resnet34, resnet50, efficientnet-b0

# Detection thresholds (will be optimized by sweep)
PROB_THRESH = 0.5     # Initial probability threshold
NMS_THRESH = 0.3      # Initial NMS IoU threshold
MATCH_THRESH = 5.0    # Distance threshold for DetA calculation (pixels)

# Data options
USE_BONUS = True      # Include bonus training data from GitHub

# ============================================================
# IMPROVEMENT SETTINGS (set to False/0 for baseline)
# ============================================================

# Data augmentation
USE_AUGMENTATION = True  # Enable data augmentation (False for baseline)
AUG_PARAMS = {           # Augmentation probabilities (only used if USE_AUGMENTATION=True)
    'rotate_p': 0.7,
    'flip_p': 0.5,
    'brightness_p': 0.3,
    'noise_p': 0.2,
    'blur_p': 0.1,
    'elastic_p': 0.2
}

# Regularization
WEIGHT_DECAY = 1e-4      # L2 regularization (0 to disable)

# Combined loss weights
FOCAL_WEIGHT = 1.0       # Focal loss weight (handles class imbalance)
DICE_WEIGHT = 1.0        # Dice loss weight (overlap-based)
DIST_WEIGHT = 1.0        # Distance regression weight

# Learning rate scheduling
SCHEDULER_PATIENCE = 5   # Epochs to wait before reducing LR
SCHEDULER_FACTOR = 0.5   # Factor to reduce LR by

# Early stopping
EARLY_STOPPING_PATIENCE = 10  # Epochs to wait before stopping

# Post-training threshold sweep
RUN_THRESHOLD_SWEEP = True
SWEEP_PROB_THRESHOLDS = [0.3, 0.4, 0.5, 0.6, 0.7]
SWEEP_NMS_THRESHOLDS = [0.1, 0.2, 0.3, 0.4, 0.5]

# Output directory
SAVE_DIR = repo_root / "models" / "stardist" / "run"

print("Configuration:")
print(f"  Mode: {'IMPROVED' if USE_AUGMENTATION else 'BASELINE'}")
print(f"  Encoder: {ENCODER_NAME}, N_Rays: {N_RAYS}")
print(f"  Epochs: {EPOCHS}, Batch: {BATCH_SIZE}, LR: {LR}")
print(f"  Augmentation: {USE_AUGMENTATION}, Weight Decay: {WEIGHT_DECAY}")
print(f"  Save directory: {SAVE_DIR}")

## 4. Training

Run K-Fold cross-validation training. Each fold:
1. Trains with combined loss (Focal + Dice + Smooth L1)
2. Uses LR scheduling and early stopping
3. Runs threshold sweep to find optimal prob/nms thresholds
4. Evaluates DetA metric on validation set

In [None]:
# Run training
fold_results, best_fold, all_preds_df = train_stardist_kfold(
    dataset_root=DATASET_DIR,
    video_map_path=VIDEO_MAP_PATH,
    val_csv_path=str(VAL_CSV),
    k_splits=K_SPLITS,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    lr=LR,
    n_rays=N_RAYS,
    prob_thresh=PROB_THRESH,
    nms_thresh=NMS_THRESH,
    use_bonus=USE_BONUS,
    save_dir=SAVE_DIR,
    match_thresh=MATCH_THRESH,
    device=device,
    # Improvement parameters
    use_augmentation=USE_AUGMENTATION,
    aug_params=AUG_PARAMS if USE_AUGMENTATION else None,
    weight_decay=WEIGHT_DECAY,
    focal_weight=FOCAL_WEIGHT,
    dice_weight=DICE_WEIGHT,
    dist_weight=DIST_WEIGHT,
    scheduler_patience=SCHEDULER_PATIENCE,
    scheduler_factor=SCHEDULER_FACTOR,
    early_stopping_patience=EARLY_STOPPING_PATIENCE,
    encoder_name=ENCODER_NAME,
    run_threshold_sweep=RUN_THRESHOLD_SWEEP,
    sweep_prob_thresholds=SWEEP_PROB_THRESHOLDS,
    sweep_nms_thresholds=SWEEP_NMS_THRESHOLDS
)

## 5. Results & Evaluation

In [None]:
# Display fold results summary
print("\n" + "="*60)
print("K-FOLD RESULTS SUMMARY")
print("="*60)

for result in fold_results:
    print(f"\nFold {result['fold']}:")
    print(f"  DetA = {result['deta']:.4f}")
    print(f"  Optimal thresholds: prob={result['best_prob_thresh']:.2f}, nms={result['best_nms_thresh']:.2f}")
    print(f"  Stopped at epoch: {result['stopped_epoch']}")

deta_values = [r['deta'] for r in fold_results]
print(f"\n" + "-"*40)
print(f"Mean DetA: {np.mean(deta_values):.4f} +/- {np.std(deta_values):.4f}")
print(f"Best Fold: {best_fold['fold']} (DetA = {best_fold['deta']:.4f})")

In [None]:
# Plot training curves for best fold
history = best_fold['history']

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
ax1 = axes[0]
ax1.plot(history['train_loss'], label='Train Loss', color='blue', linewidth=2)
ax1.plot(history['val_loss'], label='Val Loss', color='orange', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title(f'Training Curves - Best Fold {best_fold["fold"]}')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Learning rate
ax2 = axes[1]
if 'lr' in history and history['lr']:
    ax2.plot(history['lr'], color='green', linewidth=2)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Learning Rate')
    ax2.set_title('Learning Rate Schedule')
    ax2.set_yscale('log')
    ax2.grid(True, alpha=0.3)
else:
    ax2.text(0.5, 0.5, 'LR history not available', ha='center', va='center', transform=ax2.transAxes)

plt.tight_layout()
SAVE_DIR.mkdir(parents=True, exist_ok=True)
plt.savefig(SAVE_DIR / 'training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

## 6. Full Video Inference

Run the best model on all video frames using optimized thresholds.

In [None]:
# Run inference on full video
best_model = best_fold['model']
best_prob = best_fold['best_prob_thresh']
best_nms = best_fold['best_nms_thresh']

print(f"Running inference with prob_thresh={best_prob:.2f}, nms_thresh={best_nms:.2f}...")

full_preds_df = infer_stardist_full_video(
    model=best_model,
    video_root=DATASET_DIR / "video",
    video_map_path=VIDEO_MAP_PATH,
    device=device,
    prob_thresh=best_prob,
    nms_thresh=best_nms
)

print(f"\nTotal detections: {len(full_preds_df)}")
print(f"Frames covered: {full_preds_df['frame'].nunique()}")

In [None]:
# Calculate final DetA on full predictions
gt_df = pd.read_csv(VAL_CSV)
gt_roi = gt_df[
    (gt_df.x >= ROI_X_MIN) & (gt_df.x < ROI_X_MAX) &
    (gt_df.y >= ROI_Y_MIN) & (gt_df.y < ROI_Y_MAX)
].copy()

common_frames = set(gt_roi.frame.unique()) & set(full_preds_df.frame.unique())
gt_filtered = gt_roi[gt_roi.frame.isin(common_frames)]
pred_filtered = full_preds_df[full_preds_df.frame.isin(common_frames)]

final_deta = calculate_deta_robust(gt_filtered, pred_filtered, match_thresh=MATCH_THRESH)
print(f"\nFinal DetA (all frames): {final_deta:.4f}")

## 7. Tracking

In [None]:
# Run tracking with default parameters
print("Running tracking...")

try:
    tracked_df = run_btrack_tracking(full_preds_df)
    tracker_name = "BTrack"
except Exception as e:
    print(f"BTrack failed ({e}), falling back to LapTrack...")
    tracked_df = run_laptrack(full_preds_df)
    tracker_name = "LapTrack"

print(f"\n{tracker_name} Results:")
print(f"  Total tracks: {tracked_df['track_id'].nunique()}")
print(f"  Total detections: {len(tracked_df)}")

In [None]:
# Calculate HOTA metric
gt_for_hota = gt_roi[['frame', 'x', 'y', 'track']].rename(columns={'track': 'track_id'})
pred_for_hota = tracked_df[['frame', 'x', 'y', 'track_id']]

hota_score = hota(gt_for_hota, pred_for_hota, match_thresh=MATCH_THRESH)
print(f"\nHOTA Score: {hota_score:.4f}")

## 8. Visualization

In [None]:
# Visualize sample frames with detections
import tifffile

video_frames = tifffile.imread(VAL_TIF)
roi = video_frames[:, ROI_Y_MIN:ROI_Y_MAX, ROI_X_MIN:ROI_X_MAX]

sample_frames = [0, 30, 60, 90]
fig, axes = plt.subplots(1, len(sample_frames), figsize=(16, 4))

for ax, fidx in zip(axes, sample_frames):
    if fidx >= len(roi):
        continue
    
    ax.imshow(roi[fidx], cmap='gray')
    
    # Plot predictions
    frame_preds = full_preds_df[full_preds_df.frame == fidx]
    ax.scatter(
        frame_preds.x - ROI_X_MIN,
        frame_preds.y - ROI_Y_MIN,
        c='lime', s=20, marker='o', alpha=0.8, label='Predicted'
    )
    
    # Plot GT
    frame_gt = gt_roi[gt_roi.frame == fidx]
    ax.scatter(
        frame_gt.x - ROI_X_MIN,
        frame_gt.y - ROI_Y_MIN,
        c='red', s=30, marker='x', alpha=0.8, label='GT'
    )
    
    ax.set_title(f'Frame {fidx}')
    ax.axis('off')

axes[0].legend(loc='upper left')
plt.suptitle('StarDist Detection Results', fontsize=14)
plt.tight_layout()
plt.savefig(SAVE_DIR / 'detection_samples.png', dpi=150, bbox_inches='tight')
plt.show()

## 9. Save Final Results

In [None]:
# Save predictions and tracking results
SAVE_DIR.mkdir(parents=True, exist_ok=True)

full_preds_df.to_csv(SAVE_DIR / 'stardist_predictions.csv', index=False)
tracked_df.to_csv(SAVE_DIR / 'stardist_tracked.csv', index=False)

# Save summary
summary = {
    'best_fold': best_fold['fold'],
    'best_deta': best_fold['deta'],
    'final_deta': final_deta,
    'hota': hota_score,
    'best_prob_thresh': best_prob,
    'best_nms_thresh': best_nms,
    'total_detections': len(full_preds_df),
    'total_tracks': tracked_df['track_id'].nunique(),
    'use_augmentation': USE_AUGMENTATION,
    'weight_decay': WEIGHT_DECAY,
    'encoder': ENCODER_NAME,
    'n_rays': N_RAYS,
    'repo_url': REPO_URL
}

pd.DataFrame([summary]).to_csv(SAVE_DIR / 'training_summary.csv', index=False)

print("\n" + "="*60)
print("FINAL SUMMARY")
print("="*60)
for k, v in summary.items():
    if isinstance(v, float):
        print(f"  {k}: {v:.4f}")
    else:
        print(f"  {k}: {v}")
print(f"\nResults saved to: {SAVE_DIR}")