[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/veselm73/SU2/blob/main/notebooks/SU2_StarDist_inference.ipynb)

# StarDist Cell Detection & Tracking

**Authors:** Matyáš Veselý, Ruslan Guliev

This notebook provides inference with a pre-trained **StarDist** ensemble for cell detection in TIRF-SIM microscopy images, followed by **LapTrack** for cell tracking.

---

## Model Overview

**StarDist** is a deep learning method for object detection that predicts star-convex polygons for each object. Key features:

- **Architecture**: ResNet18 encoder + StarDist decoder with 64 radial rays
- **Training**: 5-fold cross-validation ensemble (predictions averaged)
- **Output**: Probability map + 64 ray distances → NMS → centroid coordinates

### Training Configuration

| Parameter | Value |
|-----------|-------|
| Encoder | ResNet18 |
| N_Rays | 64 |
| Epochs | 100 |
| Augmentation | None |
| Input Size | 256×256 (ROI crop) |
| Normalization | Percentile (1st, 99.8th) |

### K-Fold Training Results (OOF)

Out-of-Fold DetA scores using fixed threshold (prob=0.5, nms=0.3):

| Fold | DetA | Epochs |
|------|------|--------|
| 1 | 0.8261 | 100 |
| 2 | 0.7691 | 100 |
| 3 | 0.8259 | 100 |
| 4 | 0.8149 | 100 |
| 5 | 0.8286 | 100 |
| **Mean** | **0.8129 ± 0.0224** | |

### Optimized Inference Thresholds

After threshold sweep on OOF predictions:
- **prob_thresh**: 0.6
- **nms_thresh**: 0.35

---

## 1. Setup

In [1]:
# Install dependencies
!pip install uv -q
!uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121 --system -q
!uv pip install "numpy<2" cellseg-models-pytorch pytorch-lightning tifffile scipy laptrack --system -q

In [2]:
import sys
import os
from pathlib import Path

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    if not Path('/content/SU2').exists():
        !git clone https://github.com/veselm73/SU2.git /content/SU2
    os.chdir('/content/SU2')
    repo_root = Path('/content/SU2')
else:
    repo_root = Path('.').resolve()
    if repo_root.name == 'notebooks':
        repo_root = repo_root.parent

print(f"Repository: {repo_root}")

Repository: /content/SU2


In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.auto import tqdm
from skimage.measure import regionprops
import urllib.request
import json
import tifffile
import networkx as nx
from scipy import spatial
from scipy.optimize import linear_sum_assignment
from laptrack import LapTrack

from cellseg_models_pytorch.postproc.functional.stardist.stardist import post_proc_stardist

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {DEVICE}")

## 2. Download Model Weights

In [4]:
MODEL_NAME = "100e_noaug_64rays"
BASE_URL = "https://raw.githubusercontent.com/veselm73/SU2/main"

def download_model_weights(model_name):
    """Download model weights from GitHub if not present."""
    weights_dir = repo_root / "weights" / model_name
    models_dir = weights_dir / "models"

    if models_dir.exists() and len(list(models_dir.glob("*.pth"))) == 5:
        print(f"Weights found locally: {weights_dir}")
        return weights_dir

    print(f"Downloading {model_name} weights...")
    weights_dir.mkdir(parents=True, exist_ok=True)
    models_dir.mkdir(parents=True, exist_ok=True)

    base_url = f"{BASE_URL}/weights/{model_name}"

    for cfg in ["model_config.json", "inference_config.json"]:
        urllib.request.urlretrieve(f"{base_url}/{cfg}", weights_dir / cfg)

    for fold in range(1, 6):
        print(f"  Downloading fold_{fold}.pth (~53MB)...")
        urllib.request.urlretrieve(f"{base_url}/models/fold_{fold}.pth", models_dir / f"fold_{fold}.pth")

    print("Download complete!")
    return weights_dir

WEIGHTS_DIR = download_model_weights(MODEL_NAME)

Weights found locally: /content/SU2/weights/100e_noaug_64rays


