In [None]:
# Install dependencies
!pip install -q cellseg-models-pytorch pytorch-lightning "btrack==0.6.5" "laptrack" "pydantic<2" "albumentations==1.3.1" "numpy<2" opencv-python-headless pandas scipy scikit-image scikit-learn matplotlib seaborn tqdm ipywidgets tifffile numba
# Uninstall TensorFlow to avoid conflicts (if any)
!pip uninstall -y tensorflow tensorflow-intel tensorflow-cpu stardist csbdeep || true


In [None]:
import os
import sys
import shutil
import random
import math
import warnings
import json
from pathlib import Path
from typing import Optional, List, Tuple, Dict, Any

import numpy as np
import pandas as pd
import cv2
import tifffile
import matplotlib.pyplot as plt
from skimage.measure import regionprops, label
from skimage.draw import disk
from sklearn.model_selection import KFold
import requests
import zipfile

# PyTorch & Lightning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

# StarDist (PyTorch)
try:
    from cellseg_models_pytorch.models.stardist.stardist import StarDist
    from cellseg_models_pytorch.transforms.functional.stardist import gen_stardist_maps
    from cellseg_models_pytorch.postproc.functional.stardist.stardist import post_proc_stardist
    print("StarDist (PyTorch) libraries imported successfully.")
except ImportError as e:
    print(f"Error importing cellseg-models-pytorch: {e}")
    # Fallback/Exit handled by installation cell usually

# Tracking
try:
    import btrack
    print(f"btrack version: {btrack.__version__}")
except ImportError:
    print("btrack not found.")
    btrack = None

try:
    import laptrack
    from laptrack import LapTrack
    print(f"laptrack version: {laptrack.__version__}")
except ImportError:
    print("laptrack not found.")
    laptrack = None

# --- Constants ---
ROI_Y_MIN, ROI_Y_MAX = 512, 768
ROI_X_MIN, ROI_X_MAX = 256, 512
ROI_H, ROI_W = ROI_Y_MAX - ROI_Y_MIN, ROI_X_MAX - ROI_X_MIN

N_RAYS = 32
BATCH_SIZE = 4
EPOCHS = 20  # Adjust as needed
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")


In [None]:
def download_validation_data(target_dir: str = "val_data",
                             url: str = "https://su2.utia.cas.cz/files/labs/final2025/val_and_sota.zip",
                             cert_url: str = "https://pki.cesnet.cz/_media/certs/chain-harica-rsa-ov-crosssigned-root.pem"):
    if os.path.exists(target_dir) and len(os.listdir(target_dir)) > 0:
        print(f"'{target_dir}' already exists. Skipping download.")
        return

    chain_path = "chain-harica-cross.pem"
    print("1) Downloading SSL certificate chain...")
    try:
        r = requests.get(cert_url, timeout=10, stream=True)
        r.raise_for_status()
        with open(chain_path, "wb") as f:
            f.write(r.content)
    except Exception as e:
        print(f"Warning: Could not download cert chain ({e}), trying without...")
        chain_path = None

    print("2) Downloading validation archive...")
    zip_name = os.path.basename(url)
    
    verify = chain_path if chain_path else False
    
    with requests.get(url, stream=True, verify=verify, timeout=30) as resp:
        resp.raise_for_status()
        with open(zip_name, "wb") as f:
            for chunk in resp.iter_content(chunk_size=8192):
                if chunk:
                    f.write(chunk)
    
    print("3) Extracting...")
    os.makedirs(target_dir, exist_ok=True)
    with zipfile.ZipFile(zip_name, "r") as zf:
        zf.extractall(target_dir)
    if os.path.exists(zip_name):
        os.remove(zip_name)
    print(f"Done. Data in '{target_dir}/'")

# Download data
download_validation_data()


In [None]:
# Prepare 'experiment_dataset' with StarDist mask style (label images)
def create_stardist_label_mask(image_shape: Tuple[int, int], points, radius: int = 6) -> np.ndarray:
    '''Convert center points into a StarDist-ready integer mask (instance labeled).'''
    label_mask = np.zeros(image_shape, dtype=np.uint16)
    current_id = 1
    for (y, x) in points:
        if y < 0 or x < 0 or y >= image_shape[0] or x >= image_shape[1]:
            continue
        rr, cc = disk((y, x), radius, shape=image_shape)
        label_mask[rr, cc] = current_id
        current_id += 1
    return label_mask

