# SU2 Project - Unet++ Training (VS Code Ready)

This notebook runs the training pipeline directly from this repository. It stores artifacts locally so it plays nicely with VS Code Jupyter (local Python or the Colab extension) without requiring a Google Drive mount.


In [1]:
# Workspace paths (no Google Drive dependency)
import os
import sys
from pathlib import Path

IS_COLAB = 'google.colab' in sys.modules

REPO_ROOT = Path(os.environ.get('SU2_REPO_ROOT', Path.cwd())).resolve()
if not REPO_ROOT.exists():
    raise FileNotFoundError(f"Configured REPO_ROOT does not exist: {REPO_ROOT}")

os.chdir(REPO_ROOT)

SAVE_DIR = Path(os.environ.get('SU2_SAVE_DIR', REPO_ROOT / 'artifacts'))
CHECKPOINT_DIR = Path(os.environ.get('SU2_CHECKPOINT_DIR', REPO_ROOT / 'checkpoints'))
VAL_DATA_DIR = Path(os.environ.get('SU2_VAL_DATA_DIR', REPO_ROOT / 'val_data'))
CHAIN_PATH = Path(os.environ.get('SU2_CERT_CHAIN', REPO_ROOT / 'chain-harica-cross.pem'))

for path in (SAVE_DIR, CHECKPOINT_DIR, VAL_DATA_DIR):
    path.mkdir(parents=True, exist_ok=True)

print(f"Repository root: {REPO_ROOT}")
print(f"Artifacts will be stored in: {SAVE_DIR}")
print(f"Checkpoints directory: {CHECKPOINT_DIR}")
print(f"Validation data directory: {VAL_DATA_DIR}")


Repository root: /content
Artifacts will be stored in: /content/artifacts
Checkpoints directory: /content/checkpoints
Validation data directory: /content/val_data


In [2]:
# Define Configuration
# You can modify these values directly here before running the training
config_content = """
TRAIN_SAMPLES: 5000
VAL_SAMPLES: 800
BATCH_SIZE: 16             # Reduced from 32 (more frequent updates, fits T4 memory)
LEARNING_RATE: 5e-4
WEIGHT_DECAY: 1e-4
DROPOUT_RATE: 0.15
EPOCHS: 300
PATIENCE: 15
SEED: 73

# Detection Model Configuration
DETECTION_MODEL: "sam3"  # Options: "unet", "sam3"
SAM3_CHECKPOINT: "checkpoints/sam3.pt"
SKIP_TRAINING: True # Set to True to skip UNet++ training and just use SAM 3

# Detection Model Configuration
DETECTION_MODEL: "sam3"  # Options: "unet", "sam3"
SAM3_CHECKPOINT: "checkpoints/sam3_hiera_large.pt"
SAM3_MODEL_TYPE: "vit_l"

# Data Generator Config
MIN_CELLS: 8
MAX_CELLS: 24
PATCH_SIZE: 128
SIM_CONFIG:
  na: 1.49
  wavelength: 512
  px_size: 0.07
  wiener_parameter: 0.1
  apo_cutoff: 2.0
  apo_bend: 0.9
"""

with open("config.yaml", "w") as f:
    f.write(config_content)

print("Configuration saved to config.yaml")

Configuration saved to config.yaml


In [3]:
# Install dependencies with restart-safe workflow
import subprocess
import sys


def pip_install(*packages):
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', *packages])


try:
    import btrack
    import sam3
    print('Dependencies already installed.')
except ImportError:
    print('Installing dependencies... (This may take a few minutes)')
    pip_install('btrack==0.6.5', 'pydantic<2', 'pyyaml')
    pip_install('git+https://github.com/facebookresearch/sam3.git')
    pip_install('einops', 'decord', 'pycocotools', 'scipy')
    print('Install complete. Restart the kernel only if VS Code prompts you to reload modules.')


Installing dependencies... (This may take a few minutes)
Install complete. Restart the kernel only if VS Code prompts you to reload modules.


In [4]:
# Check for SAM 3 checkpoint
CHECKPOINT_PATH = CHECKPOINT_DIR / 'sam3_hiera_large.pt'

if not CHECKPOINT_PATH.exists():
    print(f"SAM 3 checkpoint missing. Please place it at: {CHECKPOINT_PATH}")
else:
    print(f"Checkpoint found at {CHECKPOINT_PATH}")


SAM 3 checkpoint missing. Please place it at: /content/checkpoints/sam3_hiera_large.pt


In [None]:
# Import modules
import sys
import importlib

if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

import modules.config
importlib.reload(modules.config)

from modules.config import *
from modules.utils import set_seed, plot_training_history
from modules.training import train_unet_pipeline
from modules.tracking import run_tracking_on_validation


In [None]:
# Set random seeds
set_seed(SEED)

In [None]:
if 'SKIP_TRAINING' not in globals() or not SKIP_TRAINING:
    # Run Training Pipeline
    print("Starting Training Pipeline...")
    model, history = train_unet_pipeline(
        train_samples=TRAIN_SAMPLES,
        val_samples=VAL_SAMPLES,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        patience=PATIENCE,
        device=DEVICE
    )

else:
    print("Skipping UNet++ training (SKIP_TRAINING=True)")
    model = None
    history = {}


In [None]:
# Plot History
plot_training_history(history)

In [None]:
# Save model locally
import shutil
import torch
from pathlib import Path

if model is None:
    print('SKIP_TRAINING=True so there is no trained UNet++ model to save.')
