[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/veselm73/SU2/blob/main/notebooks/SU2_StarDist_final.ipynb)

# StarDist Cell Detection Training Pipeline

This notebook trains a StarDist model for cell detection using K-Fold cross-validation.

**Key Features:**
- StarDist architecture with configurable backbone (ResNet18/34/50, EfficientNet)
- Combined loss function (Focal + Dice + Smooth L1)
- Data augmentation with albumentations
- Regularization: Weight decay + LR scheduling + Early stopping
- Post-training threshold sweep for optimal prob_thresh/nms_thresh
- Tracking evaluation with HOTA metric and LapTrack parameter sweep

**Data Sources:**
- Validation video: Downloaded from UTIA server
- Bonus training data: Fetched from GitHub repository (annotated frames)

**Outputs:**
- Trained StarDist model (best fold)
- Detection predictions CSV
- Training curves and metrics
- Tracking visualization

## 1. Setup & Configuration

In [None]:
# ============================================================
# REPOSITORY CONFIGURATION
# ============================================================
# Set these if running from Colab or a different location

REPO_URL = "https://github.com/veselm73/SU2"
REPO_BRANCH = "main"

# Path to bonus training data within the repo
BONUS_DATA_SUBPATH = "annotation/sam_data/unet_train"

In [None]:
# Install dependencies using uv (much faster than pip)
!pip install uv -q

# Uninstall conflicting packages  
!uv pip uninstall torch torchvision torchaudio tensorflow tensorflow-metal --system -q 2>/dev/null || true

# Install PyTorch with CUDA 12.1
!uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 --system -q

# Install other dependencies (numpy<2 for compatibility)
!uv pip install "numpy<2" cellseg-models-pytorch pytorch-lightning laptrack btrack "albumentations>=1.3.1" tifffile opencv-python-headless --system -q

In [None]:
import sys
import os
from pathlib import Path

# Check if running in Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Clone repo if not already present, otherwise pull latest
    if not Path('/content/SU2').exists():
        !git clone https://github.com/veselm73/SU2.git /content/SU2
    else:
        !cd /content/SU2 && git pull
    os.chdir('/content/SU2')
    repo_root = Path('/content/SU2')
else:
    # Local setup - find repo root
    notebook_dir = Path(os.getcwd())
    if notebook_dir.name == 'notebooks':
        repo_root = notebook_dir.parent
    else:
        repo_root = notebook_dir

if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

print(f"Repository root: {repo_root}")
print(f"Working directory: {os.getcwd()}")

In [None]:
# Import helper functions
from modules.stardist_helpers import (
    # Data fetching
    download_validation_data,
    fetch_training_data_from_github,
    prepare_grand_dataset,
    create_stardist_label_mask,
    # Training
    train_stardist_kfold,
    infer_stardist_full_video,
    sweep_stardist_thresholds,
    # Metrics
    calculate_deta_robust,
    hota,
    # Tracking
    run_laptrack,
    # Visualization
    plot_training_history,
    show_detection_overlay,
    show_tracking_animation,
    print_results_summary,
    # Utilities
    set_seed,
    get_device,
    ROI_X_MIN, ROI_X_MAX, ROI_Y_MIN, ROI_Y_MAX
)

import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

set_seed(42)
device = get_device()

## 2. Data Preparation

This section downloads/fetches the required data:
1. **Validation video**: From UTIA server (val.tif + val.csv)
2. **Bonus training data**: From GitHub repository (annotated frames with masks)

In [None]:
# Define data paths
VAL_DIR = repo_root / "data" / "val"
VAL_TIF = VAL_DIR / "val.tif"
VAL_CSV = VAL_DIR / "val.csv"
BONUS_DATA_DIR = repo_root / "bonus_training_data"

print("Data paths:")
print(f"  Validation: {VAL_DIR}")
print(f"  Bonus data: {BONUS_DATA_DIR}")

In [None]:
# Download validation data from UTIA server
if not VAL_TIF.exists():
    print("Downloading validation data from UTIA server...")
    download_validation_data(target_dir=str(VAL_DIR))
else:
    print(f"Validation data exists: {VAL_TIF}")

In [None]:
# Fetch bonus training data from GitHub
# This will:
# 1. Check if data exists locally (in cloned repo)
# 2. If not, download from GitHub API

