# 02 — Train & Evaluate YOLOv8-Seg Board Detection

**Pipeline:** YOLO-seg dataset (from notebook 01) → pretrained baseline →
fine-tune YOLOv8n-seg → validate (mask mAP) → qualitative evaluation →
pick checkpoint → export ONNX

**Inputs:**
- `data/yolo-seg-board/` — YOLO-seg dataset (377 train / 93 val, single class `board`)
- `yolov8n-seg.pt` — Ultralytics pretrained nano segmentation weights (downloaded automatically)

**Outputs:**
- `runs/segment/board-detect/weights/best.pt` — Best fine-tuned checkpoint
- `runs/segment/board-detect/weights/best.onnx` — ONNX export for deployment
- `runs/segment/board-detect/metrics.json` — Key metrics for experiment tracking

## 1. Setup & Configuration

In [None]:
import json
import math
import random
from datetime import datetime
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import yaml
from tqdm.notebook import tqdm
from ultralytics import YOLO

# --- Reproducibility ---
SEED = 67
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# --- Paths (same convention as notebook 01) ---
PROJECT_ROOT = Path.cwd().parent  # prototyping/
DATA_DIR = PROJECT_ROOT / "prototyping/data"
DATASET_DIR = DATA_DIR / "yolo-seg-board"
DATA_YAML = DATASET_DIR / "data.yaml"
RUNS_DIR = PROJECT_ROOT / "trevor-misc/data/fine-tuning-runs"

# --- Training Config ---
MODEL_VARIANT = "yolov8s-seg.pt"  # SMALL model (11M params) - better capacity than nano (3.2M)
                                  # nano was too small → poor detection rate + bad masks
                                  # swap to yolov8m-seg.pt (27M) if small still struggles
EXPERIMENT_NAME = "board-detect"
EPOCHS = 100
IMGSZ = 640
BATCH = 16               # tune down to 8 if OOM; -1 for auto-batch
PATIENCE = 20            # early stopping patience
OPTIMIZER = "auto"       # lets Ultralytics pick AdamW or SGD
DEVICE = "mps"           # "mps" for Apple Silicon, "0" for CUDA, "cpu" for CPU

# --- Fine-Tuning Config ---
# CRITICAL: For small datasets (<500 images), freeze backbone layers to prevent overfitting.
# Freezing preserves pretrained COCO features while training only the detection head.
# freeze=10 → freeze first 10 layers (backbone mostly frozen, neck/head train)
# freeze=24 → more aggressive (nearly entire backbone frozen)
FREEZE_LAYERS = 10       # Freeze first 10 backbone layers to prevent overfitting

# Lower learning rate for fine-tuning (default 0.01 is for training from scratch).
# Fine-tuning requires 10-100x lower LR to avoid catastrophic forgetting of pretrained weights.
LEARNING_RATE = 0.005    # Lower LR for fine-tuning (vs default 0.01)

# --- Augmentation overrides ---
# Notebook 01's synthetic generator already applies: brightness (0.7–1.3),
# contrast (0.8–1.2), Gaussian noise (sigma 0–15), per-channel color shift (±10),
# rotation (±15°), and perspective (strength=0.05).
# Dial back YOLO's overlapping augmentations to avoid double-randomization.
AUG_OVERRIDES = dict(
    hsv_h=0.005,         # default 0.015 — slight hue jitter only
    hsv_s=0.3,           # default 0.7
    hsv_v=0.2,           # default 0.4
    degrees=5.0,         # default 0.0 — very mild (synth already does ±15)
    translate=0.05,      # default 0.1
    scale=0.2,           # default 0.5
    shear=0.0,           # keep off
    perspective=0.0,     # synth already does this
    flipud=0.0,          # boards don't appear upside-down in practice
    fliplr=0.5,          # keep — board is symmetric
    mosaic=0.5,          # default 1.0 — halve; synth images already vary a lot
    mixup=0.0,           # no sense for single-object seg
    copy_paste=0.0,      # keep off
)