## 3. Model Definition & Loading

In [5]:
import pytorch_lightning as pl
from cellseg_models_pytorch.models.stardist.stardist import StarDist

class StarDistModel(pl.LightningModule):
    def __init__(self, n_rays=64, encoder_name="resnet18"):
        super().__init__()
        self.n_rays = n_rays
        wrapper = StarDist(
            n_nuc_classes=1, n_rays=n_rays, enc_name=encoder_name,
            model_kwargs={"encoder_kws": {"in_chans": 1}}
        )
        self.model = wrapper.model

    def forward(self, x):
        return self.model(x)


def load_ensemble(weights_dir):
    """Load 5-fold ensemble."""
    with open(weights_dir / "model_config.json") as f:
        config = json.load(f)
    with open(weights_dir / "inference_config.json") as f:
        inf_config = json.load(f)

    models = []
    for fold in range(1, 6):
        model = StarDistModel(n_rays=config['n_rays'], encoder_name=config['encoder_name'])
        model.load_state_dict(torch.load(weights_dir / "models" / f"fold_{fold}.pth", map_location='cpu'))
        model.eval().to(DEVICE)
        models.append(model)

    print(f"Loaded {len(models)} models: {config['encoder_name']}, n_rays={config['n_rays']}")
    print(f"Thresholds: prob={inf_config['prob_thresh']}, nms={inf_config['nms_thresh']}")

    return models, inf_config

# Load ensemble
models, INF_CONFIG = load_ensemble(WEIGHTS_DIR)
PROB_THRESH = INF_CONFIG['prob_thresh']
NMS_THRESH = INF_CONFIG['nms_thresh']



model.safetensors:   0%|          | 0.00/46.8M [00:00<?, ?B/s]

Loaded 5 models: resnet18, n_rays=64
Thresholds: prob=0.6, nms=0.35


## 4. Inference Functions

In [6]:
# ROI configuration
ROI = {'x_min': 256, 'x_max': 512, 'y_min': 512, 'y_max': 768}

def preprocess(frame):
    """Percentile normalization."""
    frame = frame.astype(np.float32)
    p1, p99 = np.percentile(frame, (1, 99.8))
    return np.clip((frame - p1) / (p99 - p1 + 1e-8), 0, 1)


def detect_single_frame(models, frame, prob_thresh, nms_thresh):
    """Run ensemble detection on a single frame."""
    x = torch.from_numpy(frame).float().unsqueeze(0).unsqueeze(0).to(DEVICE)

    all_stardist, all_prob = [], []
    with torch.no_grad():
        for model in models:
            out = model(x)['nuc']
            all_stardist.append(out.aux_map.cpu().numpy()[0])
            all_prob.append(torch.sigmoid(out.binary_map).cpu().numpy()[0, 0])

    try:
        labels = post_proc_stardist(
            np.mean(all_prob, axis=0), np.mean(all_stardist, axis=0),
            score_thresh=prob_thresh, iou_thresh=nms_thresh
        )
        return [(p.centroid[1], p.centroid[0]) for p in regionprops(labels)]
    except:
        return []


def run_detection(video, models, prob_thresh, nms_thresh, roi=None):
    """Run detection on entire video."""
    if roi:
        video_roi = video[:, roi['y_min']:roi['y_max'], roi['x_min']:roi['x_max']]
    else:
        video_roi = video
        roi = {'x_min': 0, 'y_min': 0}

    all_detections = []
    detections_per_frame = []

    for frame_idx in tqdm(range(len(video_roi)), desc="Detecting"):
        frame = preprocess(video_roi[frame_idx])
        dets = detect_single_frame(models, frame, prob_thresh, nms_thresh)

        # Store for tracking (local coords)
        detections_per_frame.append(dets)

        # Store for output (global coords)
        for x, y in dets:
            all_detections.append({
                'frame': frame_idx,
                'x': x + roi['x_min'],
                'y': y + roi['y_min']
            })

    return pd.DataFrame(all_detections), detections_per_frame