bonus_path = fetch_training_data_from_github(
    repo_url=REPO_URL,
    branch=REPO_BRANCH,
    data_subpath=BONUS_DATA_SUBPATH,
    target_dir=str(BONUS_DATA_DIR),
    use_local_if_available=True
)

if bonus_path:
    print(f"\nBonus data ready at: {bonus_path}")
else:
    print("\nWarning: Could not fetch bonus data. Training will use only video frames.")

In [None]:
# Prepare dataset structure (creates experiment_dataset/)
# This combines:
# - Video frames cropped to ROI with generated disk masks
# - Bonus annotated frames with instance masks

DATASET_DIR = repo_root / "experiment_dataset"
DISK_RADIUS = 3  # Radius for disk masks (3px for precise localization)

print("Preparing combined dataset...")
prepare_grand_dataset(
    bonus_data_dir=str(BONUS_DATA_DIR) if bonus_path else None,
    val_tif_path=str(VAL_TIF),
    val_csv_path=str(VAL_CSV),
    out_dir=str(DATASET_DIR),
    disk_radius=DISK_RADIUS
)

VIDEO_MAP_PATH = DATASET_DIR / "video_map.csv"
print(f"\nDataset ready at: {DATASET_DIR}")
print(f"Video frame map: {VIDEO_MAP_PATH}")

## 3. Training Configuration

Configure the training parameters below.

**Mode selection:**
- **Baseline mode**: Set `USE_AUGMENTATION = False`, `WEIGHT_DECAY = 0`
- **Improved mode**: Use defaults (all regularization features enabled)

In [None]:
# ============================================================
# TRAINING CONFIGURATION - FINAL RUN FOR BEST DetA
# ============================================================

# Basic training parameters - EXTENDED FOR FINAL RUN
K_SPLITS = 5          # Number of cross-validation folds
EPOCHS = 100          # Extended epochs for thorough training
BATCH_SIZE = 4        # Batch size (reduce if OOM)
LR = 1e-4             # Initial learning rate

# StarDist model parameters
N_RAYS = 32           # Number of radial directions (32, 64, or 96)
ENCODER_NAME = "resnet18"  # Backbone: resnet18, resnet34, resnet50, efficientnet-b0

# Detection thresholds - FIXED during training (ensemble sweep later)
PROB_THRESH = 0.3     # Fixed threshold for training
NMS_THRESH = 0.3      # Fixed NMS threshold
MATCH_THRESH = 5.0    # Distance threshold for DetA calculation (pixels)

# Data options
USE_BONUS = True      # Include bonus training data from GitHub

# Output directory
SAVE_DIR = repo_root / "results" / "stardist"

# ============================================================
# FINAL RUN CONFIGURATION - OPTIMIZED FOR BEST DetA
# ============================================================
# Based on experiments:
# - BASELINE loss (BCE + L1) performs best on this dataset
# - Light augmentation helps generalization
# - Moderate dropout (0.1) prevents overfitting
# - NO per-fold threshold sweep (ensemble sweep at the end)

TRAINING_MODE = "FINAL"  # Custom mode for final run

# Loss settings (BASELINE performs best)
USE_SIMPLE_BCE = True   # Simple BCE outperforms Focal on this dataset
USE_SIMPLE_L1 = True    # Simple L1 outperforms Smooth L1
FOCAL_WEIGHT = 1.0
DICE_WEIGHT = 0.0       # Disabled - not helpful for this task
DIST_WEIGHT = 1.0

# Regularization
DROPOUT = 0.1           # Light dropout
WEIGHT_DECAY = 0.0      # No weight decay (BASELINE)

# Light augmentation (helps generalization without distorting cells)
USE_AUGMENTATION = True
AUG_PARAMS = {
    'rotate_p': 0.3,      # Light rotation
    'flip_p': 0.5,        # Flips are safe for cells
    'brightness_p': 0.2,  # Light brightness variation
    'noise_p': 0.1,       # Very light noise
    'blur_p': 0.0,        # No blur
    'elastic_p': 0.0      # No elastic (distorts cells)
}

# Training schedule - patient for best convergence
EARLY_STOPPING_PATIENCE = 30  # Very patient - let it converge
SCHEDULER_PATIENCE = 10       # Reduce LR after 10 epochs plateau
SCHEDULER_FACTOR = 0.5        # Halve LR on plateau

