
# SU2 Full Pipeline (Synthetic Training + Tracking Sweep)

This notebook configures the SU2 workspace, generates synthetic CCP data, trains the U-Net++ model, and evaluates tracking via a HOTA sweep on validation sequences. It is ready for local Jupyter, VS Code, or a remote (e.g., Colab) kernel.


In [None]:
# Workspace setup (paths + folders)
import os
import sys
import subprocess
from pathlib import Path

IS_COLAB = 'google.colab' in sys.modules

repo_root_env = os.environ.get('SU2_REPO_ROOT')
default_root = Path(repo_root_env).expanduser().resolve() if repo_root_env else Path.cwd().resolve()
repo_root = default_root

if not (repo_root / 'modules').exists():
    if IS_COLAB:
        colab_root = Path('/content/SU2').resolve()
        if not colab_root.exists():
            repo_url = os.environ.get('SU2_REPO_URL', 'https://github.com/veselm73/SU2.git')
            print(f"Cloning repository into {colab_root}...")
            subprocess.run(['git', 'clone', repo_url, str(colab_root)], check=True)
        repo_root = colab_root
    else:
        raise FileNotFoundError(
            "Could not locate the SU2 repository. Set SU2_REPO_ROOT to the repo path or run inside the repo root."
        )

REPO_ROOT = repo_root

if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

os.chdir(REPO_ROOT)

SAVE_DIR = Path(os.environ.get('SU2_SAVE_DIR', REPO_ROOT / 'artifacts'))
CHECKPOINT_DIR = Path(os.environ.get('SU2_CHECKPOINT_DIR', REPO_ROOT / 'checkpoints'))
VAL_DATA_DIR = Path(os.environ.get('SU2_VAL_DATA_DIR', REPO_ROOT / 'val_data'))
CHAIN_PATH = Path(os.environ.get('SU2_CERT_CHAIN', REPO_ROOT / 'chain-harica-cross.pem'))

for path_obj in (SAVE_DIR, CHECKPOINT_DIR, VAL_DATA_DIR):
    path_obj.mkdir(parents=True, exist_ok=True)

print(f"Repository root: {REPO_ROOT}")
print(f"Artifacts directory: {SAVE_DIR}")
print(f"Checkpoints directory: {CHECKPOINT_DIR}")
print(f"Validation data directory: {VAL_DATA_DIR}")


In [None]:

# Pipeline configuration (write config.yaml before importing modules)
import yaml

PIPELINE_CONFIG = {
    "TRAIN_SAMPLES": 4000,
    "VAL_SAMPLES": 600,
    "BATCH_SIZE": 16,
    "LEARNING_RATE": 5e-4,
    "WEIGHT_DECAY": 1e-4,
    "DROPOUT_RATE": 0.15,
    "EPOCHS": 200,
    "PATIENCE": 20,
    "SEED": 73,
    "DETECTION_MODEL": "unet",
    "SAM3_CHECKPOINT": str(CHECKPOINT_DIR / "sam3_hiera_large.pt"),
    "SAM3_MODEL_TYPE": "vit_l",
    "SKIP_TRAINING": False,
    "MIN_CELLS": 8,
    "MAX_CELLS": 24,
    "PATCH_SIZE": 128,
    "SIM_CONFIG": {
        "na": 1.49,
        "wavelength": 512,
        "px_size": 0.07,
        "wiener_parameter": 0.1,
        "apo_cutoff": 2.0,
        "apo_bend": 0.9,
    },
}

with open('config.yaml', 'w', encoding='utf-8') as f:
    yaml.safe_dump(PIPELINE_CONFIG, f, sort_keys=False)

print("config.yaml updated. Key parameters:")
for key in ("TRAIN_SAMPLES", "VAL_SAMPLES", "BATCH_SIZE", "LEARNING_RATE", "EPOCHS", "PATIENCE"):
    print(f"  {key}: {PIPELINE_CONFIG[key]}")


In [None]:

# Imports (after config is written)
import json
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch

from modules.config import *  # noqa: F401,F403
from modules.dataset import SyntheticCCPDataset
from modules.training import train_unet_pipeline
from modules.utils import set_seed, plot_training_history, open_tiff_file, download_and_unzip
from modules.sweep import sweep_and_save_gif
from modules.tracking import DetectionParams, BTrackParams

set_seed(SEED)
print("Imports ready. Random seed initialized.")


In [None]:

# Download validation data if missing
import requests

val_tif_path = VAL_DATA_DIR / 'val.tif'
if val_tif_path.exists():
    print(f"Validation TIFF already present at {val_tif_path}")
else:
    print("Downloading SSL certificate chain...")
    cert_url = 'https://pki.cesnet.cz/_media/certs/chain-harica-rsa-ov-crosssigned-root.pem'
    response = requests.get(cert_url, timeout=10, stream=True)
    response.raise_for_status()
    CHAIN_PATH.write_bytes(response.content)
    print("Certificate ready. Fetching validation archive...")
    zip_url = 'https://su2.utia.cas.cz/files/labs/final2025/val_and_sota.zip'
    download_and_unzip(zip_url, str(VAL_DATA_DIR), str(CHAIN_PATH))
    print("Validation data downloaded.")