print("Inference functions defined.")

Inference functions defined.


## 5. Tracking with LapTrack

**LapTrack** uses Linear Assignment Problem (LAP) optimization for frame-to-frame linking with gap closing.

Best configuration from benchmark (HOTA=0.9406):
- `track_cost_cutoff`: 25 (5px squared)
- `gap_closing_cost_cutoff`: 49 (7px squared)  
- `gap_closing_max_frame_count`: 1

In [7]:
# Best LapTrack configuration (HOTA=0.9406 on GT benchmark)
LAPTRACK_CONFIG = {
    'track_cost_cutoff': 25,        # 5px squared
    'gap_closing_cost_cutoff': 49,  # 7px squared
    'gap_closing_max_frame_count': 1
}

def run_tracking(detections_per_frame, config=LAPTRACK_CONFIG):
    """Track detections using LapTrack."""
    if len(detections_per_frame) == 0:
        return pd.DataFrame(columns=['frame', 'x', 'y', 'track_id'])

    # Prepare coordinates
    coords_per_frame = []
    for dets in detections_per_frame:
        if len(dets) > 0:
            coords_per_frame.append(np.array([[x, y] for x, y in dets]))
        else:
            coords_per_frame.append(np.empty((0, 2)))

    # Run tracking
    tracker = LapTrack(
        track_cost_cutoff=config['track_cost_cutoff'],
        gap_closing_cost_cutoff=config['gap_closing_cost_cutoff'],
        gap_closing_max_frame_count=config['gap_closing_max_frame_count']
    )

    graph = tracker.predict(coords_per_frame)

    # Extract tracks
    records = []
    for track_id, component in enumerate(nx.weakly_connected_components(graph)):
        for node in component:
            frame_idx, det_idx = node
            if frame_idx < len(coords_per_frame) and det_idx < len(coords_per_frame[frame_idx]):
                x, y = coords_per_frame[frame_idx][det_idx]
                records.append({'frame': int(frame_idx), 'x': float(x), 'y': float(y), 'track_id': int(track_id)})

    if not records:
        return pd.DataFrame(columns=['frame', 'x', 'y', 'track_id'])

    return pd.DataFrame(records).sort_values(['track_id', 'frame']).reset_index(drop=True)

print(f"LapTrack config: track={np.sqrt(LAPTRACK_CONFIG['track_cost_cutoff']):.0f}px, gap={np.sqrt(LAPTRACK_CONFIG['gap_closing_cost_cutoff']):.0f}px, max_frames={LAPTRACK_CONFIG['gap_closing_max_frame_count']}")

LapTrack config: track=5px, gap=7px, max_frames=1


---

## 6. Load Test Data

**Configure your local data path below.** The data folder should contain:
- `val.tif` - validation video (TIFF stack)
- `val.csv` - ground truth annotations (columns: frame, x, y, track_id)
- `sota.csv` - SOTA predictions for comparison (columns: frame, x, y, track_id)

In [None]:
# Upload test data files (Colab only)
# Required files: val.tif, val.csv, sota.csv

if IN_COLAB:
    from google.colab import files
    
    DATA_DIR = Path("/content/data")
    DATA_DIR.mkdir(exist_ok=True)
    
    print("Upload your test data files (val.tif, val.csv, sota.csv):")
    uploaded = files.upload()
    
    for filename, content in uploaded.items():
        filepath = DATA_DIR / filename
        with open(filepath, 'wb') as f:
            f.write(content)
        print(f"Saved: {filepath}")
else:
    # Local mode: configure your data path here
    DATA_DIR = Path(r"C:\Users\Mateusz\SU2\data\test_and_sota")  # <-- Change this for local use

print(f"\nData directory: {DATA_DIR}")

In [None]:
# Verify data files exist
required_files = ["val.tif", "val.csv", "sota.csv"]
missing = [f for f in required_files if not (DATA_DIR / f).exists()]
if missing:
    raise FileNotFoundError(f"Missing files in {DATA_DIR}: {missing}")