else:
    save_path = SAVE_DIR / 'final_model.pth'
    torch.save(model.state_dict(), save_path)
    print(f"Model saved to {save_path}")

    best_model_path = Path('best_model.pth')
    if best_model_path.exists():
        best_save_path = SAVE_DIR / 'best_model.pth'
        shutil.copy(best_model_path, best_save_path)
        print(f"Best model saved to {best_save_path}")


In [None]:
# Download validation data (stored locally)
from modules.utils import download_and_unzip
import requests

print('1) Downloading SSL certificate chain...')
cert_url = 'https://pki.cesnet.cz/_media/certs/chain-harica-rsa-ov-crosssigned-root.pem'
response = requests.get(cert_url, timeout=10, stream=True)
response.raise_for_status()
CHAIN_PATH.write_bytes(response.content)
print('2) Certificate chain downloaded.
')

zip_url = 'https://su2.utia.cas.cz/files/labs/final2025/val_and_sota.zip'
download_and_unzip(zip_url, str(VAL_DATA_DIR), str(CHAIN_PATH))


In [None]:
# Define Parameter Grids for Sweep
from modules.tracking import DetectionParams, BTrackParams

# 1. Detection: High Recall (0.25 - 0.30)
det_param_grid = {
    "threshold": [0.25, 0.3],
    "min_area": [4],
    "nms_min_dist": [3.0]
}

# 2. Tracking: Optimize + Aggressive Filtering
btrack_param_grid = {
    "do_optimize": [False],              # Enable Global Optimization
    "max_search_radius": [20.0],
    "dist_thresh": [15.0],
    "time_thresh": [4, 6],              # Allow gaps
    "min_track_len": [10, 15],          # Filter noise from low threshold
    "segmentation_miss_rate": [0.1],
    "apoptosis_rate": [0.001],
    "allow_divisions": [False]
}

print("Parameter grids defined.")

In [None]:
# Run tracking sweep and generate GIF
from modules.sweep import sweep_and_save_gif

val_tif_path = (VAL_DATA_DIR / 'val.tif').resolve()
gif_output_path = SAVE_DIR / 'best_tracking.gif'

best_det, best_bt, best_tracks = sweep_and_save_gif(
    model,
    det_param_grid,
    btrack_param_grid,
    gif_output=str(gif_output_path),
    val_tif_path=str(val_tif_path)
)

print(f"Best tracking GIF saved to: {gif_output_path}")


In [None]:
# Visualize SAM 3 predictions vs ground truth (image + mask)
import matplotlib.pyplot as plt
import numpy as np
from modules.sam_detector import SAM3Detector
from modules.utils import open_tiff_file

print('Visualizing SAM 3 predictions...')

val_tif_path = VAL_DATA_DIR / 'val.tif'
if val_tif_path.exists():
    val_images = open_tiff_file(str(val_tif_path))
    sample_image = val_images[0]

    detector = SAM3Detector()
    if detector.model is not None:
        mask, detections = detector.detect(sample_image, text_prompt='cell')

        plt.figure(figsize=(15, 5))
        plt.subplot(1, 3, 1)
        plt.imshow(sample_image, cmap='gray')
        plt.title('Input Image')

        plt.subplot(1, 3, 2)
        plt.imshow(mask, cmap='gray')
        plt.title('SAM 3 Prediction')

        plt.subplot(1, 3, 3)
        plt.imshow(sample_image, cmap='gray')
        plt.imshow(mask, alpha=0.5, cmap='jet')
        plt.title('Overlay')
        plt.show()
    else:
        print('SAM 3 model not loaded.')
else:
    print('Validation data not found. Please run the download cell first.')


In [None]:
# Visualize SAM 3 predictions with GT overlay
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from modules.sam_detector import SAM3Detector
from modules.utils import open_tiff_file

print('Visualizing SAM 3 predictions...')

val_tif_path = VAL_DATA_DIR / 'val.tif'
val_csv_path = VAL_DATA_DIR / 'val.csv'

if val_tif_path.exists():
    val_images = open_tiff_file(str(val_tif_path))
    sample_idx = 0
    sample_image = val_images[sample_idx]

    detector = SAM3Detector()
    if detector.model is not None:
        mask, detections = detector.detect(sample_image, text_prompt='cell')

        gt_mask = np.zeros_like(mask)
        if val_csv_path.exists():
            df = pd.read_csv(val_csv_path)
            frame_df = df[df['frame'] == sample_idx]
            for _, row in frame_df.iterrows():
                y, x = int(row['y']), int(row['x'])
                if 0 <= y < gt_mask.shape[0] and 0 <= x < gt_mask.shape[1]:
                    gt_mask[max(0, y-2):min(gt_mask.shape[0], y+3), max(0, x-2):min(gt_mask.shape[1], x+3)] = 1

        plt.figure(figsize=(20, 5))
        plt.subplot(1, 4, 1)
        plt.imshow(sample_image, cmap='gray')
        plt.title('Input Image')

        plt.subplot(1, 4, 2)
        plt.imshow(mask, cmap='gray')
        plt.title('SAM 3 Prediction')

        plt.subplot(1, 4, 3)
        plt.imshow(gt_mask, cmap='gray')
        plt.title('Ground Truth (Approx)')

        plt.subplot(1, 4, 4)
        plt.imshow(sample_image, cmap='gray')
        plt.imshow(mask, alpha=0.4, cmap='jet')
        plt.imshow(gt_mask, alpha=0.4, cmap='spring')
        plt.title('Overlay (SAM=Jet, GT=Pink)')
        plt.show()
    else:
        print('SAM 3 model not loaded.')
else:
    print('Validation data not found. Please run the download cell first.')