# ============================================================
# THRESHOLD SWEEP STRATEGY
# ============================================================
# NO per-fold sweep during training (saves time, avoids overfitting)
# Ensemble threshold sweep will be done AFTER training on OOF predictions
RUN_THRESHOLD_SWEEP = False  # Disabled - will do ensemble sweep later

# These will be used for ensemble sweep after training
ENSEMBLE_PROB_THRESHOLDS = [0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.5]
ENSEMBLE_NMS_THRESHOLDS = [0.1, 0.2, 0.3, 0.4, 0.5]

# Print configuration
print("=" * 60)
print("FINAL RUN CONFIGURATION - OPTIMIZED FOR BEST DetA")
print("=" * 60)
print(f"\nModel:")
print(f"  Encoder: {ENCODER_NAME}, N_Rays: {N_RAYS}")
print(f"  Dropout: {DROPOUT}")
print(f"\nTraining:")
print(f"  Epochs: {EPOCHS} (extended)")
print(f"  Batch Size: {BATCH_SIZE}, LR: {LR}")
print(f"  Early Stopping: patience={EARLY_STOPPING_PATIENCE}")
print(f"  LR Scheduler: patience={SCHEDULER_PATIENCE}, factor={SCHEDULER_FACTOR}")
print(f"\nLoss: BCE + L1 (BASELINE - best for this dataset)")
print(f"\nAugmentation: {USE_AUGMENTATION}")
if USE_AUGMENTATION:
    print(f"  rotate_p={AUG_PARAMS['rotate_p']}, flip_p={AUG_PARAMS['flip_p']}")
    print(f"  brightness_p={AUG_PARAMS['brightness_p']}, noise_p={AUG_PARAMS['noise_p']}")
print(f"\nThreshold Strategy:")
print(f"  Per-fold sweep: DISABLED (faster training)")
print(f"  Fixed threshold during training: prob={PROB_THRESH}, nms={NMS_THRESH}")
print(f"  Ensemble sweep after training: {len(ENSEMBLE_PROB_THRESHOLDS)}x{len(ENSEMBLE_NMS_THRESHOLDS)} combinations")
print(f"\nSave directory: {SAVE_DIR}")
print("=" * 60)

## 4. Training

Run K-Fold cross-validation training. Each fold:
1. Trains with combined loss (Focal + Dice + Smooth L1)
2. Uses LR scheduling and early stopping
3. Runs threshold sweep to find optimal prob/nms thresholds
4. Evaluates DetA metric on validation set

In [None]:
# Run training (NO per-fold threshold sweep - will do ensemble sweep later)
fold_results, best_fold, all_preds_df = train_stardist_kfold(
    dataset_root=DATASET_DIR,
    video_map_path=VIDEO_MAP_PATH,
    val_csv_path=str(VAL_CSV),
    k_splits=K_SPLITS,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    lr=LR,
    n_rays=N_RAYS,
    prob_thresh=PROB_THRESH,
    nms_thresh=NMS_THRESH,
    use_bonus=USE_BONUS,
    save_dir=SAVE_DIR,
    match_thresh=MATCH_THRESH,
    device=device,
    # Improvement parameters
    use_augmentation=USE_AUGMENTATION,
    aug_params=AUG_PARAMS if USE_AUGMENTATION else None,
    weight_decay=WEIGHT_DECAY,
    focal_weight=FOCAL_WEIGHT,
    dice_weight=DICE_WEIGHT,
    dist_weight=DIST_WEIGHT,
    scheduler_patience=SCHEDULER_PATIENCE,
    scheduler_factor=SCHEDULER_FACTOR,
    early_stopping_patience=EARLY_STOPPING_PATIENCE,
    encoder_name=ENCODER_NAME,
    # NO per-fold threshold sweep (ensemble sweep after training)
    run_threshold_sweep=False,
    # Baseline mode: use simple BCE + L1
    use_simple_bce=USE_SIMPLE_BCE,
    use_simple_l1=USE_SIMPLE_L1,
    # Dropout for regularization
    dropout=DROPOUT
)

## 5. Results & Evaluation

In [None]:
# Display fold results summary
print("\n" + "="*60)
print("K-FOLD TRAINING RESULTS")
print("="*60)

