# SU2 Project - Unet++ Training on Colab

This notebook runs the training pipeline using the modularized code structure.

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Create a folder for results
import os
SAVE_DIR = "/content/drive/MyDrive/SU2_Project"
os.makedirs(SAVE_DIR, exist_ok=True)
print(f"Results will be saved to: {SAVE_DIR}")

In [None]:
# Define Configuration
# You can modify these values directly here before running the training
config_content = """
TRAIN_SAMPLES: 5000
VAL_SAMPLES: 800
BATCH_SIZE: 16             # Reduced from 32 (more frequent updates, fits T4 memory)
LEARNING_RATE: 5e-4
WEIGHT_DECAY: 1e-4
DROPOUT_RATE: 0.15
EPOCHS: 300
PATIENCE: 15
SEED: 73

# Detection Model Configuration
DETECTION_MODEL: "sam3"  # Options: "unet", "sam3"
SAM3_CHECKPOINT: "checkpoints/sam3.pt"
SKIP_TRAINING: True # Set to True to skip UNet++ training and just use SAM 3

# Detection Model Configuration
DETECTION_MODEL: "sam3"  # Options: "unet", "sam3"
SAM3_CHECKPOINT: "checkpoints/sam3_hiera_large.pt"
SAM3_MODEL_TYPE: "vit_l"

# Data Generator Config
MIN_CELLS: 8
MAX_CELLS: 24
PATCH_SIZE: 128
SIM_CONFIG:
  na: 1.49
  wavelength: 512
  px_size: 0.07
  wiener_parameter: 0.1
  apo_cutoff: 2.0
  apo_bend: 0.9
"""

with open("config.yaml", "w") as f:
    f.write(config_content)

print("Configuration saved to config.yaml")

In [None]:
# Install dependencies with auto-restart
import os
import sys

def install_dependencies():
    print("Installing dependencies...")
    !pip install btrack==0.6.5 "pydantic<2" pyyaml
    !pip install git+https://github.com/facebookresearch/sam3.git
    !pip install einops decord pycocotools scipy
    print("Dependencies installed. Restarting runtime to apply changes...")
    os.kill(os.getpid(), 9)

try:
    import btrack
    import sam3
    print("Dependencies already installed.")
except ImportError:
    install_dependencies()

In [None]:
# Download SAM 3 Checkpoint
import os

CHECKPOINT_DIR = "checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, "sam3_hiera_large.pt")

if not os.path.exists(CHECKPOINT_PATH):
    print("Please upload the SAM 3 checkpoint manually to 'checkpoints/sam3_hiera_large.pt' or authenticate to download.")
else:
    print(f"Checkpoint found at {CHECKPOINT_PATH}")

In [None]:
# Clone Repository with Authentication
from google.colab import userdata
import os

try:
    token = userdata.get('GITHUB_TOKEN')
    repo_url = f"https://{token}@github.com/veselm73/SU2.git"
except Exception:
    print("GITHUB_TOKEN not found in secrets, using public clone...")
    repo_url = "https://github.com/veselm73/SU2.git"

REPO_DIR = "/content/SU2"

if not os.path.exists(REPO_DIR):
    print(f"Cloning repository...")
    !git clone {repo_url} {REPO_DIR}
else:
    print(f"Repository already exists at {REPO_DIR}")

%cd {REPO_DIR}


In [None]:
# Import modules
import sys
import os
import importlib

REPO_ROOT = '/content/SU2'
if REPO_ROOT not in sys.path:
    sys.path.append(REPO_ROOT)

# Reload config to ensure fresh load
import modules.config
importlib.reload(modules.config)

from modules.config import *
from modules.utils import set_seed, plot_training_history
from modules.training import train_unet_pipeline
from modules.tracking import run_tracking_on_validation

In [None]:
# Set random seeds
set_seed(SEED)

In [None]:
if 'SKIP_TRAINING' not in globals() or not SKIP_TRAINING:
    # Run Training Pipeline
    print("Starting Training Pipeline...")
    model, history = train_unet_pipeline(
        train_samples=TRAIN_SAMPLES,
        val_samples=VAL_SAMPLES,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        patience=PATIENCE,
        device=DEVICE
    )

else:
    print("Skipping UNet++ training (SKIP_TRAINING=True)")
    model = None
    history = {}


In [None]:
# Plot History
plot_training_history(history)

In [None]:
# Save Model to Drive
import torch
save_path = os.path.join(SAVE_DIR, "final_model.pth")
torch.save(model.state_dict(), save_path)
print(f"Model saved to {save_path}")

# Also save the best model if it exists locally (from training loop)
if os.path.exists("best_model.pth"):
    best_save_path = os.path.join(SAVE_DIR, "best_model.pth")
    import shutil
    shutil.copy("best_model.pth", best_save_path)
    print(f"Best model saved to {best_save_path}")

In [None]:
# Download Validation Data
from modules.utils import download_and_unzip
import requests

chain_path = "/content/chain-harica-cross.pem"
print("1) Downloading SSL certificate chain...")
cert_url = "https://pki.cesnet.cz/_media/certs/chain-harica-rsa-ov-crosssigned-root.pem"
r = requests.get(cert_url, timeout=10, stream=True)
r.raise_for_status()
with open(chain_path, "wb") as f:
    f.write(r.content)
print("2) Certificate chain downloaded.\n")