print(f"Dataset:     {DATA_YAML}")
print(f"Model:       {MODEL_VARIANT}")
print(f"Device:      {DEVICE}")
print(f"Epochs:      {EPOCHS}")
print(f"Image size:  {IMGSZ}")
print(f"Batch:       {BATCH}")
print(f"Patience:    {PATIENCE}")
print(f"Freeze:      {FREEZE_LAYERS} layers")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Runs dir:    {RUNS_DIR}")
print(f"Torch:       {torch.__version__}")
print(f"CUDA avail:  {torch.cuda.is_available()}")
print(f"MPS avail:   {torch.backends.mps.is_available()}")

## 2. Dataset Sanity Check — Verify YAML & Visualize Polygons

In [None]:
# Load and display data.yaml
with open(DATA_YAML) as f:
    data_cfg = yaml.safe_load(f)

print("=== data.yaml ===")
for k, v in data_cfg.items():
    print(f"  {k}: {v}")

# Count images per split and type
for split in ["train", "val"]:
    img_dir = DATASET_DIR / "images" / split
    all_imgs = sorted(img_dir.glob("*.jpg"))
    real = [p for p in all_imgs if p.name.startswith("real_")]
    synth = [p for p in all_imgs if p.name.startswith("synth_")]
    print(f"\n{split}: {len(all_imgs)} images ({len(real)} real, {len(synth)} synthetic)")

# Verify image-label pairing integrity
for split in ["train", "val"]:
    img_stems = {p.stem for p in (DATASET_DIR / "images" / split).glob("*.jpg")}
    lbl_stems = {p.stem for p in (DATASET_DIR / "labels" / split).glob("*.txt")}
    assert img_stems == lbl_stems, f"{split}: image/label mismatch!"
print("\nIntegrity check passed: every image has a matching label.")

In [None]:
def visualize_yolo_seg_label(img_path, label_path, ax):
    """Load image + YOLO-seg label, draw polygon overlay."""
    img = cv2.imread(str(img_path))
    h, w = img.shape[:2]
    label_text = label_path.read_text().strip()
    parts = label_text.split()
    coords = [float(x) for x in parts[1:]]
    pts = np.array(
        [[coords[i] * w, coords[i + 1] * h] for i in range(0, len(coords), 2)],
        dtype=np.int32,
    )
    vis = img.copy()
    cv2.polylines(vis, [pts.reshape(-1, 1, 2)], True, (0, 255, 0), 3)
    overlay = vis.copy()
    cv2.fillPoly(overlay, [pts.reshape(-1, 1, 2)], (0, 255, 0))
    vis = cv2.addWeighted(overlay, 0.15, vis, 0.85, 0)
    ax.imshow(cv2.cvtColor(vis, cv2.COLOR_BGR2RGB))
    ax.set_title(img_path.name, fontsize=7)
    ax.axis("off")


# Sample 4 train + 4 val
train_imgs = sorted((DATASET_DIR / "images" / "train").glob("*.jpg"))
val_imgs = sorted((DATASET_DIR / "images" / "val").glob("*.jpg"))
sample = random.sample(train_imgs, min(4, len(train_imgs))) + random.sample(val_imgs, min(4, len(val_imgs)))

fig, axes = plt.subplots(2, 4, figsize=(20, 10))
for ax, img_path in zip(axes.flatten(), sample):
    lbl_path = DATASET_DIR / "labels" / img_path.parent.name / f"{img_path.stem}.txt"
    visualize_yolo_seg_label(img_path, lbl_path, ax)
plt.suptitle("Dataset Sanity Check (green = label polygon)", fontsize=14)
plt.tight_layout()
plt.show()

## 3. Pretrained Baseline — Before Fine-Tuning

Run the pretrained COCO model on our validation images to establish a baseline.
COCO does not have a "board" class, so we expect zero or near-zero detections —
this confirms the model needs fine-tuning and gives us a before/after comparison.

In [None]:
pretrained_model = YOLO(MODEL_VARIANT)

val_sample = random.sample(val_imgs, min(8, len(val_imgs)))

