# SU2 Project - Unet++ Training on Colab

This notebook runs the training pipeline using the modularized code structure.

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Create a folder for results
import os
SAVE_DIR = "/content/drive/MyDrive/SU2_Project"
os.makedirs(SAVE_DIR, exist_ok=True)
print(f"Results will be saved to: {SAVE_DIR}")

In [None]:
# Install dependencies
!pip install btrack==0.6.5 "pydantic<2"

In [None]:
# Clone Repository and Import Modules
import sys
import os

# TODO: Replace with your GitHub username and repo name
REPO_URL = "https://github.com/Mateusz/SU2_Project.git"
REPO_DIR = "/content/SU2_Project"

if not os.path.exists(REPO_DIR):
    print(f"Cloning repository from {REPO_URL}...")
    !git clone {REPO_URL} {REPO_DIR}
else:
    print(f"Repository already exists at {REPO_DIR}")

# Ensure the repo directory is in path so we can import modules
sys.path.append(REPO_DIR)
# Change working directory to the repo so relative paths work
os.chdir(REPO_DIR)

from modules.config import *
from modules.utils import set_seed, plot_training_history
from modules.training import train_unet_pipeline
from modules.tracking import run_tracking_on_validation

In [None]:
# Set random seeds
set_seed(SEED)

In [None]:
# Run Training Pipeline
print("Starting Training Pipeline...")
model, history = train_unet_pipeline(
    train_samples=TRAIN_SAMPLES,
    val_samples=VAL_SAMPLES,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    patience=PATIENCE,
    device=DEVICE
)

In [None]:
# Plot History
plot_training_history(history)

In [None]:
# Save Model to Drive
import torch
save_path = os.path.join(SAVE_DIR, "final_model.pth")
torch.save(model.state_dict(), save_path)
print(f"Model saved to {save_path}")

# Also save the best model if it exists locally (from training loop)
if os.path.exists("best_model.pth"):
    best_save_path = os.path.join(SAVE_DIR, "best_model.pth")
    import shutil
    shutil.copy("best_model.pth", best_save_path)
    print(f"Best model saved to {best_save_path}")

In [None]:
# Download Validation Data
from modules.utils import download_and_unzip
import requests

chain_path = "/content/chain-harica-cross.pem"
print("1) Downloading SSL certificate chain...")
cert_url = "https://pki.cesnet.cz/_media/certs/chain-harica-rsa-ov-crosssigned-root.pem"
r = requests.get(cert_url, timeout=10, stream=True)
r.raise_for_status()
with open(chain_path, "wb") as f:
    f.write(r.content)
print("2) Certificate chain downloaded.\n")

zip_url = "https://su2.utia.cas.cz/files/labs/final2025/val_and_sota.zip"
extract_directory = "/content/val_data"
download_and_unzip(zip_url, extract_directory, chain_path)

In [None]:
# Define Parameter Grids for Sweep
from modules.tracking import DetectionParams, BTrackParams

# 1. Detection: High Recall (0.25 - 0.30)
det_param_grid = {
    "threshold": [0.25, 0.3],
    "min_area": [4],
    "nms_min_dist": [3.0]
}

# 2. Tracking: Optimize + Aggressive Filtering
btrack_param_grid = {
    "do_optimize": [False],              # Enable Global Optimization
    "max_search_radius": [20.0],
    "dist_thresh": [15.0],
    "time_thresh": [4, 6],              # Allow gaps
    "min_track_len": [10, 15],          # Filter noise from low threshold
    "segmentation_miss_rate": [0.1],
    "apoptosis_rate": [0.001],
    "allow_divisions": [False]
}

print("Parameter grids defined.")

In [None]:
# Run Tracking Sweep and Generate GIF
from modules.sweep import sweep_and_save_gif

gif_output_path = os.path.join(SAVE_DIR, "best_tracking.gif")

best_det, best_bt, best_tracks = sweep_and_save_gif(
    model,
    det_param_grid,
    btrack_param_grid,
    gif_output=gif_output_path
)

print(f"Best tracking GIF saved to: {gif_output_path}")

In [None]:
# AUTO-DISCONNECT to save runtime units
from google.colab import runtime
print("Training finished. Disconnecting runtime to save units...")
runtime.unassign()