# Load video
VIDEO_PATH = DATA_DIR / "val.tif"
video = tifffile.imread(VIDEO_PATH)
print(f"Loaded: {VIDEO_PATH}")
print(f"Shape: {video.shape} (frames, height, width)")

# Load ground truth annotations
GT_PATH = DATA_DIR / "val.csv"
gt_df = pd.read_csv(GT_PATH)
print(f"\nGround truth: {len(gt_df)} detections, {gt_df['track_id'].nunique()} tracks")

# Load SOTA predictions
SOTA_PATH = DATA_DIR / "sota.csv"
sota_df = pd.read_csv(SOTA_PATH)
print(f"SOTA predictions: {len(sota_df)} detections, {sota_df['track_id'].nunique()} tracks")

In [None]:
# Run detection
print("\nRunning detection...")
detections_df, detections_per_frame = run_detection(
    video, models, PROB_THRESH, NMS_THRESH, roi=ROI
)
print(f"Detected {len(detections_df)} cells across {len(video)} frames")
print(f"Average: {len(detections_df) / len(video):.1f} cells/frame")

In [None]:
# Run tracking
print("\nRunning tracking...")
tracks_df = run_tracking(detections_per_frame)

# Adjust to global coordinates
tracks_df['x'] += ROI['x_min']
tracks_df['y'] += ROI['y_min']

print(f"Found {tracks_df['track_id'].nunique()} tracks")
print(f"Average track length: {tracks_df.groupby('track_id').size().mean():.1f} frames")

In [None]:
# Visualize sample frame
SAMPLE_FRAME = 0

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Full frame with ROI
ax = axes[0]
ax.imshow(video[SAMPLE_FRAME], cmap='gray')
rect = plt.Rectangle((ROI['x_min'], ROI['y_min']),
                      ROI['x_max']-ROI['x_min'], ROI['y_max']-ROI['y_min'],
                      linewidth=2, edgecolor='lime', facecolor='none')
ax.add_patch(rect)
ax.set_title(f'Frame {SAMPLE_FRAME} - Full (ROI in green)')
ax.axis('off')

# ROI with detections
ax = axes[1]
roi_frame = video[SAMPLE_FRAME, ROI['y_min']:ROI['y_max'], ROI['x_min']:ROI['x_max']]
ax.imshow(roi_frame, cmap='gray')

frame_dets = detections_df[detections_df['frame'] == SAMPLE_FRAME]
if len(frame_dets) > 0:
    ax.scatter(frame_dets['x'] - ROI['x_min'], frame_dets['y'] - ROI['y_min'],
               c='red', s=50, marker='x', linewidths=1.5)
ax.set_title(f'Frame {SAMPLE_FRAME} - ROI with Detections ({len(frame_dets)})')
ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Visualize tracks
fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(video[0, ROI['y_min']:ROI['y_max'], ROI['x_min']:ROI['x_max']], cmap='gray', alpha=0.5)

# Plot sample tracks
sample_tracks = np.random.choice(tracks_df['track_id'].unique(),
                                  size=min(30, tracks_df['track_id'].nunique()),
                                  replace=False)

for tid in sample_tracks:
    track = tracks_df[tracks_df['track_id'] == tid].sort_values('frame')
    ax.plot(track['x'] - ROI['x_min'], track['y'] - ROI['y_min'],
            linewidth=1.5, alpha=0.7)

ax.set_title(f'Sample Trajectories (n={len(sample_tracks)})')
ax.axis('off')
plt.show()

---

## 7. Compare with SOTA

Compare StarDist predictions against ground truth and SOTA method using HOTA metrics.