for result in fold_results:
    print(f"\nFold {result['fold']}:")
    print(f"  DetA = {result['deta']:.4f} (with fixed thresh: prob={PROB_THRESH}, nms={NMS_THRESH})")
    print(f"  Stopped at epoch: {result['stopped_epoch']}")

deta_values = [r['deta'] for r in fold_results]
print(f"\n" + "-"*40)
print(f"Mean DetA: {np.mean(deta_values):.4f} +/- {np.std(deta_values):.4f}")
print(f"Best Fold: {best_fold['fold']} (DetA = {best_fold['deta']:.4f})")
print(f"\nNote: These DetA values use fixed threshold {PROB_THRESH}")
print(f"Ensemble threshold sweep will optimize this next.")

In [None]:
# Plot training curves for best fold
history = best_fold['history']

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
ax1 = axes[0]
ax1.plot(history['train_loss'], label='Train Loss', color='blue', linewidth=2)
ax1.plot(history['val_loss'], label='Val Loss', color='orange', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title(f'Training Curves - Best Fold {best_fold["fold"]}')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Learning rate
ax2 = axes[1]
if 'lr' in history and history['lr']:
    ax2.plot(history['lr'], color='green', linewidth=2)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Learning Rate')
    ax2.set_title('Learning Rate Schedule')
    ax2.set_yscale('log')
    ax2.grid(True, alpha=0.3)
else:
    ax2.text(0.5, 0.5, 'LR history not available', ha='center', va='center', transform=ax2.transAxes)

plt.tight_layout()
SAVE_DIR.mkdir(parents=True, exist_ok=True)
plt.savefig(SAVE_DIR / 'training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

## 6. Ensemble Threshold Sweep

Now we find the **optimal threshold for ensemble inference** by sweeping on OOF predictions.

This is NOT data leakage because:
- Each frame's prediction comes from a model that didn't see it during training
- We're finding one global threshold for the entire ensemble

In [None]:
# Ensemble threshold sweep on OOF predictions
# This finds the optimal threshold for all models combined

from tqdm.auto import tqdm
import tifffile

print("=" * 60)
print("ENSEMBLE THRESHOLD SWEEP")
print("=" * 60)
print(f"\nSweeping {len(ENSEMBLE_PROB_THRESHOLDS)} x {len(ENSEMBLE_NMS_THRESHOLDS)} = {len(ENSEMBLE_PROB_THRESHOLDS) * len(ENSEMBLE_NMS_THRESHOLDS)} combinations")

# Load ground truth
gt_df = pd.read_csv(VAL_CSV)
gt_roi = gt_df[
    (gt_df.x >= ROI_X_MIN) & (gt_df.x < ROI_X_MAX) &
    (gt_df.y >= ROI_Y_MIN) & (gt_df.y < ROI_Y_MAX)
].copy()

# Load video for re-inference with different thresholds
video = tifffile.imread(VAL_TIF)
video_roi = video[:, ROI_Y_MIN:ROI_Y_MAX, ROI_X_MIN:ROI_X_MAX]

# Get video frame indices from the dataset
video_map = pd.read_csv(VIDEO_MAP_PATH)
video_frames = video_map[video_map['source'] == 'video']['frame'].tolist()

# We need to re-run inference with different thresholds
# Use all fold models and average their probability outputs
from cellseg_models_pytorch.postproc.functional.stardist.stardist import post_proc_stardist
from skimage.measure import regionprops

best_deta = 0
best_prob_thresh = PROB_THRESH
best_nms_thresh = NMS_THRESH
sweep_results = []

# Move all models to device
for result in fold_results:
    result['model'].to(device)
    result['model'].eval()

for prob_t in tqdm(ENSEMBLE_PROB_THRESHOLDS, desc="prob_thresh"):
    for nms_t in ENSEMBLE_NMS_THRESHOLDS:
        # Run ensemble inference on all video frames
        all_preds = []
        
        for frame_idx in video_frames:
            frame = video_roi[frame_idx].astype(np.float32)
            frame_norm = (frame - frame.mean()) / (frame.std() + 1e-8)
            x = torch.from_numpy(frame_norm).float().unsqueeze(0).unsqueeze(0).to(device)
            
            # Collect predictions from all folds and average
            all_dist = []
            all_bin = []
            
            with torch.no_grad():
                for result in fold_results:
                    model = result['model']
                    pred_dist, pred_bin = model.model(x)
                    all_dist.append(pred_dist.cpu().numpy()[0])
                    all_bin.append(torch.sigmoid(pred_bin).cpu().numpy()[0, 0])
            
            # Average across folds
            avg_dist = np.mean(all_dist, axis=0)
            avg_bin = np.mean(all_bin, axis=0)
            
            # Post-process with current thresholds
            try:
                labels = post_proc_stardist(avg_dist, avg_bin, prob_thresh=prob_t, nms_thresh=nms_t)
                for prop in regionprops(labels):
                    y, cx = prop.centroid
                    all_preds.append({
                        'frame': frame_idx,
                        'x': cx + ROI_X_MIN,
                        'y': y + ROI_Y_MIN
                    })
            except:
                pass
        
        # Calculate DetA
        if all_preds:
            preds_df = pd.DataFrame(all_preds)
            common_frames = set(gt_roi.frame.unique()) & set(preds_df.frame.unique())
            gt_filtered = gt_roi[gt_roi.frame.isin(common_frames)]
            pred_filtered = preds_df[preds_df.frame.isin(common_frames)]
            deta = calculate_deta_robust(gt_filtered, pred_filtered, match_thresh=MATCH_THRESH)
        else:
            deta = 0.0
        
        sweep_results.append({
            'prob_thresh': prob_t,
            'nms_thresh': nms_t,
            'deta': deta,
            'n_detections': len(all_preds)
        })
        
        if deta > best_deta:
            best_deta = deta
            best_prob_thresh = prob_t
            best_nms_thresh = nms_t
            print(f"  New best: prob={prob_t:.2f}, nms={nms_t:.2f} -> DetA={deta:.4f}")

# Store results
sweep_df = pd.DataFrame(sweep_results)
print(f"\n" + "=" * 60)
print("ENSEMBLE THRESHOLD SWEEP RESULTS")
print("=" * 60)
print(f"\nBest ensemble thresholds:")
print(f"  prob_thresh: {best_prob_thresh}")
print(f"  nms_thresh: {best_nms_thresh}")
print(f"  DetA: {best_deta:.4f}")

# Update best values for downstream use
best_prob = best_prob_thresh
best_nms = best_nms_thresh

## 7. Final Ensemble Predictions

Generate final predictions using the **best ensemble thresholds** found in the sweep.

In [None]:
# Generate final ensemble predictions with best thresholds
print(f"Generating final ensemble predictions...")
print(f"  Using: prob_thresh={best_prob}, nms_thresh={best_nms}")

final_preds = []

for frame_idx in tqdm(video_frames, desc="Ensemble inference"):
    frame = video_roi[frame_idx].astype(np.float32)
    frame_norm = (frame - frame.mean()) / (frame.std() + 1e-8)
    x = torch.from_numpy(frame_norm).float().unsqueeze(0).unsqueeze(0).to(device)
    
    # Average predictions from all folds
    all_dist = []
    all_bin = []
    
    with torch.no_grad():
        for result in fold_results:
            model = result['model']
            pred_dist, pred_bin = model.model(x)
            all_dist.append(pred_dist.cpu().numpy()[0])
            all_bin.append(torch.sigmoid(pred_bin).cpu().numpy()[0, 0])
    
    avg_dist = np.mean(all_dist, axis=0)
    avg_bin = np.mean(all_bin, axis=0)
    
    # Post-process with best thresholds
    try:
        labels = post_proc_stardist(avg_dist, avg_bin, prob_thresh=best_prob, nms_thresh=best_nms)
        for prop in regionprops(labels):
            y, cx = prop.centroid
            final_preds.append({
                'frame': frame_idx,
                'x': cx + ROI_X_MIN,
                'y': y + ROI_Y_MIN
            })
    except:
        pass

full_preds_df = pd.DataFrame(final_preds)

# Calculate final DetA
common_frames = set(gt_roi.frame.unique()) & set(full_preds_df.frame.unique())
gt_filtered = gt_roi[gt_roi.frame.isin(common_frames)]
pred_filtered = full_preds_df[full_preds_df.frame.isin(common_frames)]

final_deta = calculate_deta_robust(gt_filtered, pred_filtered, match_thresh=MATCH_THRESH)

print(f"\n" + "=" * 60)
print("FINAL ENSEMBLE RESULTS")
print("=" * 60)
print(f"Total detections: {len(full_preds_df)}")
print(f"Frames covered: {full_preds_df['frame'].nunique()}")
print(f"Final DetA: {final_deta:.4f}")
print(f"Thresholds: prob={best_prob}, nms={best_nms}")
print("=" * 60)

## 7. Tracking

**Note on Parameter Selection:** We use fixed tracking parameters (not tuned on GT) to avoid data leakage. These defaults are based on typical cell movement characteristics:
- `track_cost_cutoff=25` (5px): Maximum squared distance for frame-to-frame linking
- `gap_closing_cost_cutoff=25`: Distance for gap closing
- `gap_closing_max_frame_count=1`: Maximum frames to skip for gap closing

This ensures HOTA scores reflect what you'll achieve on unknown test data.

In [None]:
# Run LapTrack with FIXED parameters (no data leakage)
# Parameters chosen based on typical cell movement (~5 pixels per frame)

# Fixed tracking parameters (not tuned on GT!)
TRACK_COST_CUTOFF = 25          # 5 pixels squared distance
GAP_CLOSING_COST_CUTOFF = 25    # 5 pixels for gap closing
GAP_CLOSING_MAX_FRAMES = 1      # Only close 1-frame gaps

# Run tracking with fixed parameters
tracked_df = run_laptrack(
    detections_df=full_preds_df,
    track_cost_cutoff=TRACK_COST_CUTOFF,
    gap_closing_cost_cutoff=GAP_CLOSING_COST_CUTOFF,
    gap_closing_max_frame_count=GAP_CLOSING_MAX_FRAMES
)

print(f"Tracking complete with fixed parameters (no data leakage)!")
print(f"  track_cost_cutoff: {TRACK_COST_CUTOFF} (â‰ˆ{int(TRACK_COST_CUTOFF**0.5)}px)")
print(f"  gap_closing_cost_cutoff: {GAP_CLOSING_COST_CUTOFF}")
print(f"  gap_closing_max_frame_count: {GAP_CLOSING_MAX_FRAMES}")
print(f"  Total tracks: {tracked_df['track_id'].nunique()}")
print(f"  Total detections: {len(tracked_df)}")

In [None]:
# Calculate HOTA metrics with fixed tracking parameters
gt_for_tracking = gt_roi[['frame', 'x', 'y', 'track_id']].copy()

hota_scores = hota(gt_for_tracking, tracked_df, match_thresh=MATCH_THRESH)

print("HOTA Metrics (with fixed tracking parameters - no data leakage):")
print(f"  HOTA: {hota_scores['HOTA']:.4f}")
print(f"  DetA: {hota_scores['DetA']:.4f}")
print(f"  AssA: {hota_scores['AssA']:.4f}")

# Store for summary
hota_score = hota_scores['HOTA']

## 8. Visualization

In [None]:
# Visualize sample frames with detections
import tifffile

video_frames = tifffile.imread(VAL_TIF)
roi = video_frames[:, ROI_Y_MIN:ROI_Y_MAX, ROI_X_MIN:ROI_X_MAX]

sample_frames = [0, 30, 60, 90]
fig, axes = plt.subplots(1, len(sample_frames), figsize=(16, 4))

for ax, fidx in zip(axes, sample_frames):
    if fidx >= len(roi):
        continue
    
    ax.imshow(roi[fidx], cmap='gray')
    
    # Plot GT (green circles)
    frame_gt = gt_roi[gt_roi.frame == fidx]
    ax.scatter(
        frame_gt.x - ROI_X_MIN,
        frame_gt.y - ROI_Y_MIN,
        c='lime', s=30, marker='o', alpha=0.8, label='GT'
    )
    
    # Plot predictions (red crosses)
    frame_preds = full_preds_df[full_preds_df.frame == fidx]
    ax.scatter(
        frame_preds.x - ROI_X_MIN,
        frame_preds.y - ROI_Y_MIN,
        c='red', s=20, marker='x', alpha=0.8, label='Predicted'
    )
    
    ax.set_title(f'Frame {fidx}')
    ax.axis('off')

axes[0].legend(loc='upper left')
plt.suptitle('StarDist Detection Results', fontsize=14)
plt.tight_layout()
plt.savefig(SAVE_DIR / 'detection_samples.png', dpi=150, bbox_inches='tight')
plt.show()

## 9. Save Final Results

In [None]:
# Save predictions and tracking results
SAVE_DIR.mkdir(parents=True, exist_ok=True)

full_preds_df.to_csv(SAVE_DIR / 'stardist_predictions.csv', index=False)
tracked_df.to_csv(SAVE_DIR / 'stardist_tracked.csv', index=False)

# Save summary
summary = {
    'best_fold': best_fold['fold'],
    'best_deta': best_fold['deta'],
    'final_deta': final_deta,
    'hota': hota_score,
    'deta_from_hota': hota_scores['DetA'],
    'assa': hota_scores['AssA'],
    'best_prob_thresh': best_prob,
    'best_nms_thresh': best_nms,
    'track_cost_cutoff': TRACK_COST_CUTOFF,
    'gap_closing_cost_cutoff': GAP_CLOSING_COST_CUTOFF,
    'gap_closing_max_frame_count': GAP_CLOSING_MAX_FRAMES,
    'total_detections': len(full_preds_df),
    'total_tracks': tracked_df['track_id'].nunique(),
    'use_augmentation': USE_AUGMENTATION,
    'weight_decay': WEIGHT_DECAY,
    'encoder': ENCODER_NAME,
    'n_rays': N_RAYS,
    'dropout': DROPOUT,
    'repo_url': REPO_URL
}

pd.DataFrame([summary]).to_csv(SAVE_DIR / 'training_summary.csv', index=False)

print("\n" + "="*60)
print("TRAINING SUMMARY")
print("="*60)
print("\nDetection:")
print(f"  Best Fold: {summary['best_fold']}")
print(f"  OOF DetA: {summary['final_deta']:.4f}")
print(f"  prob_thresh: {summary['best_prob_thresh']:.2f}")
print(f"  nms_thresh: {summary['best_nms_thresh']:.2f}")
print("\nTracking (HOTA):")
print(f"  HOTA: {summary['hota']:.4f}")
print(f"  DetA: {summary['deta_from_hota']:.4f}")
print(f"  AssA: {summary['assa']:.4f}")
print(f"\nResults saved to: {SAVE_DIR}")

## 10. Export Competition Artifacts

This section exports all trained models and configurations for competition submission.

**Exported files:**
- `models/fold_1.pth` ... `fold_5.pth` - All K-fold model weights
- `model_config.json` - Architecture configuration
- `inference_config.json` - Optimal thresholds and tracking parameters

**Usage:** Load models and configs to run inference on new test data.

In [None]:
# Create competition export directory
import json

COMPETITION_DIR = SAVE_DIR / "competition"
MODELS_DIR = COMPETITION_DIR / "models"
MODELS_DIR.mkdir(parents=True, exist_ok=True)

print("Exporting competition artifacts...")
print(f"  Export directory: {COMPETITION_DIR}")

In [None]:
# Export all K-fold models as state_dict (.pth)
print("\nExporting fold models...")

for result in fold_results:
    fold_num = result['fold']
    model = result['model']
    save_path = MODELS_DIR / f"fold_{fold_num}.pth"
    torch.save(model.state_dict(), save_path)
    print(f"  Saved: fold_{fold_num}.pth")

print(f"\nExported {len(fold_results)} models to {MODELS_DIR}")

In [None]:
# Export model configuration
model_config = {
    "architecture": "StarDist",
    "encoder_name": ENCODER_NAME,
    "n_rays": N_RAYS,
    "input_channels": 1,
    "dropout": DROPOUT
}

model_config_path = COMPETITION_DIR / "model_config.json"
with open(model_config_path, 'w') as f:
    json.dump(model_config, f, indent=2)
print(f"\nSaved: model_config.json")

# Export inference configuration (optimal thresholds + tracking params)
inference_config = {
    "prob_thresh": float(best_prob),
    "nms_thresh": float(best_nms),
    "track_cost_cutoff": int(TRACK_COST_CUTOFF),
    "gap_closing_cost_cutoff": int(GAP_CLOSING_COST_CUTOFF),
    "gap_closing_max_frame_count": int(GAP_CLOSING_MAX_FRAMES),
    "match_thresh": float(MATCH_THRESH),
    "roi": {
        "x_min": int(ROI_X_MIN),
        "x_max": int(ROI_X_MAX),
        "y_min": int(ROI_Y_MIN),
        "y_max": int(ROI_Y_MAX)
    }
}

inference_config_path = COMPETITION_DIR / "inference_config.json"
with open(inference_config_path, 'w') as f:
    json.dump(inference_config, f, indent=2)
print(f"Saved: inference_config.json")

In [None]:
# Print export summary
print("\n" + "="*60)
print("COMPETITION EXPORT COMPLETE")
print("="*60)

print(f"\nExported to: {COMPETITION_DIR}")
print("\nFiles:")
for item in sorted(COMPETITION_DIR.rglob('*')):
    if item.is_file():
        rel_path = item.relative_to(COMPETITION_DIR)
        size_kb = item.stat().st_size / 1024
        print(f"  {rel_path} ({size_kb:.1f} KB)")

print("\n" + "-"*60)
print("USAGE INSTRUCTIONS")
print("-"*60)
print("""
To run inference on new test data:

1. Load model_config.json and inference_config.json
2. Create StarDistLightning model with config params
3. Load fold weights (ensemble or single best)
4. Run inference with prob_thresh/nms_thresh from config
5. Run LapTrack with tracking params from config

See SU2_StarDist_export.ipynb for ready-to-use inference API.
""")
print("="*60)

## 11. Save Weights to GitHub / Google Drive

**IMPORTANT:** Colab runtime files are deleted when the session ends. Run one of the cells below to persist your trained models.

**Option A (Preferred):** Push to GitHub repository  
**Option B (Fallback):** Save to Google Drive

In [None]:
# ============================================================
# OPTION A (PREFERRED): Push weights to GitHub
# ============================================================
# This saves your trained models directly to the repository

import subprocess

# Configure git (required for Colab)
!git config --global user.email "veselm73@gmail.com"
!git config --global user.name "veselm73"

# Check current status
print("Current git status:")
!cd {repo_root} && git status --short

# Add competition artifacts to git
!cd {repo_root} && git add results/stardist/competition/

# Create commit with training results
commit_msg = f"Add trained StarDist models (DetA={final_deta:.4f}, HOTA={hota_score:.4f})"
!cd {repo_root} && git commit -m "{commit_msg}"

# Push to GitHub (you may need to authenticate)
print("\nPushing to GitHub...")
print("If prompted, enter your GitHub Personal Access Token as password")
!cd {repo_root} && git push

print("\n" + "="*60)
print("SUCCESS: Weights pushed to GitHub!")
print(f"Repository: {REPO_URL}")
print("="*60)

In [None]:
# ============================================================
# OPTION B (FALLBACK): Save to Google Drive
# ============================================================
# Use this if GitHub push fails or you prefer Google Drive

from google.colab import drive

# Mount Google Drive
print("Mounting Google Drive...")
drive.mount('/content/drive')

# Create backup directory
DRIVE_BACKUP_DIR = Path('/content/drive/MyDrive/SU2_competition_backup')
DRIVE_BACKUP_DIR.mkdir(parents=True, exist_ok=True)

# Copy competition artifacts to Drive
import shutil
shutil.copytree(COMPETITION_DIR, DRIVE_BACKUP_DIR / 'competition', dirs_exist_ok=True)

# Also save training summary and predictions
shutil.copy(SAVE_DIR / 'training_summary.csv', DRIVE_BACKUP_DIR / 'training_summary.csv')
shutil.copy(SAVE_DIR / 'stardist_predictions.csv', DRIVE_BACKUP_DIR / 'stardist_predictions.csv')
shutil.copy(SAVE_DIR / 'stardist_tracked.csv', DRIVE_BACKUP_DIR / 'stardist_tracked.csv')

print("\n" + "="*60)
print("SUCCESS: Weights saved to Google Drive!")
print(f"Location: {DRIVE_BACKUP_DIR}")
print("="*60)

# List saved files
print("\nSaved files:")
for item in sorted(DRIVE_BACKUP_DIR.rglob('*')):
    if item.is_file():
        rel_path = item.relative_to(DRIVE_BACKUP_DIR)
        size_kb = item.stat().st_size / 1024
        print(f"  {rel_path} ({size_kb:.1f} KB)")