[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/veselm73/SU2/blob/main/notebooks/SU2_StarDist_export.ipynb)

# StarDist Competition Export & Inference

This notebook exports trained StarDist models for competition submission and provides an easy-to-use inference API.

**Features:**
- Export all K-fold models for ensemble inference
- Save model and inference configurations as JSON
- Single-function inference API: `infer_video(path) -> DataFrame`
- Visualization of predictions

**Prerequisites:**
- Run `SU2_StarDist_final.ipynb` first to train models
- Trained model checkpoints in `results/stardist/`

## 1. Setup & Configuration

In [None]:
# Install dependencies using uv (much faster than pip)
!pip install uv -q

# Uninstall conflicting packages  
!uv pip uninstall torch torchvision torchaudio tensorflow tensorflow-metal --system -q 2>/dev/null || true

# Install PyTorch with CUDA 12.1
!uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 --system -q

# Install other dependencies
!uv pip install "numpy<2" cellseg-models-pytorch pytorch-lightning laptrack tifffile --system -q

In [None]:
import sys
import os
import json
from pathlib import Path

# Check if running in Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    if not Path('/content/SU2').exists():
        !git clone https://github.com/veselm73/SU2.git /content/SU2
    else:
        !cd /content/SU2 && git pull
    os.chdir('/content/SU2')
    repo_root = Path('/content/SU2')
else:
    notebook_dir = Path(os.getcwd())
    if notebook_dir.name == 'notebooks':
        repo_root = notebook_dir.parent
    else:
        repo_root = notebook_dir

if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

print(f"Repository root: {repo_root}")

In [None]:
import torch
import numpy as np
import pandas as pd
import tifffile
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# Import from stardist_helpers
from modules.stardist_helpers import (
    StarDistLightning,
    run_laptrack,
    ROI_X_MIN, ROI_X_MAX, ROI_Y_MIN, ROI_Y_MAX
)

try:
    from cellseg_models_pytorch.postproc.functional.stardist.stardist import post_proc_stardist
    HAS_STARDIST = True
except ImportError:
    HAS_STARDIST = False
    print("Warning: cellseg_models_pytorch not available")

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

In [None]:
# ============================================================
# CONFIGURATION
# ============================================================

# Paths to trained models (from SU2_StarDist_final.ipynb)
MODELS_DIR = repo_root / "results" / "stardist"
EXPORT_DIR = repo_root / "results" / "stardist" / "competition"

# Model architecture (must match training)
MODEL_CONFIG = {
    "architecture": "StarDist",
    "encoder_name": "resnet18",
    "n_rays": 32,
    "input_channels": 1,
    "dropout": 0.1
}

# Inference configuration (optimal values from training)
INFERENCE_CONFIG = {
    "prob_thresh": 0.3,
    "nms_thresh": 0.3,
    "track_cost_cutoff": 49,           # 7px squared
    "gap_closing_cost_cutoff": 25,     # 5px squared
    "gap_closing_max_frame_count": 2,
    "roi": {
        "x_min": ROI_X_MIN,
        "x_max": ROI_X_MAX,
        "y_min": ROI_Y_MIN,
        "y_max": ROI_Y_MAX
    }
}

# Number of folds
K_FOLDS = 5

print("Configuration:")
print(f"  Models dir: {MODELS_DIR}")
print(f"  Export dir: {EXPORT_DIR}")
print(f"  Model: {MODEL_CONFIG['encoder_name']}, n_rays={MODEL_CONFIG['n_rays']}")
print(f"  Thresholds: prob={INFERENCE_CONFIG['prob_thresh']}, nms={INFERENCE_CONFIG['nms_thresh']}")

## 2. Load Trained Models