def _load_mask_image(path: Path) -> np.ndarray:
    if path.suffix.lower() in {'.tif', '.tiff'}:
        mask = tifffile.imread(path)
    else:
        mask = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
    if mask is None:
        raise FileNotFoundError(f"Missing mask: {path}")
    if mask.ndim == 3:
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
    return mask

def _label_connected(mask: np.ndarray) -> np.ndarray:
    return label(mask > 0).astype(np.uint16)

def prepare_grand_dataset(
    real_data_dir: str = "real_training_data",
    val_tif_path: str = "val_data/val.tif",
    val_csv_path: str = "val_data/val.csv",
    out_dir: str = "experiment_dataset",
    roi_y: Tuple[int, int] = (ROI_Y_MIN, ROI_Y_MAX),
    roi_x: Tuple[int, int] = (ROI_X_MIN, ROI_X_MAX),
    disk_radius: int = 6,
) -> Path:
    out_path = Path(out_dir)
    bonus_images = out_path / "bonus" / "images"
    bonus_masks = out_path / "bonus" / "masks"
    video_images = out_path / "video" / "images"
    video_masks = out_path / "video" / "masks"

    if out_path.exists():
        shutil.rmtree(out_path)
    for p in [bonus_images, bonus_masks, video_images, video_masks]:
        p.mkdir(parents=True, exist_ok=True)

    # 1. Real Bonus Data
    src_images = Path(real_data_dir) / "images"
    src_masks = Path(real_data_dir) / "masks"
    if src_images.exists() and src_masks.exists():
        print("Processing real training data (bonus)...")
        mask_lookup = {m.stem: m for m in src_masks.glob("*")}
        for img_path in sorted(src_images.glob("*")):
            shutil.copy(img_path, bonus_images / img_path.name)
            # Find corresponding mask
            expected_mask_stem = f"{img_path.stem}_mask"
            if expected_mask_stem not in mask_lookup:
                 # Fallback try exact match
                 expected_mask_stem = img_path.stem
            
            if expected_mask_stem in mask_lookup:
                mask_path = mask_lookup[expected_mask_stem]
                bonus_mask_raw = _load_mask_image(mask_path)
                # Convert binary to instance label
                instance_mask = _label_connected(bonus_mask_raw)
                tifffile.imwrite(bonus_masks / f"{img_path.stem}.tif", instance_mask)
            else:
                print(f"Warning: No mask found for {img_path.name}")
    else:
        print(f"Note: '{real_data_dir}' not found. Skipping bonus data.")

    # 2. Validation Video Data
    if Path(val_tif_path).exists() and Path(val_csv_path).exists():
        print("Processing validation video data...")
        video = tifffile.imread(val_tif_path)
        coords = pd.read_csv(val_csv_path)
        y_min, y_max = roi_y
        x_min, x_max = roi_x
        records = []
        
        for idx, frame in enumerate(video):
            crop = frame[y_min:y_max, x_min:x_max]
            # Get points for this frame
            frame_coords = coords[coords['frame'] == idx]
            points = [(int(row['y'] - y_min), int(row['x'] - x_min)) for _, row in frame_coords.iterrows()]
            
            mask = create_stardist_label_mask((ROI_H, ROI_W), points, radius=disk_radius)
            
            video_images_idx_path = video_images / f"frame_{idx:03d}.png"
            video_masks_idx_path = video_masks / f"frame_{idx:03d}.tif"
            
            cv2.imwrite(str(video_images_idx_path), crop)
            tifffile.imwrite(str(video_masks_idx_path), mask)
            records.append({
                'filename': video_images_idx_path.name,
                'real_frame_idx': idx
            })
            
        video_map_df = pd.DataFrame(records)
        video_map_df.to_csv(out_path / "video_map.csv", index=False)
        print(f"Processed {len(video)} video frames.")
    else:
        print("Validation data not found.")

    return out_path

# Execute Preparation
OUT_DIR = prepare_grand_dataset()
print(f"Dataset ready at {OUT_DIR}")