fig, axes = plt.subplots(2, 4, figsize=(20, 10))
for ax, img_path in zip(axes.flatten(), val_sample):
    results = pretrained_model.predict(str(img_path), imgsz=IMGSZ, conf=0.25, verbose=False)
    annotated = results[0].plot()
    ax.imshow(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
    n_det = len(results[0].boxes) if results[0].boxes is not None else 0
    ax.set_title(f"{img_path.name[:25]}\n{n_det} detections", fontsize=7)
    ax.axis("off")

plt.suptitle("Pretrained Baseline (COCO weights, no fine-tuning)", fontsize=14)
plt.tight_layout()
plt.show()

print("Expected: zero or near-zero relevant detections (COCO has no 'board' class).")

## 4. Train YOLOv8s-Seg — Fine-Tune on Board Dataset

Core training cell. This uses **YOLOv8s-seg (small, 11M params)** instead of nano (3.2M params)
because the nano model was too small:
- Only detected 25% of boards (4/16)
- Mask quality was very poor (IoU ~0.5, rectangular/boxy masks)
- Training was unstable (metrics bouncing wildly)

The small model has **3.5x more capacity** for learning precise boundaries and should give:
- Better detection rate (>80% of boards)
- Better mask quality (IoU >0.75)
- More stable training

**Fine-tuning hyperparameters** to prevent overfitting on our small dataset (377 train, 23 real):
- **freeze=10**: Freezes first 10 backbone layers, preserving pretrained COCO features
- **lr0=0.005**: Lower learning rate (vs default 0.01) to avoid catastrophic forgetting

Augmentation is dialed back because the synthetic generator in notebook 01 already applies
heavy domain randomization (brightness, contrast, noise, color shift, rotation, perspective).

To swap model size: change `MODEL_VARIANT` in Section 1 to `yolov8m-seg.pt` (medium, 27M params)
if small still struggles. Training time scales roughly 2–3x per size step.

In [None]:
model = YOLO(MODEL_VARIANT)

results = model.train(
    data=str(DATA_YAML),
    epochs=EPOCHS,
    imgsz=IMGSZ,
    batch=BATCH,
    patience=PATIENCE,
    optimizer=OPTIMIZER,
    device=DEVICE,
    seed=SEED,
    deterministic=True,
    project=str(RUNS_DIR / "segment"),
    name=EXPERIMENT_NAME,
    exist_ok=True,
    pretrained=True,
    # CRITICAL CHANGES FOR FINE-TUNING:
    freeze=FREEZE_LAYERS,    # Freeze backbone to prevent overfitting on small dataset
    lr0=LEARNING_RATE,       # Lower learning rate for fine-tuning
    # Augmentation overrides (reduced — see Section 1 comments)
    **AUG_OVERRIDES,
    # Training behavior
    verbose=True,
    save=True,
    save_period=-1,      # only save best and last
    plots=True,          # save training plots
    val=True,            # validate after each epoch
    close_mosaic=10,     # disable mosaic for final 10 epochs
)

TRAIN_DIR = Path(model.trainer.save_dir)
print(f"\nTraining complete. Results saved to: {TRAIN_DIR}")

## 5. Training Curves

Review loss curves and mAP progression to verify training converged properly.

In [None]:
results_csv = TRAIN_DIR / "results.csv"
df = pd.read_csv(results_csv)
df.columns = df.columns.str.strip()  # Ultralytics pads column names with spaces

fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# Row 1: Losses (train + val)
loss_cols = [
    ("train/box_loss", "val/box_loss", "Box Loss"),
    ("train/seg_loss", "val/seg_loss", "Seg Loss"),
    ("train/cls_loss", "val/cls_loss", "Cls Loss"),
]
for ax, (train_col, val_col, title) in zip(axes[0], loss_cols):
    if train_col in df.columns:
        ax.plot(df["epoch"], df[train_col], "b-", linewidth=1.5, label="train")
    if val_col in df.columns:
        ax.plot(df["epoch"], df[val_col], "r-", linewidth=1.5, label="val")
    ax.set_title(title)
    ax.set_xlabel("Epoch")
    ax.legend()
    ax.grid(True, alpha=0.3)

# Row 2: Metrics
metric_cols = [
    ("metrics/precision(B)", "Precision"),
    ("metrics/recall(B)", "Recall"),
    ("metrics/mAP50-95(M)", "Mask mAP50-95"),
]
for ax, (col, title) in zip(axes[1], metric_cols):
    if col in df.columns:
        ax.plot(df["epoch"], df[col], "g-", linewidth=1.5)
    ax.set_title(title)
    ax.set_xlabel("Epoch")
    ax.set_ylim(0, 1)
    ax.grid(True, alpha=0.3)

plt.suptitle("Training Curves", fontsize=14)
plt.tight_layout()
plt.show()

# Best epoch
if "metrics/mAP50-95(M)" in df.columns:
    best_row = df.loc[df["metrics/mAP50-95(M)"].idxmax()]
    print(f"Best mask mAP50-95: {best_row['metrics/mAP50-95(M)']:.4f} at epoch {int(best_row['epoch'])}")

## 6. Validate — Mask mAP Metrics on Val Set

Run `model.val()` on the best checkpoint to get official metrics.
Key metric for board segmentation: **mask mAP50-95** (averaged over IoU thresholds 0.50–0.95).

In [None]:
best_weights = TRAIN_DIR / "weights" / "best.pt"
assert best_weights.exists(), f"best.pt not found at {best_weights}"

best_model = YOLO(str(best_weights))
val_results = best_model.val(
    data=str(DATA_YAML),
    imgsz=IMGSZ,
    batch=BATCH,
    device=DEVICE,
    split="val",
    verbose=True,
    plots=True,
)

box_map50 = val_results.box.map50
box_map5095 = val_results.box.map
mask_map50 = val_results.seg.map50
mask_map5095 = val_results.seg.map

print("\n=== Validation Metrics (best.pt) ===")
print(f"  Box  mAP@50:    {box_map50:.4f}")
print(f"  Box  mAP@50-95: {box_map5095:.4f}")
print(f"  Mask mAP@50:    {mask_map50:.4f}")
print(f"  Mask mAP@50-95: {mask_map5095:.4f}")

In [None]:
last_weights = TRAIN_DIR / "weights" / "last.pt"
assert last_weights.exists(), f"last.pt not found at {last_weights}"

last_model = YOLO(str(last_weights))
last_val = last_model.val(
    data=str(DATA_YAML), imgsz=IMGSZ, batch=BATCH, device=DEVICE, split="val", verbose=False
)

print("=== Checkpoint Comparison ===")
print(f"{'Checkpoint':<12} {'Box mAP50':>10} {'Box mAP50-95':>14} {'Mask mAP50':>12} {'Mask mAP50-95':>15}")
print("-" * 67)
print(f"{'best.pt':<12} {box_map50:>10.4f} {box_map5095:>14.4f} {mask_map50:>12.4f} {mask_map5095:>15.4f}")
print(f"{'last.pt':<12} {last_val.box.map50:>10.4f} {last_val.box.map:>14.4f} {last_val.seg.map50:>12.4f} {last_val.seg.map:>15.4f}")

use_best = mask_map5095 >= last_val.seg.map
chosen = "best.pt" if use_best else "last.pt"
print(f"\nChosen checkpoint: {chosen}")

CHOSEN_WEIGHTS = best_weights if use_best else last_weights

## 7. Qualitative Evaluation — Visual Predictions on Val Set

Predict on validation images and display overlays. This is the most important
QA step — numbers can look good while masks are subtly wrong.

In [None]:
eval_model = YOLO(str(CHOSEN_WEIGHTS))

val_all = sorted((DATASET_DIR / "images" / "val").glob("*.jpg"))
n_display = min(16, len(val_all))
display_imgs = random.sample(val_all, n_display)

n_cols = 4
n_rows = math.ceil(n_display / n_cols)
fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows))
axes = axes.flatten()