In [None]:
def load_model_from_checkpoint(checkpoint_path, config=MODEL_CONFIG):
    """
    Load a StarDist model from a Lightning checkpoint.
    
    Args:
        checkpoint_path: Path to .ckpt file
        config: Model configuration dict
    
    Returns:
        Loaded model in eval mode
    """
    model = StarDistLightning(
        n_rays=config['n_rays'],
        encoder_name=config['encoder_name'],
        dropout=config.get('dropout', 0.0)
    )
    
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    
    return model


def load_model_from_state_dict(state_dict_path, config=MODEL_CONFIG):
    """
    Load a StarDist model from a state_dict .pth file.
    
    Args:
        state_dict_path: Path to .pth file
        config: Model configuration dict
    
    Returns:
        Loaded model in eval mode
    """
    model = StarDistLightning(
        n_rays=config['n_rays'],
        encoder_name=config['encoder_name'],
        dropout=config.get('dropout', 0.0)
    )
    
    state_dict = torch.load(state_dict_path, map_location='cpu')
    model.load_state_dict(state_dict)
    model.eval()
    
    return model

In [None]:
# Load all K-fold models
print("Loading K-fold models...")
fold_models = []

for fold in range(1, K_FOLDS + 1):
    # Try checkpoint first, then state_dict
    ckpt_path = MODELS_DIR / f"stardist_fold_{fold}" / "best_model.ckpt"
    pth_path = EXPORT_DIR / "models" / f"fold_{fold}.pth"
    
    if ckpt_path.exists():
        model = load_model_from_checkpoint(ckpt_path)
        fold_models.append(model)
        print(f"  Fold {fold}: Loaded from {ckpt_path}")
    elif pth_path.exists():
        model = load_model_from_state_dict(pth_path)
        fold_models.append(model)
        print(f"  Fold {fold}: Loaded from {pth_path}")
    else:
        print(f"  Fold {fold}: NOT FOUND (expected at {ckpt_path})")

print(f"\nLoaded {len(fold_models)} models")

## 3. Export Competition Artifacts

In [None]:
# Create export directory structure
EXPORT_DIR.mkdir(parents=True, exist_ok=True)
(EXPORT_DIR / "models").mkdir(exist_ok=True)

print(f"Export directory: {EXPORT_DIR}")

In [None]:
# Export all fold models as state_dict (.pth)
print("Exporting fold models...")

for i, model in enumerate(fold_models, start=1):
    save_path = EXPORT_DIR / "models" / f"fold_{i}.pth"
    torch.save(model.state_dict(), save_path)
    print(f"  Saved: {save_path}")

print(f"\nExported {len(fold_models)} models")

In [None]:
# Export configuration files
print("Exporting configuration files...")

# Model config
model_config_path = EXPORT_DIR / "model_config.json"
with open(model_config_path, 'w') as f:
    json.dump(MODEL_CONFIG, f, indent=2)
print(f"  Saved: {model_config_path}")

# Inference config
inference_config_path = EXPORT_DIR / "inference_config.json"
with open(inference_config_path, 'w') as f:
    json.dump(INFERENCE_CONFIG, f, indent=2)
print(f"  Saved: {inference_config_path}")

## 4. Easy-to-Use Inference API

In [None]:
def infer_frame(model, frame, prob_thresh=0.3, nms_thresh=0.3, device=DEVICE):
    """
    Run StarDist inference on a single frame.
    
    Args:
        model: StarDistLightning model
        frame: 2D numpy array (H, W)
        prob_thresh: Probability threshold for detection
        nms_thresh: NMS threshold
        device: torch device
    
    Returns:
        List of (x, y) detection coordinates
    """
    model = model.to(device)
    model.eval()
    
    # Normalize frame
    frame_norm = (frame - frame.mean()) / (frame.std() + 1e-8)
    
    # Prepare tensor
    x = torch.from_numpy(frame_norm).float().unsqueeze(0).unsqueeze(0).to(device)
    
    with torch.no_grad():
        pred_dist, pred_bin = model.model(x)
        pred_dist = pred_dist.cpu().numpy()[0]
        pred_bin = torch.sigmoid(pred_bin).cpu().numpy()[0, 0]
    
    # Post-process
    if HAS_STARDIST:
        labels = post_proc_stardist(
            pred_dist, pred_bin,
            prob_thresh=prob_thresh,
            nms_thresh=nms_thresh
        )
        
        # Extract centroids
        from skimage.measure import regionprops
        detections = []
        for prop in regionprops(labels):
            y, x = prop.centroid
            detections.append((x, y))
        return detections
    else:
        return []