In [None]:
def _load_gray(path: Path) -> np.ndarray:
    if path.suffix.lower() in {'.tif', '.tiff'}:
        arr = tifffile.imread(path)
    else:
        arr = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
    if arr is None: raise FileNotFoundError(f"Missing image: {path}")
    if arr.ndim == 3: arr = cv2.cvtColor(arr, cv2.COLOR_BGR2GRAY)
    return arr

def _load_label_mask(path: Path) -> np.ndarray:
    mask = tifffile.imread(path)
    if mask.ndim == 3: mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
    return mask.astype(np.int32)

class StarDistDataset(Dataset):
    def __init__(self, pairs, n_rays=32):
        self.pairs = pairs
        self.n_rays = n_rays
        
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        img_path, mask_path = self.pairs[idx]
        img = _load_gray(img_path)
        mask = _load_label_mask(mask_path)
        
        # Normalize Image (1-99.8 percentile)
        p1, p99 = np.percentile(img, (1, 99.8))
        img = np.clip(img, p1, p99)
        img = (img - p1) / (p99 - p1 + 1e-8)
        img = img.astype(np.float32)
        
        # Generate Targets
        # gen_stardist_maps returns (n_rays, H, W)
        dist_map = gen_stardist_maps(mask, n_rays=self.n_rays)
        
        # Binary map (Prob) - Shape (1, H, W)
        binary_map = (mask > 0).astype(np.float32)
        binary_map = binary_map[np.newaxis, ...]
        
        # Image tensor (1, H, W)
        img_tensor = torch.from_numpy(img[np.newaxis, ...])
        dist_tensor = torch.from_numpy(dist_map)
        binary_tensor = torch.from_numpy(binary_map)
        
        return {
            "image": img_tensor,
            "stardist_map": dist_tensor, 
            "binary_map": binary_tensor,
            "id": str(img_path.name)
        }


In [None]:
class StarDistLightning(pl.LightningModule):
    def __init__(self, n_rays=32):
        super().__init__()
        # Initialize StarDist model from cellseg_models_pytorch
        # We use a ResNet18 encoder, 1 input channel (gray), 1 output class (foreground)
        wrapper = StarDist(
            n_nuc_classes=1, 
            n_rays=n_rays,
            enc_name="resnet18",
            model_kwargs={"encoder_kws": {"in_chans": 1}}
        )
        self.model = wrapper.model
        self.lambda_dist = 1.0 # Weight for distance loss

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images = batch["image"]
        gt_dist = batch["stardist_map"] # (B, n_rays, H, W)
        gt_bin = batch["binary_map"]    # (B, 1, H, W)
        
        out = self(images) # Returns dict keys: nuc, etc.
        nuc_out = out["nuc"]
        
        pred_dist = nuc_out.aux_map   # (B, n_rays, H, W)
        pred_bin = nuc_out.binary_map # (B, 1, H, W)
        
        # Losses
        # 1. Binary Loss (BCE)
        loss_prob = F.binary_cross_entropy_with_logits(pred_bin, gt_bin)
        
        # 2. Dist Loss (L1, masked by object presence)
        l1 = F.l1_loss(pred_dist, gt_dist, reduction='none')
        mask = gt_bin.expand_as(l1)
        loss_dist = (l1 * mask).sum() / (mask.sum() + 1e-8)
        
        total_loss = loss_prob + self.lambda_dist * loss_dist
        
        self.log("train_loss", total_loss, prog_bar=True)
        self.log("loss_prob", loss_prob)
        self.log("loss_dist", loss_dist)
        
        return total_loss

    def validation_step(self, batch, batch_idx):
        loss = self.training_step(batch, batch_idx)
        self.log("val_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)


In [None]:
def _list_image_mask_pairs(root: Path):
    images = sorted((Path(root) / "images").glob("*"))
    mask_lookup = {}
    masks_root = Path(root) / "masks"
    for ext in ("*.tif", "*.tiff", "*.png"):
        for p in masks_root.glob(ext):
            mask_lookup[p.stem] = p
    pairs = []
    for img_path in images:
        mask_path = mask_lookup.get(img_path.stem)
        if mask_path: pairs.append((img_path, mask_path))
    return pairs


In [None]:
def train_kfold(k_splits=2, epochs=20):
    SAVE_DIR = Path("stardist_checkpoints")
    SAVE_DIR.mkdir(parents=True, exist_ok=True)
    
    video_root = Path("experiment_dataset/video")
    bonus_root = Path("experiment_dataset/bonus")
    
    pairs = _list_image_mask_pairs(video_root)
    # Include bonus data?
    bonus_pairs = _list_image_mask_pairs(bonus_root)
    if bonus_pairs:
        print(f"Adding {len(bonus_pairs)} bonus pairs to training.")
        pairs += bonus_pairs
        
    kf = KFold(n_splits=k_splits, shuffle=True, random_state=42)
    
    best_models = []
    
    for fold, (train_idx, val_idx) in enumerate(kf.split(pairs), 1):
        print(f"\n=== Training Fold {fold}/{k_splits} ===")
        train_pairs = [pairs[i] for i in train_idx]
        val_pairs = [pairs[i] for i in val_idx]
        
        train_ds = StarDistDataset(train_pairs, n_rays=N_RAYS)
        val_ds = StarDistDataset(val_pairs, n_rays=N_RAYS)
        
        train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2 if os.name != 'nt' else 0)
        val_loader = DataLoader(val_ds, batch_size=1, shuffle=False)
        
        model = StarDistLightning(n_rays=N_RAYS)
        
        checkpoint_callback = ModelCheckpoint(
            dirpath=SAVE_DIR, 
            filename=f"stardist_fold_{fold}",
            monitor="val_loss",
            mode="min",
            save_top_k=1
        )
        
        trainer = pl.Trainer(
            max_epochs=epochs,
            accelerator="gpu" if torch.cuda.is_available() else "cpu",
            devices=1,
            default_root_dir=SAVE_DIR,
            callbacks=[checkpoint_callback],
            log_every_n_steps=5
        )
        
        trainer.fit(model, train_loader, val_loader)
        
        print(f"Fold {fold} Best Model: {checkpoint_callback.best_model_path}")
        best_models.append(checkpoint_callback.best_model_path)
        
    return best_models