for idx, img_path in enumerate(display_imgs):
    preds = eval_model.predict(str(img_path), imgsz=IMGSZ, conf=0.25, verbose=False)
    annotated = preds[0].plot()
    axes[idx].imshow(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
    conf = preds[0].boxes.conf[0].item() if len(preds[0].boxes) > 0 else 0.0
    tag = "real" if img_path.name.startswith("real_") else "synth"
    axes[idx].set_title(f"{img_path.name[:25]}\n[{tag}] conf={conf:.3f}", fontsize=7)
    axes[idx].axis("off")

for idx in range(n_display, len(axes)):
    axes[idx].set_visible(False)

plt.suptitle("Fine-Tuned Model Predictions on Val Set", fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Focus on real-only val images (most important for production)
real_val = sorted(p for p in val_all if p.name.startswith("real_"))
print(f"Real val images: {len(real_val)}")

if len(real_val) > 0:
    n_real_display = min(len(real_val), 8)
    n_cols_r = min(4, n_real_display)
    n_rows_r = math.ceil(n_real_display / n_cols_r)
    fig, axes = plt.subplots(n_rows_r, n_cols_r, figsize=(5 * n_cols_r, 5 * n_rows_r))
    if n_real_display == 1:
        axes = [axes]
    else:
        axes = axes.flatten() if hasattr(axes, "flatten") else [axes]

    for idx, img_path in enumerate(real_val[:n_real_display]):
        preds = eval_model.predict(str(img_path), imgsz=IMGSZ, conf=0.25, verbose=False)
        annotated = preds[0].plot()
        axes[idx].imshow(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
        conf = preds[0].boxes.conf[0].item() if len(preds[0].boxes) > 0 else 0.0
        axes[idx].set_title(f"{img_path.name}\nconf={conf:.3f}", fontsize=8)
        axes[idx].axis("off")

    for idx in range(n_real_display, len(axes)):
        axes[idx].set_visible(False)

    plt.suptitle("Predictions on Real Val Images (most important for production)", fontsize=14)
    plt.tight_layout()
    plt.show()
else:
    print("No real images in val set — all predictions are on synthetic data.")

## 8. Mask Quality — IoU Between Predicted and Ground-Truth Polygons

Compute per-image mask IoU for all val images to quantify mask accuracy beyond mAP.
Split by real vs synthetic to catch domain-transfer gaps.

In [None]:
def compute_mask_iou(pred_mask, gt_polygon, img_shape):
    """Compute IoU between a predicted binary mask and a ground-truth polygon."""
    h, w = img_shape[:2]
    gt_mask = np.zeros((h, w), dtype=np.uint8)
    cv2.fillPoly(gt_mask, [gt_polygon.reshape(-1, 1, 2)], 1)

    # Resize predicted mask to image dimensions (pred comes at model output size)
    if pred_mask.shape != (h, w):
        pred_mask = cv2.resize(
            pred_mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST
        )

    intersection = np.logical_and(pred_mask, gt_mask).sum()
    union = np.logical_or(pred_mask, gt_mask).sum()
    return intersection / union if union > 0 else 0.0


iou_records = []
for img_path in tqdm(val_all, desc="Computing mask IoUs"):
    # Load GT polygon
    lbl_path = DATASET_DIR / "labels" / "val" / f"{img_path.stem}.txt"
    label_text = lbl_path.read_text().strip().split()
    coords = [float(x) for x in label_text[1:]]
    img = cv2.imread(str(img_path))
    h, w = img.shape[:2]
    gt_pts = np.array(
        [[coords[i] * w, coords[i + 1] * h] for i in range(0, len(coords), 2)],
        dtype=np.int32,
    )

    # Predict
    preds = eval_model.predict(str(img_path), imgsz=IMGSZ, conf=0.25, verbose=False)

    if preds[0].masks is not None and len(preds[0].masks) > 0:
        pred_mask = preds[0].masks.data[0].cpu().numpy()
        iou = compute_mask_iou(pred_mask, gt_pts, (h, w))
        conf = preds[0].boxes.conf[0].item()
    else:
        iou = 0.0
        conf = 0.0

    tag = "real" if img_path.name.startswith("real_") else "synth"
    iou_records.append({"filename": img_path.name, "type": tag, "iou": iou, "conf": conf})

iou_df = pd.DataFrame(iou_records)

print("\n=== Mask IoU Summary ===")
print(f"  Overall mean IoU: {iou_df['iou'].mean():.4f}")
real_ious = iou_df[iou_df["type"] == "real"]["iou"]
synth_ious = iou_df[iou_df["type"] == "synth"]["iou"]
if len(real_ious) > 0:
    print(f"  Real mean IoU:    {real_ious.mean():.4f}")
if len(synth_ious) > 0:
    print(f"  Synth mean IoU:   {synth_ious.mean():.4f}")
print(f"  Min IoU:          {iou_df['iou'].min():.4f} ({iou_df.loc[iou_df['iou'].idxmin(), 'filename']})")
print(f"  Images with IoU < 0.5: {(iou_df['iou'] < 0.5).sum()}")

# Histogram
fig, ax = plt.subplots(figsize=(8, 4))
ax.hist(iou_df["iou"], bins=20, edgecolor="black", alpha=0.7)
ax.axvline(iou_df["iou"].mean(), color="red", linestyle="--", label=f"Mean: {iou_df['iou'].mean():.3f}")
ax.set_xlabel("Mask IoU")
ax.set_ylabel("Count")
ax.set_title("Per-Image Mask IoU Distribution (Val Set)")
ax.legend()
plt.tight_layout()
plt.show()

## 8b. Real vs Synthetic Performance Breakdown

Since we have only 5 real validation images vs 88 synthetic, the overall metrics may hide
domain-transfer issues. This cell explicitly compares real vs synthetic performance.

In [None]:
# Separate real vs synthetic images from validation set
real_val_imgs = sorted(p for p in val_all if p.name.startswith("real_"))
synth_val_imgs = sorted(p for p in val_all if p.name.startswith("synth_"))

print(f"Real validation images:  {len(real_val_imgs)}")
print(f"Synth validation images: {len(synth_val_imgs)}")

# Compute metrics separately for real and synthetic
def eval_subset(img_list, subset_name):
    """Evaluate model on a subset of images."""
    if len(img_list) == 0:
        return None
    
    ious = []
    confs = []
    detect_count = 0
    
    for img_path in img_list:
        # Load GT
        lbl_path = DATASET_DIR / "labels" / "val" / f"{img_path.stem}.txt"
        label_text = lbl_path.read_text().strip().split()
        coords = [float(x) for x in label_text[1:]]
        img = cv2.imread(str(img_path))
        h, w = img.shape[:2]
        gt_pts = np.array(
            [[coords[i] * w, coords[i + 1] * h] for i in range(0, len(coords), 2)],
            dtype=np.int32,
        )
        
        # Predict
        preds = eval_model.predict(str(img_path), imgsz=IMGSZ, conf=0.25, verbose=False)
        
        if preds[0].masks is not None and len(preds[0].masks) > 0:
            pred_mask = preds[0].masks.data[0].cpu().numpy()
            iou = compute_mask_iou(pred_mask, gt_pts, (h, w))
            conf = preds[0].boxes.conf[0].item()
            detect_count += 1
        else:
            iou = 0.0
            conf = 0.0
        
        ious.append(iou)
        confs.append(conf)
    
    return {
        "count": len(img_list),
        "detected": detect_count,
        "detection_rate": detect_count / len(img_list),
        "mean_iou": np.mean(ious),
        "mean_conf": np.mean([c for c in confs if c > 0]) if detect_count > 0 else 0.0,
        "ious": ious,
    }

real_metrics = eval_subset(real_val_imgs, "real")
synth_metrics = eval_subset(synth_val_imgs, "synth")

print("\n=== Real vs Synthetic Performance ===")
print(f"{'Metric':<20} {'Real':>12} {'Synthetic':>12} {'Gap':>12}")
print("-" * 58)

if real_metrics:
    print(f"{'Images':<20} {real_metrics['count']:>12} {synth_metrics['count']:>12} {'-':>12}")
    print(f"{'Detection Rate':<20} {real_metrics['detection_rate']:>12.2%} {synth_metrics['detection_rate']:>12.2%} {abs(real_metrics['detection_rate'] - synth_metrics['detection_rate']):>11.2%}")
    print(f"{'Mean IoU':<20} {real_metrics['mean_iou']:>12.4f} {synth_metrics['mean_iou']:>12.4f} {abs(real_metrics['mean_iou'] - synth_metrics['mean_iou']):>12.4f}")
    print(f"{'Mean Confidence':<20} {real_metrics['mean_conf']:>12.4f} {synth_metrics['mean_conf']:>12.4f} {abs(real_metrics['mean_conf'] - synth_metrics['mean_conf']):>12.4f}")
    
    print("\n⚠️  NOTE: With only 5 real validation images, real metrics have high variance.")
    print("    Consider collecting 20+ real images for more reliable evaluation.")
    
    # Visual comparison
    if len(real_val_imgs) > 0 and len(synth_val_imgs) > 0:
        fig, axes = plt.subplots(1, 2, figsize=(12, 4))
        
        # Real IoU distribution
        axes[0].hist(real_metrics['ious'], bins=10, edgecolor='black', alpha=0.7, color='coral')
        axes[0].axvline(real_metrics['mean_iou'], color='red', linestyle='--', 
                       label=f"Mean: {real_metrics['mean_iou']:.3f}")
        axes[0].set_xlabel("IoU")
        axes[0].set_ylabel("Count")
        axes[0].set_title(f"Real Images (n={len(real_val_imgs)})")
        axes[0].legend()
        axes[0].set_xlim(0, 1)
        
        # Synthetic IoU distribution
        axes[1].hist(synth_metrics['ious'], bins=20, edgecolor='black', alpha=0.7, color='lightblue')
        axes[1].axvline(synth_metrics['mean_iou'], color='blue', linestyle='--',
                       label=f"Mean: {synth_metrics['mean_iou']:.3f}")
        axes[1].set_xlabel("IoU")
        axes[1].set_ylabel("Count")
        axes[1].set_title(f"Synthetic Images (n={len(synth_val_imgs)})")
        axes[1].legend()
        axes[1].set_xlim(0, 1)
        
        plt.suptitle("IoU Distribution: Real vs Synthetic", fontsize=14)
        plt.tight_layout()
        plt.show()
else:
    print("No real validation images found.")

## 9. Export to ONNX

Export the chosen checkpoint for deployment. ONNX is the target format for
inference in the web app (via ONNX Runtime in the browser or server-side).

In [None]:
export_model = YOLO(str(CHOSEN_WEIGHTS))

export_path = export_model.export(
    format="onnx",
    imgsz=IMGSZ,
    simplify=True,       # run onnxslim to optimize the graph
    dynamic=False,       # fixed input shape for simpler deployment
    half=False,          # FP32 for maximum compatibility
)

onnx_path = Path(export_path)
print(f"ONNX exported to: {onnx_path}")
print(f"File size: {onnx_path.stat().st_size / 1024 / 1024:.1f} MB")

In [None]:
# Smoke-test: load ONNX back through Ultralytics and compare to PyTorch
assert onnx_path.exists(), f"ONNX file not found: {onnx_path}"

onnx_model = YOLO(str(onnx_path))
test_img = random.choice(val_all)

pt_results = eval_model.predict(str(test_img), imgsz=IMGSZ, conf=0.25, verbose=False)
onnx_results = onnx_model.predict(str(test_img), imgsz=IMGSZ, conf=0.25, verbose=False)

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

axes[0].imshow(cv2.cvtColor(pt_results[0].plot(), cv2.COLOR_BGR2RGB))
pt_conf = pt_results[0].boxes.conf[0].item() if len(pt_results[0].boxes) > 0 else 0
axes[0].set_title(f"PyTorch (conf={pt_conf:.3f})", fontsize=10)
axes[0].axis("off")

axes[1].imshow(cv2.cvtColor(onnx_results[0].plot(), cv2.COLOR_BGR2RGB))
onnx_conf = onnx_results[0].boxes.conf[0].item() if len(onnx_results[0].boxes) > 0 else 0
axes[1].set_title(f"ONNX (conf={onnx_conf:.3f})", fontsize=10)
axes[1].axis("off")

plt.suptitle(f"Smoke Test: PyTorch vs ONNX — {test_img.name}", fontsize=12)
plt.tight_layout()
plt.show()

print(f"Confidence difference: {abs(pt_conf - onnx_conf):.6f}")
if abs(pt_conf - onnx_conf) < 0.01:
    print("ONNX export verified — predictions match PyTorch within tolerance.")
else:
    print("WARNING: ONNX predictions differ significantly from PyTorch!")

## 10. Summary & Experiment Tracking

In [None]:
metrics = {
    "experiment": EXPERIMENT_NAME,
    "timestamp": datetime.now().isoformat(),
    "model_variant": MODEL_VARIANT,
    "epochs_configured": EPOCHS,
    "epochs_completed": int(df["epoch"].max()) if "df" in dir() else EPOCHS,
    "imgsz": IMGSZ,
    "batch": BATCH,
    "seed": SEED,
    "device": DEVICE,
    "augmentation_overrides": AUG_OVERRIDES,
    "dataset": {
        "train_images": len(list((DATASET_DIR / "images" / "train").glob("*.jpg"))),
        "val_images": len(list((DATASET_DIR / "images" / "val").glob("*.jpg"))),
        "classes": data_cfg["names"],
    },
    "validation_metrics": {
        "box_mAP50": round(box_map50, 4),
        "box_mAP50_95": round(box_map5095, 4),
        "mask_mAP50": round(mask_map50, 4),
        "mask_mAP50_95": round(mask_map5095, 4),
    },
    "mask_iou": {
        "mean": round(float(iou_df["iou"].mean()), 4),
        "real_mean": round(float(real_ious.mean()), 4) if len(real_ious) > 0 else None,
        "synth_mean": round(float(synth_ious.mean()), 4) if len(synth_ious) > 0 else None,
        "min": round(float(iou_df["iou"].min()), 4),
    },
    "chosen_checkpoint": str(CHOSEN_WEIGHTS.name),
    "artifacts": {
        "weights_pt": str(CHOSEN_WEIGHTS),
        "weights_onnx": str(onnx_path),
        "train_dir": str(TRAIN_DIR),
    },
}

metrics_path = TRAIN_DIR / "metrics.json"
with open(metrics_path, "w") as f:
    json.dump(metrics, f, indent=2)

print(f"Metrics saved to: {metrics_path}")

In [None]:
print("=" * 60)
print("  TRAINING COMPLETE")
print("=" * 60)
print()
print(f"  Model:             {MODEL_VARIANT}")
print(f"  Chosen checkpoint: {CHOSEN_WEIGHTS.name}")
print(f"  Epochs completed:  {metrics['epochs_completed']}")
print()
print("  Validation Metrics:")
print(f"    Mask mAP@50:     {mask_map50:.4f}")
print(f"    Mask mAP@50-95:  {mask_map5095:.4f}")
print(f"    Mean mask IoU:   {iou_df['iou'].mean():.4f}")
if len(real_ious) > 0:
    print(f"    Real-img IoU:    {real_ious.mean():.4f}")
print()
print("  Artifacts:")
print(f"    PyTorch weights: {CHOSEN_WEIGHTS}")
print(f"    ONNX export:     {onnx_path}")
print(f"    Metrics JSON:    {metrics_path}")
print(f"    Training dir:    {TRAIN_DIR}")
print()
print("  Next steps:")
print("    - Copy best.onnx to the web app for inference")
print("    - If mask mAP50-95 < 0.85, consider:")
print("      - More real training photos")
print("      - Trying yolov8s-seg.pt (small) model")
print("      - Adjusting augmentation knobs")
print("=" * 60)