def infer_frame_ensemble(models, frame, prob_thresh=0.3, nms_thresh=0.3, device=DEVICE):
    """
    Run ensemble inference on a single frame.
    Averages probability maps from all models before post-processing.
    
    Args:
        models: List of StarDistLightning models
        frame: 2D numpy array (H, W)
        prob_thresh: Probability threshold for detection
        nms_thresh: NMS threshold
        device: torch device
    
    Returns:
        List of (x, y) detection coordinates
    """
    # Normalize frame
    frame_norm = (frame - frame.mean()) / (frame.std() + 1e-8)
    x = torch.from_numpy(frame_norm).float().unsqueeze(0).unsqueeze(0).to(device)
    
    # Collect predictions from all models
    all_dist = []
    all_bin = []
    
    for model in models:
        model = model.to(device)
        model.eval()
        
        with torch.no_grad():
            pred_dist, pred_bin = model.model(x)
            all_dist.append(pred_dist.cpu().numpy()[0])
            all_bin.append(torch.sigmoid(pred_bin).cpu().numpy()[0, 0])
    
    # Average predictions
    avg_dist = np.mean(all_dist, axis=0)
    avg_bin = np.mean(all_bin, axis=0)
    
    # Post-process
    if HAS_STARDIST:
        labels = post_proc_stardist(
            avg_dist, avg_bin,
            prob_thresh=prob_thresh,
            nms_thresh=nms_thresh
        )
        
        from skimage.measure import regionprops
        detections = []
        for prop in regionprops(labels):
            y, x = prop.centroid
            detections.append((x, y))
        return detections
    else:
        return []

In [None]:
def infer_video(video_path, output_csv=None, use_ensemble=True, 
                models=None, config=None, device=DEVICE):
    """
    Run detection + tracking on a video file.
    
    This is the main inference API for competition submission.
    
    Args:
        video_path: Path to TIFF video file
        output_csv: Optional path to save predictions CSV
        use_ensemble: Use all fold models (True) or first model only (False)
        models: List of models (uses global fold_models if None)
        config: Inference config dict (uses INFERENCE_CONFIG if None)
        device: torch device
    
    Returns:
        DataFrame with columns: frame, x, y, track_id
    
    Example:
        >>> predictions = infer_video('test.tif')
        >>> predictions.to_csv('submission.csv', index=False)
    """
    # Use defaults if not provided
    if models is None:
        models = fold_models
    if config is None:
        config = INFERENCE_CONFIG
    
    if len(models) == 0:
        raise ValueError("No models loaded. Run model loading cell first.")
    
    # Load video
    print(f"Loading video: {video_path}")
    video = tifffile.imread(video_path)
    print(f"  Shape: {video.shape}")
    
    # Extract ROI
    roi = config['roi']
    video_roi = video[:, roi['y_min']:roi['y_max'], roi['x_min']:roi['x_max']]
    print(f"  ROI: y=[{roi['y_min']}, {roi['y_max']}], x=[{roi['x_min']}, {roi['x_max']}]")
    
    # Run inference
    prob_thresh = config['prob_thresh']
    nms_thresh = config['nms_thresh']
    
    print(f"Running inference (ensemble={use_ensemble})...")
    all_detections = []
    
    for frame_idx in tqdm(range(len(video_roi)), desc="Detecting"):
        frame = video_roi[frame_idx].astype(np.float32)
        
        if use_ensemble and len(models) > 1:
            dets = infer_frame_ensemble(models, frame, prob_thresh, nms_thresh, device)
        else:
            dets = infer_frame(models[0], frame, prob_thresh, nms_thresh, device)
        
        for x, y in dets:
            all_detections.append({
                'frame': frame_idx,
                'x': x + roi['x_min'],  # Convert back to full image coords
                'y': y + roi['y_min']
            })
    
    detections_df = pd.DataFrame(all_detections)
    print(f"  Detections: {len(detections_df)}")
    
    if detections_df.empty:
        return pd.DataFrame(columns=['frame', 'x', 'y', 'track_id'])
    
    # Run tracking
    print("Running tracking...")
    tracked_df = run_laptrack(
        detections_df,
        track_cost_cutoff=config['track_cost_cutoff'],
        gap_closing_cost_cutoff=config['gap_closing_cost_cutoff'],
        gap_closing_max_frame_count=config['gap_closing_max_frame_count']
    )
    print(f"  Tracks: {tracked_df['track_id'].nunique()}")
    
    # Save if requested
    if output_csv:
        tracked_df.to_csv(output_csv, index=False)
        print(f"  Saved: {output_csv}")
    
    return tracked_df