# Run Training
best_model_paths = train_kfold(k_splits=2, epochs=EPOCHS)
print("Best Models:", best_model_paths)


In [None]:
def generate_detections(model_path, out_dir="predictions"):
    print(f"Generating detections using {model_path}...")
    model = StarDistLightning.load_from_checkpoint(model_path)
    model.eval()
    if torch.cuda.is_available():
        model.cuda()
    
    # Predict on Video frames
    video_root = Path("experiment_dataset/video")
    img_paths = sorted((video_root / "images").glob("*"))
    video_map_path = Path("experiment_dataset/video_map.csv")
    if not video_map_path.exists():
        print("Video map not found, cannot map to frames.")
        return pd.DataFrame()
        
    frame_lookup = pd.read_csv(video_map_path).set_index("filename")["real_frame_idx"].to_dict()
    
    all_detections = []
    
    for img_path in img_paths:
        img = _load_gray(img_path)
        # Normalize
        p1, p99 = np.percentile(img, (1, 99.8))
        img_n = np.clip(img, p1, p99)
        img_n = (img_n - p1) / (p99 - p1 + 1e-8)
        img_n = img_n.astype(np.float32)
        
        inp = torch.from_numpy(img_n[np.newaxis, np.newaxis, ...])
        if torch.cuda.is_available():
            inp = inp.cuda()
            
        with torch.no_grad():
            out = model(inp)
            # Binary prob: (1, 1, H, W)
            prob_map = torch.sigmoid(out["nuc"].binary_map).cpu().numpy().squeeze() 
            # Dist map: (1, 32, H, W)
            dist_map = out["nuc"].aux_map.cpu().numpy().squeeze()
            
        # If batch dim was squeezed out differently, ensure shapes:
        if prob_map.ndim == 3: prob_map = prob_map[0] # Handle case just in case
        
        # StarDist Post-processing
        # prob_map: (H, W), dist_map: (n_rays, H, W)
        # post_proc_stardist expects (prob, dist)
        
        labels = post_proc_stardist(
            dist_map=prob_map,
            stardist_map=dist_map,
            score_thresh=0.5,
            iou_thresh=0.3, # NMS IoU threshold
        )
        
        # Extract centroids
        props = regionprops(labels)
        frame_idx = frame_lookup.get(img_path.name, 0)
        
        for p in props:
            # Centroid is (y, x)
            y, x = p.centroid
            all_detections.append({
                "frame": frame_idx, 
                "x": x, 
                "y": y
            })
            
    df = pd.DataFrame(all_detections)
    os.makedirs(out_dir, exist_ok=True)
    csv_path = Path(out_dir) / "stardist_detections.csv"
    df.to_csv(csv_path, index=False)
    print(f"Saved {len(df)} detections to {csv_path}")
    return df