In [None]:

# Synthetic vs real validation data preview
synthetic_dataset = SyntheticCCPDataset(
    min_n=PIPELINE_CONFIG['MIN_CELLS'],
    max_n=PIPELINE_CONFIG['MAX_CELLS'],
    patch_size=PIPELINE_CONFIG['PATCH_SIZE'],
    sim_config=PIPELINE_CONFIG['SIM_CONFIG'],
)

synth_img, synth_mask = synthetic_dataset.data_sample()
val_tif = VAL_DATA_DIR / 'val.tif'
val_img = None
if val_tif.exists():
    val_stack = open_tiff_file(str(val_tif))
    val_img = val_stack[0]

fig, axes = plt.subplots(1, 3, figsize=(15, 4))
axes[0].imshow(synth_img, cmap='gray')
axes[0].set_title('Synthetic CCP Image')
axes[0].axis('off')
axes[1].imshow(synth_mask, cmap='gray')
axes[1].set_title('Synthetic Mask')
axes[1].axis('off')
if val_img is not None:
    axes[2].imshow(val_img, cmap='gray')
    axes[2].set_title('Validation Frame (val.tif)')
else:
    axes[2].text(0.5, 0.5, 'Validation TIFF missing', ha='center', va='center')
    axes[2].set_axis_off()
plt.tight_layout()
plt.show()


In [None]:

# Train U-Net++ on synthetic data
training_kwargs = dict(
    train_samples=PIPELINE_CONFIG['TRAIN_SAMPLES'],
    val_samples=PIPELINE_CONFIG['VAL_SAMPLES'],
    epochs=PIPELINE_CONFIG['EPOCHS'],
    batch_size=PIPELINE_CONFIG['BATCH_SIZE'],
    learning_rate=PIPELINE_CONFIG['LEARNING_RATE'],
    weight_decay=PIPELINE_CONFIG['WEIGHT_DECAY'],
    patience=PIPELINE_CONFIG['PATIENCE'],
    device=DEVICE,
)

model, history = train_unet_pipeline(**training_kwargs)
print('Training finished.')


In [None]:

# Persist best epoch by validation Dice
if history['val_dice']:
    best_epoch = int(np.argmax(history['val_dice'])) + 1
    checkpoint_path = Path(f'checkpoint_epoch_{best_epoch}.pth')
    best_dice_path = SAVE_DIR / 'best_val_dice_model.pth'
    if checkpoint_path.exists():
        shutil.copy(checkpoint_path, best_dice_path)
        print(f'Best validation Dice epoch: {best_epoch} (score={history["val_dice"][best_epoch-1]:.4f})')
        print(f'Checkpoint copied to {best_dice_path}')
    else:
        print(f'Checkpoint {checkpoint_path} not found; cannot copy best Dice weights.')
else:
    print('History missing val_dice information.')


In [None]:

# Plot training curves
plot_training_history(history)


In [None]:

# Save final model snapshot
if model is not None:
    final_path = SAVE_DIR / 'final_model.pth'
    torch.save(model.state_dict(), final_path)
    print(f'Final model saved to {final_path}')
else:
    print('Model object is None; skipping save.')


In [None]:

# Define detection/tracking search spaces
from modules.tracking import DetectionParams, BTrackParams

det_param_grid = {
    'threshold': [0.25, 0.3, 0.35],
    'min_area': [3, 5],
    'nms_min_dist': [3.0, 4.0],
}

btrack_param_grid = {
    'do_optimize': [False],
    'max_search_radius': [20.0, 25.0],
    'dist_thresh': [12.0, 15.0],
    'time_thresh': [4, 6],
    'min_track_len': [8, 12],
    'segmentation_miss_rate': [0.1],
    'apoptosis_rate': [0.001],
    'allow_divisions': [False],
}
print('Parameter grids prepared.')


In [None]:

# Run HOTA sweep + GIF export
val_tif_path = (VAL_DATA_DIR / 'val.tif').resolve()
gif_output_path = SAVE_DIR / 'best_tracking.gif'

best_det, best_bt, best_tracks = sweep_and_save_gif(
    model,
    det_param_grid,
    btrack_param_grid,
    gif_output=str(gif_output_path),
    val_tif_path=str(val_tif_path)
)

print('Best detection params:', best_det)
print('Best tracking params:', best_bt)
print(f'GIF stored at: {gif_output_path}')


In [None]:

# Display best tracking GIF
from IPython.display import Image

gif_path = SAVE_DIR / 'best_tracking.gif'
if gif_path.exists():
    display(Image(filename=str(gif_path)))
else:
    print('best_tracking.gif not found.')