## 5. Usage Example & Verification

In [None]:
# Demo: Run inference on validation video
VAL_TIF = repo_root / "data" / "val" / "val.tif"

if VAL_TIF.exists() and len(fold_models) > 0:
    print("Running demo inference on validation video...")
    predictions = infer_video(
        VAL_TIF,
        output_csv=EXPORT_DIR / "demo_predictions.csv",
        use_ensemble=True
    )
    print(f"\nDemo complete! Predictions shape: {predictions.shape}")
else:
    print("Skipping demo: validation video not found or no models loaded")
    print(f"  VAL_TIF exists: {VAL_TIF.exists()}")
    print(f"  Models loaded: {len(fold_models)}")

In [None]:
# Visualize sample predictions
if VAL_TIF.exists() and 'predictions' in dir() and len(predictions) > 0:
    video = tifffile.imread(VAL_TIF)
    roi = INFERENCE_CONFIG['roi']
    video_roi = video[:, roi['y_min']:roi['y_max'], roi['x_min']:roi['x_max']]
    
    sample_frames = [0, 30, 60, 90]
    fig, axes = plt.subplots(1, len(sample_frames), figsize=(16, 4))
    
    for ax, fidx in zip(axes, sample_frames):
        if fidx >= len(video_roi):
            continue
        
        ax.imshow(video_roi[fidx], cmap='gray')
        
        # Plot predictions
        frame_preds = predictions[predictions.frame == fidx]
        ax.scatter(
            frame_preds.x - roi['x_min'],
            frame_preds.y - roi['y_min'],
            c='red', s=20, marker='x', alpha=0.8
        )
        
        ax.set_title(f'Frame {fidx} ({len(frame_preds)} detections)')
        ax.axis('off')
    
    plt.suptitle('Ensemble Predictions', fontsize=14)
    plt.tight_layout()
    plt.savefig(EXPORT_DIR / 'demo_visualization.png', dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
# Print summary of exported files
print("\n" + "="*60)
print("EXPORT SUMMARY")
print("="*60)
print(f"\nExport directory: {EXPORT_DIR}")
print("\nExported files:")

for item in sorted(EXPORT_DIR.rglob('*')):
    if item.is_file():
        rel_path = item.relative_to(EXPORT_DIR)
        size_kb = item.stat().st_size / 1024
        print(f"  {rel_path} ({size_kb:.1f} KB)")

print("\n" + "="*60)
print("USAGE INSTRUCTIONS")
print("="*60)
print("""
To run inference on new data:

1. Load this notebook or import the inference function
2. Call: predictions = infer_video('path/to/test.tif')
3. Save: predictions.to_csv('submission.csv', index=False)

The output CSV will have columns: frame, x, y, track_id
""")
print("="*60)