In [None]:
# Evaluation metric functions
def calculate_deta(gt_df, pred_df, match_thresh=5.0):
    """Calculate Detection Accuracy (DetA) using Hungarian matching."""
    gt_df = gt_df.copy()
    pred_df = pred_df.copy()
    gt_df['frame'] = gt_df['frame'].astype(int)
    pred_df['frame'] = pred_df['frame'].astype(int)

    total_tp, total_fp, total_fn = 0, 0, 0

    for frame in gt_df['frame'].unique():
        gt_frame = gt_df[gt_df['frame'] == frame][['x', 'y']].values
        pred_frame = pred_df[pred_df['frame'] == frame][['x', 'y']].values

        if len(gt_frame) == 0:
            total_fp += len(pred_frame)
            continue
        if len(pred_frame) == 0:
            total_fn += len(gt_frame)
            continue

        dist_matrix = spatial.distance.cdist(gt_frame, pred_frame)
        row_ind, col_ind = linear_sum_assignment(dist_matrix)
        matches = sum(dist_matrix[row_ind[i], col_ind[i]] <= match_thresh for i in range(len(row_ind)))

        total_tp += matches
        total_fp += len(pred_frame) - matches
        total_fn += len(gt_frame) - matches

    deta = total_tp / max(1, total_tp + total_fp + total_fn)
    return deta, {'TP': total_tp, 'FP': total_fp, 'FN': total_fn}


def hota(gt: pd.DataFrame, tr: pd.DataFrame, threshold: float = 5) -> dict:
    """Calculate HOTA (Higher Order Tracking Accuracy) metric."""
    gt = gt.copy()
    tr = tr.copy()

    if 'track_id' not in gt.columns or 'track_id' not in tr.columns:
        return {'HOTA': 0.0, 'AssA': 0.0, 'DetA': 0.0}
    if gt.empty or tr.empty:
        return {'HOTA': 0.0, 'AssA': 0.0, 'DetA': 0.0}

    gt.track_id = gt.track_id.map({old: new for old, new in zip(gt.track_id.unique(), range(gt.track_id.nunique()))})
    tr.track_id = tr.track_id.map({old: new for old, new in zip(tr.track_id.unique(), range(tr.track_id.nunique()))})

    num_gt_ids = gt.track_id.nunique()
    num_tr_ids = tr.track_id.nunique()

    matches_count = np.zeros((num_gt_ids, num_tr_ids))
    gt_id_count = np.zeros((num_gt_ids, 1))
    tracker_id_count = np.zeros((1, num_tr_ids))
    HOTA_TP = HOTA_FN = HOTA_FP = 0

    similarities = [1 - np.clip(spatial.distance.cdist(
        gt[gt.frame == t][['x', 'y']].values,
        tr[tr.frame == t][['x', 'y']].values
    ) / threshold, 0, 1) for t in range(int(gt.frame.max()) + 1)]

    for t in range(int(gt.frame.max()) + 1):
        gt_ids_t = gt[gt.frame == t].track_id.to_numpy()
        tr_ids_t = tr[tr.frame == t].track_id.to_numpy()

        gt_id_count[gt_ids_t] += 1
        tracker_id_count[:, tr_ids_t] += 1

        if len(gt_ids_t) == 0:
            HOTA_FP += len(tr_ids_t)
            continue
        if len(tr_ids_t) == 0:
            HOTA_FN += len(gt_ids_t)
            continue

        similarity = similarities[t]
        match_rows, match_cols = linear_sum_assignment(-similarity)
        mask = similarity[match_rows, match_cols] > 0
        alpha_match_rows = match_rows[mask]
        alpha_match_cols = match_cols[mask]
        num_matches = len(alpha_match_rows)

        HOTA_TP += num_matches
        HOTA_FN += len(gt_ids_t) - num_matches
        HOTA_FP += len(tr_ids_t) - num_matches

        if num_matches > 0:
            matches_count[gt_ids_t[alpha_match_rows], tr_ids_t[alpha_match_cols]] += 1

    ass_a = matches_count / np.maximum(1, gt_id_count + tracker_id_count - matches_count)
    AssA = np.sum(matches_count * ass_a) / np.maximum(1, HOTA_TP)
    DetA = HOTA_TP / np.maximum(1, HOTA_TP + HOTA_FN + HOTA_FP)
    HOTA_score = np.sqrt(DetA * AssA)

    return {'HOTA': float(HOTA_score), 'AssA': float(AssA), 'DetA': float(DetA)}