zip_url = "https://su2.utia.cas.cz/files/labs/final2025/val_and_sota.zip"
extract_directory = "/content/val_data"
download_and_unzip(zip_url, extract_directory, chain_path)

In [None]:
# Define Parameter Grids for Sweep
from modules.tracking import DetectionParams, BTrackParams

# 1. Detection: High Recall (0.25 - 0.30)
det_param_grid = {
    "threshold": [0.25, 0.3],
    "min_area": [4],
    "nms_min_dist": [3.0]
}

# 2. Tracking: Optimize + Aggressive Filtering
btrack_param_grid = {
    "do_optimize": [False],              # Enable Global Optimization
    "max_search_radius": [20.0],
    "dist_thresh": [15.0],
    "time_thresh": [4, 6],              # Allow gaps
    "min_track_len": [10, 15],          # Filter noise from low threshold
    "segmentation_miss_rate": [0.1],
    "apoptosis_rate": [0.001],
    "allow_divisions": [False]
}

print("Parameter grids defined.")

In [None]:
# Run Tracking Sweep and Generate GIF
from modules.sweep import sweep_and_save_gif

gif_output_path = os.path.join(SAVE_DIR, "best_tracking.gif")

best_det, best_bt, best_tracks = sweep_and_save_gif(
    model,
    det_param_grid,
    btrack_param_grid,
    gif_output=gif_output_path
)

print(f"Best tracking GIF saved to: {gif_output_path}")

In [None]:
# AUTO-DISCONNECT to save runtime units
from google.colab import runtime
print("Training finished. Disconnecting runtime to save units...")
runtime.unassign()

In [None]:
# Visualize SAM 3 Predictions vs Ground Truth
import matplotlib.pyplot as plt
import numpy as np
from modules.sam_detector import SAM3Detector
from modules.utils import open_tiff_file

print("Visualizing SAM 3 predictions...")

# Load validation data
val_tif_path = "/content/val_data/val.tif"
if os.path.exists(val_tif_path):
    val_images = open_tiff_file(val_tif_path)
    sample_image = val_images[0] # Take first frame
    
    # Initialize SAM 3
    detector = SAM3Detector()
    if detector.model is not None:
        # Run detection
        mask, detections = detector.detect(sample_image, text_prompt="cell")
        
        # Plot
        plt.figure(figsize=(15, 5))
        plt.subplot(1, 3, 1)
        plt.imshow(sample_image, cmap='gray')
        plt.title("Input Image")
        
        plt.subplot(1, 3, 2)
        plt.imshow(mask, cmap='gray')
        plt.title("SAM 3 Prediction")
        
        # Overlay
        plt.subplot(1, 3, 3)
        plt.imshow(sample_image, cmap='gray')
        plt.imshow(mask, alpha=0.5, cmap='jet')
        plt.title("Overlay")
        plt.show()
    else:
        print("SAM 3 model not loaded.")
else:
    print("Validation data not found. Please run the download cell first.")

In [None]:
# Visualize SAM 3 Predictions vs Ground Truth
import matplotlib.pyplot as plt
import numpy as np
from modules.sam_detector import SAM3Detector
from modules.utils import open_tiff_file
import pandas as pd

print("Visualizing SAM 3 predictions...")

# Load validation data
val_tif_path = "/content/val_data/val.tif"
val_csv_path = "/content/val_data/val.csv"

if os.path.exists(val_tif_path):
    val_images = open_tiff_file(val_tif_path)
    sample_idx = 0
    sample_image = val_images[sample_idx]
    
    # Initialize SAM 3
    detector = SAM3Detector()
    if detector.model is not None:
        # Run detection
        mask, detections = detector.detect(sample_image, text_prompt="cell")
        
        # Load GT if available
        gt_mask = np.zeros_like(mask)
        if os.path.exists(val_csv_path):
            df = pd.read_csv(val_csv_path)
            frame_df = df[df['frame'] == sample_idx]
            for _, row in frame_df.iterrows():
                # Draw circles for GT (approximate)
                y, x = int(row['y']), int(row['x'])
                if 0 <= y < gt_mask.shape[0] and 0 <= x < gt_mask.shape[1]:
                    # Simple 3x3 block for visibility
                    gt_mask[max(0, y-2):min(gt_mask.shape[0], y+3), max(0, x-2):min(gt_mask.shape[1], x+3)] = 1
        
        # Plot
        plt.figure(figsize=(20, 5))
        plt.subplot(1, 4, 1)
        plt.imshow(sample_image, cmap='gray')
        plt.title("Input Image")
        
        plt.subplot(1, 4, 2)
        plt.imshow(mask, cmap='gray')
        plt.title("SAM 3 Prediction")
        
        plt.subplot(1, 4, 3)
        plt.imshow(gt_mask, cmap='gray')
        plt.title("Ground Truth (Approx)")

        # Overlay
        plt.subplot(1, 4, 4)
        plt.imshow(sample_image, cmap='gray')
        plt.imshow(mask, alpha=0.4, cmap='jet')
        plt.imshow(gt_mask, alpha=0.4, cmap='spring') # GT in pink/spring
        plt.title("Overlay (SAM=Jet, GT=Pink)")
        plt.show()
    else:
        print("SAM 3 model not loaded.")
else:
    print("Validation data not found. Please run the download cell first.")