# Run on the last best model
if best_model_paths:
    detections_df = generate_detections(best_model_paths[-1])
else:
    print("No models trained.")


In [None]:
def run_btrack_tracking(detections_df, config_file="btrack_config.json"):
    if detections_df is None or detections_df.empty:
        print("No detections to track.")
        return pd.DataFrame()
    
    if btrack is None:
        print("Btrack library not available.")
        return pd.DataFrame()

    print("Running btrack...")
    # Convert to btrack objects
    # btrack expects a list of objects usually, or a dataframe
    # We can use create_objects_from_array
    
    data = detections_df[['x', 'y', 'frame']].copy()
    
    objects = btrack.create_objects_from_array(
        data.to_dict(orient="records"), 
        properties=['x', 'y']
    )
    
    # Config
    if not os.path.exists(config_file):
        print("Creating default btrack config...")
        cfg = {
             "MotionModel": {
                "name": "cell_motion",
                "dt": 1.0,
                "measurements": 3,
                "states": 6,
                "accuracy": 7.5,
                "prob_not_assign": 0.1,
                "max_lost": 5,
                "enable_optimization": True
            },
            "HypothesisModel": {
                "name": "cell_hypothesis",
                "hypotheses": ["P_FP", "P_init", "P_term", "P_link", "P_branch", "P_dead"],
                "lambda_time": 5.0,
                "lambda_dist": 3.0,
                "lambda_link": 10.0,
                "lambda_branch": 50.0,
                "eta": 1e-10,
                "theta_dist": 20.0,
                "theta_time": 5.0,
                "dist_thresh": 40.0,
                "time_thresh": 2.0,
                "apop_thresh": 5,
                "segmentation_miss_rate": 0.1,
                "apoptosis_rate": 0.001,
                "relax": True
            }
         }
        with open(config_file, "w") as f:
            json.dump(cfg, f, indent=2)

    with btrack.BayesianTracker() as tracker:
        tracker.configure(config_file)
        tracker.append(objects)
        tracker.max_search_radius = 50.0 
        tracker.volume = ((0, 1024), (0, 1024), (-1e5, 1e5))
        tracker.track_interactive(step_size=100) # or tracker.track()
        tracker.optimize()
        
        tracks = tracker.to_pandas()
        
    print(f"BTrack finished. Found {len(tracks['track_id'].unique())} tracks.")
    return tracks

# Run BTrack
btrack_tracks = run_btrack_tracking(detections_df)
if not btrack_tracks.empty:
    btrack_tracks.to_csv("btrack_results.csv", index=False)
    print("Saved btrack_results.csv")


In [None]:
def run_laptrack_tracking(detections_df):
    if detections_df is None or detections_df.empty:
        return pd.DataFrame()
    
    if laptrack is None:
        print("LapTrack not available.")
        return pd.DataFrame()

    print("Running LapTrack...")
    lt = LapTrack(
        track_dist_metric="sqeuclidean",
        track_cost_cutoff=30.0**2,
        gap_closing_max_frame_count=2,
        splitting_cost_cutoff=False, # Disable splitting for now
        merging_cost_cutoff=False
    )
    
    # Laptrack expects columns: [frame, x, y] (in that order if just array)
    # Or specific input format.
    # The 'predict_dataframe' method is convenient.
    # It expects: frame, x, y columns.
    
    try:
        track_df, split_df, merge_df = lt.predict_dataframe(
            detections_df,
            coordinate_cols=["x", "y"],
            frame_col="frame",
            only_coordinate_cols=False,
            validate_frame_col=False
        )
        print(f"LapTrack finished. Found {len(track_df['track_id'].unique())} tracks.")
        return track_df
    except Exception as e:
        print(f"LapTrack error: {e}")
        return pd.DataFrame()

# Run LapTrack
lap_tracks = run_laptrack_tracking(detections_df)
if not lap_tracks.empty:
    lap_tracks.to_csv("laptrack_results.csv", index=False)
    print("Saved laptrack_results.csv")