print("Evaluation functions defined.")

In [None]:
# Run comparison
print("="*60)
print("PERFORMANCE COMPARISON")
print("="*60)

# StarDist metrics
stardist_hota = hota(gt_df, tracks_df)
print(f"\nStarDist (Ours):")
print(f"  HOTA={stardist_hota['HOTA']:.4f}, DetA={stardist_hota['DetA']:.4f}, AssA={stardist_hota['AssA']:.4f}")
print(f"  Tracks: {tracks_df['track_id'].nunique()}, Detections: {len(tracks_df)}")

# SOTA metrics
sota_hota = hota(gt_df, sota_df)
print(f"\nSOTA:")
print(f"  HOTA={sota_hota['HOTA']:.4f}, DetA={sota_hota['DetA']:.4f}, AssA={sota_hota['AssA']:.4f}")
print(f"  Tracks: {sota_df['track_id'].nunique()}, Detections: {len(sota_df)}")

# Ground truth stats
print(f"\nGround Truth:")
print(f"  Tracks: {gt_df['track_id'].nunique()}, Detections: {len(gt_df)}")

# Comparison table
print("\n" + "="*60)
print("COMPARISON TABLE")
print("="*60)
comparison_df = pd.DataFrame({
    'Method': ['StarDist (Ours)', 'SOTA'],
    'HOTA': [stardist_hota['HOTA'], sota_hota['HOTA']],
    'DetA': [stardist_hota['DetA'], sota_hota['DetA']],
    'AssA': [stardist_hota['AssA'], sota_hota['AssA']],
    'Tracks': [tracks_df['track_id'].nunique(), sota_df['track_id'].nunique()]
})
print(comparison_df.to_string(index=False))

# Bar chart
fig, ax = plt.subplots(figsize=(10, 6))
metrics = ['HOTA', 'DetA', 'AssA']
x = np.arange(len(metrics))
width = 0.35
bars1 = ax.bar(x - width/2, [stardist_hota[m] for m in metrics], width, label='StarDist (Ours)', color='steelblue')
bars2 = ax.bar(x + width/2, [sota_hota[m] for m in metrics], width, label='SOTA', color='coral')
ax.set_ylabel('Score')
ax.set_title('Tracking Performance Comparison')
ax.set_xticks(x)
ax.set_xticklabels(metrics)
ax.legend()
ax.set_ylim(0, 1)
for bar in bars1 + bars2:
    ax.annotate(f'{bar.get_height():.3f}', xy=(bar.get_x() + bar.get_width()/2, bar.get_height()),
                xytext=(0, 3), textcoords='offset points', ha='center', va='bottom', fontsize=9)
plt.tight_layout()
plt.show()

In [None]:
# Save results to local data directory
output_dir = DATA_DIR

# Save detections
det_path = output_dir / "stardist_detections.csv"
detections_df.to_csv(det_path, index=False)
print(f"Detections saved: {det_path}")

# Save tracks
tracks_path = output_dir / "stardist_tracks.csv"
tracks_df.to_csv(tracks_path, index=False)
print(f"Tracks saved: {tracks_path}")

print(f"\nNote: Results saved to local folder (not in git repo)")

---

## Summary

This notebook:
1. Downloads **StarDist ensemble** weights (5-fold, ResNet18, 64 rays) from GitHub
2. Loads test data from your **local data folder** (configure `DATA_DIR` in Section 6)
3. Runs cell detection and **LapTrack** tracking on the validation video
4. Compares performance against **SOTA method** using HOTA, DetA, AssA metrics
5. Saves results to your local data folder

**Required files in DATA_DIR:**
- `val.tif` - validation video
- `val.csv` - ground truth annotations
- `sota.csv` - SOTA predictions

For questions or issues, see: https://github.com/veselm73/SU2