# üåø Off-Road Semantic Segmentation with SegFormer-B5
## Hack for Green Bharat ‚Äî God-Level Pipeline

**Architecture:** SegFormer-B5 (Mix Transformer Encoder + All-MLP Decoder)  
**Backbone:** `nvidia/segformer-b5-finetuned-ade-640-640` (pretrained on ADE20K, 150 classes ‚Üí fine-tuned to 10)  
**Hardware:** Kaggle H100 GPU (80GB VRAM)  

### Why SegFormer?
- **Hierarchical Transformer Encoder (MiT-B5):** Generates multi-scale features without positional encoding ‚Üí resolution-agnostic
- **Lightweight All-MLP Decoder:** Drastically fewer params than ASPP/atrous convolutions (DeepLabv3+), yet higher mIoU
- **State-of-the-art:** 84.0 mIoU on Cityscapes, outperforms DeepLabv3+, Swin-based methods & DINOv2 linear probes
- **Transfer learning friendly:** ADE20K pretrained weights transfer beautifully to off-road terrain segmentation

### References
- [1] Xie et al., "SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers," NeurIPS 2021
- [2] Chen et al., "Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation," ECCV 2018
- [3] Maturana et al., "Real-time Semantic Mapping for Autonomous Off-Road Navigation," 2017
- [4-6] Bozinovski (1976), Bozinovski (2020), Pan & Yang (2010) ‚Äî Transfer Learning foundations
- [7] Minhas, "Transfer Learning for Semantic Segmentation using PyTorch DeepLabv3," 2019

---
## 1. Environment Setup & Installations

In [None]:
%%capture
!pip install -q transformers accelerate datasets evaluate segmentation-models-pytorch albumentations timm
!pip install -q matplotlib seaborn tqdm pillow opencv-python-headless

In [None]:
import os
import sys
import json
import glob
import time
import shutil
import zipfile
import random
import warnings
import datetime
from pathlib import Path
from collections import OrderedDict

import numpy as np
import pandas as pd
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
import torchvision.transforms as T

import albumentations as A
from albumentations.pytorch import ToTensorV2

from transformers import (
    SegformerForSemanticSegmentation,
    SegformerConfig,
    SegformerImageProcessor,
)

from tqdm.auto import tqdm

warnings.filterwarnings('ignore')
plt.style.use('seaborn-v0_8-whitegrid')

# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ Reproducibility ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = False  # H100: allow TF32
torch.backends.cudnn.benchmark = True       # H100: auto-tune kernels

# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ Device ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üñ•Ô∏è  Device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
    # Enable TF32 for H100 (massive speedup)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    print(f"   TF32: Enabled ‚úÖ")

---
## 2. Configuration

In [None]:
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
#  MASTER CONFIGURATION ‚Äî Paths matched to Kaggle dataset structure
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê

class CFG:
    # ‚îÄ‚îÄ Paths (Kaggle ‚Äî already extracted, double-nested) ‚îÄ‚îÄ
    # Dataset name on Kaggle: "hack_for_bharat"
    DATA_ROOT        = '/kaggle/input/hack-for-bharat'
    
    # Training data (double-nested folder structure)
    TRAIN_DIR        = os.path.join(DATA_ROOT, 'Offroad_Segmentation_Training_Dataset',
                                    'Offroad_Segmentation_Training_Dataset', 'train')
    VAL_DIR          = os.path.join(DATA_ROOT, 'Offroad_Segmentation_Training_Dataset',
                                    'Offroad_Segmentation_Training_Dataset', 'val')
    TEST_DIR         = os.path.join(DATA_ROOT, 'Offroad_Segmentation_testImages',
                                    'Offroad_Segmentation_testImages')
    
    OUTPUT_DIR       = '/kaggle/working/outputs'
    CHECKPOINT_DIR   = '/kaggle/working/checkpoints'
    RESULTS_DIR      = '/kaggle/working/results'
    
    # ‚îÄ‚îÄ Model ‚îÄ‚îÄ
    MODEL_NAME       = 'nvidia/segformer-b5-finetuned-ade-640-640'
    NUM_CLASSES      = 10
    
    # ‚îÄ‚îÄ Training ‚îÄ‚îÄ
    IMG_SIZE         = 640          # SegFormer works best at 640
    BATCH_SIZE       = 8            # H100 80GB can handle this at 640x640
    ACCUMULATION     = 2            # Effective batch = 16
    NUM_EPOCHS       = 100          # Will early-stop much sooner
    LR               = 2e-4         # AdamW peak LR
    MIN_LR           = 1e-7         # Cosine annealing floor
    WEIGHT_DECAY     = 0.01
    WARMUP_EPOCHS    = 3
    
    # ‚îÄ‚îÄ Early Stopping ‚îÄ‚îÄ
    PATIENCE_MIOU    = 15           # Stop if val mIoU doesn't improve for N epochs
    PATIENCE_LOSS    = 20           # Stop if val loss doesn't improve for N epochs
    MIN_DELTA        = 1e-4         # Minimum improvement threshold
    
    # ‚îÄ‚îÄ Checkpointing ‚îÄ‚îÄ
    SAVE_EVERY       = 5            # Save checkpoint every N epochs
    KEEP_TOP_K       = 3            # Keep top K best models
    
    # ‚îÄ‚îÄ Augmentation ‚îÄ‚îÄ
    USE_MIXUP        = False        # Mixup augmentation
    AUG_PROB         = 0.5          # Probability of each augmentation
    
    # ‚îÄ‚îÄ Mixed Precision ‚îÄ‚îÄ
    USE_AMP          = True         # FP16/BF16 on H100
    
    # ‚îÄ‚îÄ Workers ‚îÄ‚îÄ
    NUM_WORKERS      = 4
    PIN_MEMORY       = True

# Create directories
for d in [CFG.OUTPUT_DIR, CFG.CHECKPOINT_DIR, CFG.RESULTS_DIR]:
    os.makedirs(d, exist_ok=True)

print("Configuration:")
for k, v in vars(CFG).items():
    if not k.startswith('_'):
        print(f"  {k}: {v}")

---
## 3. Verify Dataset Paths

In [None]:
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
#  VERIFY DATASET PATHS (already extracted on Kaggle)
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê

# Kaggle dataset structure (double-nested):
# /kaggle/input/hack-for-bharat/
#   ‚îú‚îÄ‚îÄ Offroad_Segmentation_Scripts/
#   ‚îú‚îÄ‚îÄ Offroad_Segmentation_Training_Dataset/
#   ‚îÇ   ‚îî‚îÄ‚îÄ Offroad_Segmentation_Training_Dataset/
#   ‚îÇ       ‚îú‚îÄ‚îÄ train/
#   ‚îÇ       ‚îÇ   ‚îú‚îÄ‚îÄ Color_Images/
#   ‚îÇ       ‚îÇ   ‚îî‚îÄ‚îÄ Segmentation/
#   ‚îÇ       ‚îî‚îÄ‚îÄ val/
#   ‚îÇ           ‚îú‚îÄ‚îÄ Color_Images/
#   ‚îÇ           ‚îî‚îÄ‚îÄ Segmentation/
#   ‚îî‚îÄ‚îÄ Offroad_Segmentation_testImages/
#       ‚îî‚îÄ‚îÄ Offroad_Segmentation_testImages/
#           ‚îú‚îÄ‚îÄ Color_Images/
#           ‚îî‚îÄ‚îÄ Segmentation/

TRAIN_DIR = CFG.TRAIN_DIR
VAL_DIR   = CFG.VAL_DIR
TEST_DIR  = CFG.TEST_DIR

# Auto-detect: try common Kaggle dataset name variants
if not os.path.exists(TRAIN_DIR):
    print("‚ö†Ô∏è  Default paths not found, auto-detecting...")
    # Search for the train/Color_Images folder anywhere under /kaggle/input
    import subprocess
    candidates = glob.glob('/kaggle/input/**/train/Color_Images', recursive=True)
    if candidates:
        TRAIN_DIR = os.path.dirname(candidates[0])               # .../train
        VAL_DIR   = os.path.join(os.path.dirname(TRAIN_DIR), 'val')
        print(f"   Found TRAIN_DIR: {TRAIN_DIR}")
        print(f"   Found VAL_DIR:   {VAL_DIR}")
    
    test_candidates = glob.glob('/kaggle/input/**/Offroad_Segmentation_testImages/**/Color_Images', recursive=True)
    if test_candidates:
        TEST_DIR = os.path.dirname(test_candidates[0])
        print(f"   Found TEST_DIR:  {TEST_DIR}")

# Verify
print("\nüìÇ Dataset Paths:")
for name, path in [('Train', TRAIN_DIR), ('Val', VAL_DIR), ('Test', TEST_DIR)]:
    img_dir = os.path.join(path, 'Color_Images')
    seg_dir = os.path.join(path, 'Segmentation')
    n_img = len(os.listdir(img_dir)) if os.path.exists(img_dir) else 0
    n_seg = len(os.listdir(seg_dir)) if os.path.exists(seg_dir) else 0
    status = "‚úÖ" if n_img > 0 else "‚ùå"
    print(f"  {status} {name:5s}: {n_img} images, {n_seg} masks  ‚Üí  {path}")

---
## 4. Class Definitions & Color Palette

In [None]:
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
#  CLASS MAPPING ‚Äî From raw mask pixel values to class IDs
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê

VALUE_MAP = {
    0    : 0,   # Background
    100  : 1,   # Trees
    200  : 2,   # Lush Bushes
    300  : 3,   # Dry Grass
    500  : 4,   # Dry Bushes
    550  : 5,   # Ground Clutter
    700  : 6,   # Logs
    800  : 7,   # Rocks
    7100 : 8,   # Landscape
    10000: 9,   # Sky
}

CLASS_NAMES = [
    'Background', 'Trees', 'Lush Bushes', 'Dry Grass', 'Dry Bushes',
    'Ground Clutter', 'Logs', 'Rocks', 'Landscape', 'Sky'
]

# Beautiful color palette for visualization
COLOR_PALETTE = np.array([
    [  0,   0,   0],   # Background  ‚Äî black
    [ 34, 139,  34],   # Trees       ‚Äî forest green
    [  0, 255,   0],   # Lush Bushes ‚Äî lime
    [210, 180, 140],   # Dry Grass   ‚Äî tan
    [139,  90,  43],   # Dry Bushes  ‚Äî brown
    [128, 128,   0],   # Ground Clutter ‚Äî olive
    [139,  69,  19],   # Logs        ‚Äî saddle brown
    [128, 128, 128],   # Rocks       ‚Äî gray
    [160,  82,  45],   # Landscape   ‚Äî sienna
    [135, 206, 235],   # Sky         ‚Äî sky blue
], dtype=np.uint8)

ID2LABEL = {i: name for i, name in enumerate(CLASS_NAMES)}
LABEL2ID = {name: i for i, name in enumerate(CLASS_NAMES)}

print(f"Number of classes: {CFG.NUM_CLASSES}")
for i, name in enumerate(CLASS_NAMES):
    print(f"  {i}: {name} (color: {COLOR_PALETTE[i].tolist()})")

---
## 5. Dataset & Augmentations

In [None]:
def convert_mask(mask_np):
    """Convert raw mask pixel values (0, 100, 200, ..., 10000) ‚Üí class IDs (0-9)."""
    out = np.zeros_like(mask_np, dtype=np.uint8)
    for raw_val, class_id in VALUE_MAP.items():
        out[mask_np == raw_val] = class_id
    return out


def mask_to_color(mask_np):
    """Convert class ID mask to RGB visualization."""
    h, w = mask_np.shape
    color = np.zeros((h, w, 3), dtype=np.uint8)
    for cid in range(CFG.NUM_CLASSES):
        color[mask_np == cid] = COLOR_PALETTE[cid]
    return color


# ‚îÄ‚îÄ Albumentations pipelines (compatible with albumentations >= 2.0) ‚îÄ‚îÄ
def get_train_transforms(img_size=CFG.IMG_SIZE):
    return A.Compose([
        # Resize first, then random scale + crop for multi-scale training
        A.Resize(height=img_size, width=img_size),
        A.RandomScale(scale_limit=(-0.5, 0.5), p=0.5),
        A.PadIfNeeded(min_height=img_size, min_width=img_size,
                      border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0),
        A.RandomCrop(height=img_size, width=img_size),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.RandomRotate90(p=0.25),
        A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=15,
                           p=CFG.AUG_PROB, border_mode=cv2.BORDER_CONSTANT),
        A.OneOf([
            A.GaussNoise(var_limit=(10, 50), p=1),
            A.GaussianBlur(blur_limit=(3, 7), p=1),
            A.MotionBlur(blur_limit=7, p=1),
        ], p=0.3),
        A.OneOf([
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1),
            A.HueSaturationValue(hue_shift_limit=15, sat_shift_limit=25, val_shift_limit=15, p=1),
            A.CLAHE(clip_limit=4.0, p=1),
            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=1),
        ], p=CFG.AUG_PROB),
        A.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.2),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])


def get_val_transforms(img_size=CFG.IMG_SIZE):
    return A.Compose([
        A.Resize(height=img_size, width=img_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])


class OffroadSegmentationDataset(Dataset):
    """Off-road terrain segmentation dataset."""
    
    def __init__(self, data_dir, transforms=None, is_test=False):
        self.image_dir = os.path.join(data_dir, 'Color_Images')
        self.mask_dir  = os.path.join(data_dir, 'Segmentation')
        self.transforms = transforms
        self.is_test = is_test
        
        self.image_files = sorted(os.listdir(self.image_dir))
        print(f"  Loaded {len(self.image_files)} samples from {data_dir}")
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        fname = self.image_files[idx]
        
        # Load image
        img_path = os.path.join(self.image_dir, fname)
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Load mask
        mask_path = os.path.join(self.mask_dir, fname)
        mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
        
        # Handle multi-channel masks
        if mask is not None and len(mask.shape) == 3:
            mask = mask[:, :, 0]
        
        # Convert raw values ‚Üí class IDs
        mask = convert_mask(mask)
        
        # Apply augmentations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']   # (C, H, W) float tensor
            mask  = augmented['mask']    # (H, W) uint8
        
        mask = mask.long()
        return image, mask, fname


# Create datasets
print("Creating datasets...")
train_dataset = OffroadSegmentationDataset(TRAIN_DIR, transforms=get_train_transforms())
val_dataset   = OffroadSegmentationDataset(VAL_DIR, transforms=get_val_transforms())
test_dataset  = OffroadSegmentationDataset(TEST_DIR, transforms=get_val_transforms(), is_test=True)

# Create dataloaders
train_loader = DataLoader(
    train_dataset, batch_size=CFG.BATCH_SIZE, shuffle=True,
    num_workers=CFG.NUM_WORKERS, pin_memory=CFG.PIN_MEMORY, drop_last=True
)
val_loader = DataLoader(
    val_dataset, batch_size=CFG.BATCH_SIZE, shuffle=False,
    num_workers=CFG.NUM_WORKERS, pin_memory=CFG.PIN_MEMORY
)
test_loader = DataLoader(
    test_dataset, batch_size=CFG.BATCH_SIZE, shuffle=False,
    num_workers=CFG.NUM_WORKERS, pin_memory=CFG.PIN_MEMORY
)

print(f"\nüìä Dataloaders created:")
print(f"  Train: {len(train_dataset)} samples ‚Üí {len(train_loader)} batches")
print(f"  Val:   {len(val_dataset)} samples ‚Üí {len(val_loader)} batches")
print(f"  Test:  {len(test_dataset)} samples ‚Üí {len(test_loader)} batches")

---
## 6. Visualize Samples

In [None]:
def visualize_samples(dataset, n=4, title='Samples'):
    """Visualize random image-mask pairs."""
    fig, axes = plt.subplots(n, 3, figsize=(15, 5*n))
    if n == 1:
        axes = axes[np.newaxis, :]
    
    indices = random.sample(range(len(dataset)), n)
    mean = np.array([0.485, 0.456, 0.406])
    std  = np.array([0.229, 0.224, 0.225])
    
    for row, idx in enumerate(indices):
        img, mask, fname = dataset[idx]
        
        # Denormalize image
        img_np = img.numpy().transpose(1, 2, 0)
        img_np = (img_np * std + mean) * 255
        img_np = np.clip(img_np, 0, 255).astype(np.uint8)
        
        # Mask
        mask_np = mask.numpy().astype(np.uint8)
        mask_color = mask_to_color(mask_np)
        
        # Overlay
        overlay = cv2.addWeighted(img_np, 0.6, mask_color, 0.4, 0)
        
        axes[row, 0].imshow(img_np)
        axes[row, 0].set_title(f'Image: {fname}')
        axes[row, 0].axis('off')
        
        axes[row, 1].imshow(mask_color)
        axes[row, 1].set_title('Ground Truth Mask')
        axes[row, 1].axis('off')
        
        axes[row, 2].imshow(overlay)
        axes[row, 2].set_title('Overlay')
        axes[row, 2].axis('off')
    
    # Legend
    patches = [mpatches.Patch(color=COLOR_PALETTE[i]/255., label=CLASS_NAMES[i]) for i in range(CFG.NUM_CLASSES)]
    fig.legend(handles=patches, loc='lower center', ncol=5, fontsize=10, bbox_to_anchor=(0.5, -0.02))
    
    plt.suptitle(title, fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(os.path.join(CFG.OUTPUT_DIR, 'sample_visualization.png'), dpi=150, bbox_inches='tight')
    plt.show()

visualize_samples(train_dataset, n=4, title='Training Samples')

---
## 7. Class Distribution Analysis

In [None]:
def compute_class_weights(dataset, num_samples=200):
    """Compute class weights based on pixel frequency."""
    print("Computing class pixel frequencies...")
    pixel_counts = np.zeros(CFG.NUM_CLASSES, dtype=np.float64)
    
    indices = random.sample(range(len(dataset)), min(num_samples, len(dataset)))
    for idx in tqdm(indices, desc='Scanning'):
        _, mask, _ = dataset[idx]
        mask_np = mask.numpy()
        for c in range(CFG.NUM_CLASSES):
            pixel_counts[c] += (mask_np == c).sum()
    
    total_pixels = pixel_counts.sum()
    freq = pixel_counts / total_pixels
    
    # Inverse frequency weights (capped)
    weights = 1.0 / (freq + 1e-6)
    weights = weights / weights.sum() * CFG.NUM_CLASSES  # normalize so mean=1
    weights = np.clip(weights, 0.5, 10.0)  # cap extreme weights
    
    print("\nClass Distribution:")
    for i in range(CFG.NUM_CLASSES):
        bar = '‚ñà' * int(freq[i] * 100)
        print(f"  {CLASS_NAMES[i]:<16}: {freq[i]*100:6.2f}%  {bar}  (weight: {weights[i]:.3f})")
    
    # Plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    colors = [COLOR_PALETTE[i]/255. for i in range(CFG.NUM_CLASSES)]
    
    ax1.barh(CLASS_NAMES, freq * 100, color=colors, edgecolor='black')
    ax1.set_xlabel('Pixel Percentage (%)')
    ax1.set_title('Class Distribution')
    
    ax2.barh(CLASS_NAMES, weights, color=colors, edgecolor='black')
    ax2.set_xlabel('Weight')
    ax2.set_title('Class Weights (Inverse Frequency)')
    
    plt.tight_layout()
    plt.savefig(os.path.join(CFG.OUTPUT_DIR, 'class_distribution.png'), dpi=150, bbox_inches='tight')
    plt.show()
    
    return torch.tensor(weights, dtype=torch.float32)

class_weights = compute_class_weights(train_dataset, num_samples=300)
print(f"\nClass weights tensor: {class_weights}")

---
## 8. SegFormer Model Setup

In [None]:
def build_segformer_model():
    """Build SegFormer-B5 with custom number of classes."""
    print(f"üîß Loading SegFormer: {CFG.MODEL_NAME}")
    print(f"   Fine-tuning for {CFG.NUM_CLASSES} classes")
    
    model = SegformerForSemanticSegmentation.from_pretrained(
        CFG.MODEL_NAME,
        num_labels=CFG.NUM_CLASSES,
        id2label=ID2LABEL,
        label2id=LABEL2ID,
        ignore_mismatched_sizes=True,  # decoder head size changes
    )
    
    # Count parameters
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"   Total params:     {total/1e6:.1f}M")
    print(f"   Trainable params: {trainable/1e6:.1f}M")
    
    return model.to(device)

model = build_segformer_model()

---
## 9. Loss Functions, Metrics & Optimizer

In [None]:
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
#  COMBINED LOSS: CrossEntropy + Dice + Focal
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê

class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6, num_classes=10):
        super().__init__()
        self.smooth = smooth
        self.num_classes = num_classes
    
    def forward(self, pred, target):
        pred_soft = F.softmax(pred, dim=1)
        target_oh = F.one_hot(target, self.num_classes).permute(0, 3, 1, 2).float()
        
        intersection = (pred_soft * target_oh).sum(dim=(2, 3))
        union = pred_soft.sum(dim=(2, 3)) + target_oh.sum(dim=(2, 3))
        
        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1. - dice.mean()


class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, pred, target):
        ce_loss = F.cross_entropy(pred, target, weight=self.alpha, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        return focal_loss


class CombinedLoss(nn.Module):
    """CE + Dice + Focal combined loss for robust training."""
    def __init__(self, class_weights=None, ce_weight=1.0, dice_weight=1.0, focal_weight=0.5):
        super().__init__()
        self.ce_weight = ce_weight
        self.dice_weight = dice_weight
        self.focal_weight = focal_weight
        
        w = class_weights.to(device) if class_weights is not None else None
        self.ce_loss = nn.CrossEntropyLoss(weight=w)
        self.dice_loss = DiceLoss(num_classes=CFG.NUM_CLASSES)
        self.focal_loss = FocalLoss(alpha=w, gamma=2.0)
    
    def forward(self, pred, target):
        loss_ce    = self.ce_loss(pred, target)
        loss_dice  = self.dice_loss(pred, target)
        loss_focal = self.focal_loss(pred, target)
        
        total = (self.ce_weight * loss_ce +
                 self.dice_weight * loss_dice +
                 self.focal_weight * loss_focal)
        return total, {
            'ce': loss_ce.item(),
            'dice': loss_dice.item(),
            'focal': loss_focal.item(),
        }


# ‚îÄ‚îÄ Metrics ‚îÄ‚îÄ
def compute_iou(pred, target, num_classes=CFG.NUM_CLASSES, smooth=1e-6):
    """Compute per-class IoU and mean IoU."""
    pred_cls = pred.argmax(dim=1).view(-1)
    target_flat = target.view(-1)
    
    iou_per_class = []
    for c in range(num_classes):
        pred_c = (pred_cls == c)
        tgt_c  = (target_flat == c)
        inter  = (pred_c & tgt_c).sum().float()
        union  = (pred_c | tgt_c).sum().float()
        if union == 0:
            iou_per_class.append(float('nan'))
        else:
            iou_per_class.append((inter / (union + smooth)).item())
    
    return np.nanmean(iou_per_class), iou_per_class


def compute_dice(pred, target, num_classes=CFG.NUM_CLASSES, smooth=1e-6):
    """Compute per-class Dice and mean Dice."""
    pred_cls = pred.argmax(dim=1).view(-1)
    target_flat = target.view(-1)
    
    dice_per_class = []
    for c in range(num_classes):
        pred_c = (pred_cls == c)
        tgt_c  = (target_flat == c)
        inter  = (pred_c & tgt_c).sum().float()
        dice = (2. * inter + smooth) / (pred_c.sum().float() + tgt_c.sum().float() + smooth)
        dice_per_class.append(dice.item())
    
    return np.mean(dice_per_class), dice_per_class


def compute_pixel_accuracy(pred, target):
    """Compute overall pixel accuracy."""
    pred_cls = pred.argmax(dim=1)
    return (pred_cls == target).float().mean().item()


# ‚îÄ‚îÄ Instantiate ‚îÄ‚îÄ
criterion = CombinedLoss(class_weights=class_weights, ce_weight=1.0, dice_weight=1.0, focal_weight=0.5)
print("‚úÖ Combined Loss: CrossEntropy + Dice + Focal")

In [None]:
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
#  OPTIMIZER & SCHEDULER
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê

# Layer-wise learning rate decay for transformer
def get_parameter_groups(model, lr=CFG.LR, wd=CFG.WEIGHT_DECAY, lr_decay=0.9):
    """Apply layer-wise LR decay: earlier layers get lower LR."""
    no_decay = ['bias', 'LayerNorm.weight', 'layer_norm.weight']
    
    # Encoder layers ‚Üí decayed LR
    encoder_params_decay = []
    encoder_params_no_decay = []
    decoder_params_decay = []
    decoder_params_no_decay = []
    
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        is_no_decay = any(nd in name for nd in no_decay)
        
        if 'encoder' in name or 'segformer.encoder' in name:
            if is_no_decay:
                encoder_params_no_decay.append(param)
            else:
                encoder_params_decay.append(param)
        else:
            if is_no_decay:
                decoder_params_no_decay.append(param)
            else:
                decoder_params_decay.append(param)
    
    param_groups = [
        {'params': encoder_params_decay,     'lr': lr * 0.1, 'weight_decay': wd},
        {'params': encoder_params_no_decay,  'lr': lr * 0.1, 'weight_decay': 0.0},
        {'params': decoder_params_decay,     'lr': lr,       'weight_decay': wd},
        {'params': decoder_params_no_decay,  'lr': lr,       'weight_decay': 0.0},
    ]
    
    print(f"  Encoder params (with decay):    {len(encoder_params_decay)} tensors, LR={lr*0.1:.1e}")
    print(f"  Encoder params (no decay):      {len(encoder_params_no_decay)} tensors, LR={lr*0.1:.1e}")
    print(f"  Decoder params (with decay):    {len(decoder_params_decay)} tensors, LR={lr:.1e}")
    print(f"  Decoder params (no decay):      {len(decoder_params_no_decay)} tensors, LR={lr:.1e}")
    
    return param_groups


param_groups = get_parameter_groups(model)
optimizer = optim.AdamW(param_groups, lr=CFG.LR, weight_decay=CFG.WEIGHT_DECAY)

# Cosine Annealing with Warm Restarts
total_steps = len(train_loader) * CFG.NUM_EPOCHS // CFG.ACCUMULATION
warmup_steps = len(train_loader) * CFG.WARMUP_EPOCHS // CFG.ACCUMULATION

def cosine_warmup_scheduler(optimizer, warmup_steps, total_steps, min_lr=CFG.MIN_LR):
    def lr_lambda(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return max(min_lr / CFG.LR, 0.5 * (1.0 + np.cos(np.pi * progress)))
    return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

scheduler = cosine_warmup_scheduler(optimizer, warmup_steps, total_steps)
scaler = GradScaler(enabled=CFG.USE_AMP)

print(f"\n‚úÖ Optimizer: AdamW (layerwise LR decay)")
print(f"‚úÖ Scheduler: Cosine with {CFG.WARMUP_EPOCHS} warmup epochs")
print(f"‚úÖ AMP: {'Enabled' if CFG.USE_AMP else 'Disabled'}")
print(f"   Total steps: {total_steps}, Warmup steps: {warmup_steps}")

---
## 10. Early Stopping & Checkpoint Manager

In [None]:
class EarlyStopping:
    """Multi-metric early stopping with patience."""
    
    def __init__(self, patience_miou=CFG.PATIENCE_MIOU, patience_loss=CFG.PATIENCE_LOSS,
                 min_delta=CFG.MIN_DELTA):
        self.patience_miou = patience_miou
        self.patience_loss = patience_loss
        self.min_delta = min_delta
        
        self.best_miou = -np.inf
        self.best_loss = np.inf
        self.miou_counter = 0
        self.loss_counter = 0
        self.should_stop = False
        self.best_epoch_miou = 0
        self.best_epoch_loss = 0
    
    def __call__(self, epoch, val_miou, val_loss):
        improved = False
        
        # Check mIoU improvement
        if val_miou > self.best_miou + self.min_delta:
            self.best_miou = val_miou
            self.miou_counter = 0
            self.best_epoch_miou = epoch
            improved = True
        else:
            self.miou_counter += 1
        
        # Check loss improvement
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.loss_counter = 0
            self.best_epoch_loss = epoch
        else:
            self.loss_counter += 1
        
        # Stop if BOTH metrics stagnated
        if self.miou_counter >= self.patience_miou and self.loss_counter >= self.patience_loss:
            self.should_stop = True
            print(f"\nüõë EARLY STOPPING at epoch {epoch+1}")
            print(f"   mIoU: no improvement for {self.miou_counter} epochs (best: {self.best_miou:.4f} at epoch {self.best_epoch_miou+1})")
            print(f"   Loss: no improvement for {self.loss_counter} epochs (best: {self.best_loss:.4f} at epoch {self.best_epoch_loss+1})")
        
        return improved
    
    def status(self):
        return (f"mIoU patience: {self.miou_counter}/{self.patience_miou} | "
                f"Loss patience: {self.loss_counter}/{self.patience_loss}")


class CheckpointManager:
    """Manages model checkpoints: best, periodic & top-K."""
    
    def __init__(self, checkpoint_dir, keep_top_k=CFG.KEEP_TOP_K):
        self.checkpoint_dir = checkpoint_dir
        self.keep_top_k = keep_top_k
        self.best_models = []  # (miou, path) sorted ascending
        os.makedirs(checkpoint_dir, exist_ok=True)
    
    def save_checkpoint(self, model, optimizer, scheduler, scaler, epoch, metrics, is_best=False):
        """Save a full training checkpoint (resumable)."""
        state = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'scaler_state_dict': scaler.state_dict() if scaler else None,
            'metrics': metrics,
            'config': {k: v for k, v in vars(CFG).items() if not k.startswith('_')},
        }
        
        # Save periodic checkpoint
        path = os.path.join(self.checkpoint_dir, f'checkpoint_epoch_{epoch+1:03d}.pt')
        torch.save(state, path)
        
        # Save best model
        if is_best:
            best_path = os.path.join(self.checkpoint_dir, 'best_model.pt')
            torch.save(state, best_path)
            print(f"  üíæ Saved BEST model (mIoU: {metrics['val_miou']:.4f})")
            
            # Also save HuggingFace-format for easy loading
            hf_path = os.path.join(self.checkpoint_dir, 'best_model_hf')
            model.save_pretrained(hf_path)
        
        # Top-K management
        miou = metrics.get('val_miou', 0)
        self.best_models.append((miou, path))
        self.best_models.sort(key=lambda x: x[0])
        
        while len(self.best_models) > self.keep_top_k:
            _, remove_path = self.best_models.pop(0)
            if os.path.exists(remove_path) and 'best_model' not in remove_path:
                os.remove(remove_path)
        
        return path
    
    def save_last(self, model, optimizer, scheduler, scaler, epoch, metrics):
        """Always save 'last' checkpoint for resuming."""
        state = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'scaler_state_dict': scaler.state_dict() if scaler else None,
            'metrics': metrics,
        }
        path = os.path.join(self.checkpoint_dir, 'last_checkpoint.pt')
        torch.save(state, path)
        return path
    
    @staticmethod
    def load_checkpoint(path, model, optimizer=None, scheduler=None, scaler=None):
        """Load checkpoint and optionally restore optimizer/scheduler."""
        print(f"üìÇ Loading checkpoint: {path}")
        ckpt = torch.load(path, map_location=device)
        model.load_state_dict(ckpt['model_state_dict'])
        
        if optimizer and 'optimizer_state_dict' in ckpt:
            optimizer.load_state_dict(ckpt['optimizer_state_dict'])
        if scheduler and 'scheduler_state_dict' in ckpt:
            scheduler.load_state_dict(ckpt['scheduler_state_dict'])
        if scaler and ckpt.get('scaler_state_dict'):
            scaler.load_state_dict(ckpt['scaler_state_dict'])
        
        epoch = ckpt.get('epoch', 0)
        metrics = ckpt.get('metrics', {})
        print(f"   Resumed from epoch {epoch+1}, val_miou={metrics.get('val_miou', 'N/A')}")
        return epoch, metrics


# Instantiate
early_stopping = EarlyStopping()
ckpt_manager = CheckpointManager(CFG.CHECKPOINT_DIR)
print("‚úÖ Early Stopping (dual mIoU+Loss) & Checkpoint Manager initialized")

---
## 11. Training Loop (God-Level)

In [None]:
def train_one_epoch(model, loader, criterion, optimizer, scheduler, scaler, epoch):
    """Train for one epoch with gradient accumulation & AMP."""
    model.train()
    running_loss = 0.0
    running_ce = 0.0
    running_dice = 0.0
    running_focal = 0.0
    num_batches = 0
    
    # Metrics accumulators
    all_iou = []
    all_dice = []
    all_acc = []
    
    optimizer.zero_grad()
    pbar = tqdm(loader, desc=f'Epoch {epoch+1} [Train]', leave=False)
    
    for step, (images, masks, _) in enumerate(pbar):
        images = images.to(device, non_blocking=True)
        masks  = masks.to(device, non_blocking=True)
        
        with autocast(enabled=CFG.USE_AMP):
            outputs = model(pixel_values=images)
            logits = outputs.logits  # (B, num_classes, H/4, W/4)
            
            # Upsample to original size
            logits_up = F.interpolate(logits, size=masks.shape[-2:], mode='bilinear', align_corners=False)
            
            loss, loss_dict = criterion(logits_up, masks)
            loss = loss / CFG.ACCUMULATION
        
        scaler.scale(loss).backward()
        
        if (step + 1) % CFG.ACCUMULATION == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            scheduler.step()
        
        # Track losses
        running_loss += loss.item() * CFG.ACCUMULATION
        running_ce += loss_dict['ce']
        running_dice += loss_dict['dice']
        running_focal += loss_dict['focal']
        num_batches += 1
        
        # Compute metrics on this batch
        with torch.no_grad():
            miou, _ = compute_iou(logits_up, masks)
            mdice, _ = compute_dice(logits_up, masks)
            acc = compute_pixel_accuracy(logits_up, masks)
            all_iou.append(miou)
            all_dice.append(mdice)
            all_acc.append(acc)
        
        # Progress bar
        pbar.set_postfix({
            'loss': f'{running_loss/num_batches:.4f}',
            'mIoU': f'{np.mean(all_iou):.3f}',
            'lr': f'{scheduler.get_last_lr()[0]:.2e}',
        })
    
    return {
        'loss': running_loss / num_batches,
        'ce_loss': running_ce / num_batches,
        'dice_loss': running_dice / num_batches,
        'focal_loss': running_focal / num_batches,
        'miou': np.mean(all_iou),
        'dice': np.mean(all_dice),
        'pixel_acc': np.mean(all_acc),
        'lr': scheduler.get_last_lr()[0],
    }


@torch.no_grad()
def validate(model, loader, criterion):
    """Validate the model."""
    model.eval()
    running_loss = 0.0
    num_batches = 0
    all_iou = []
    all_dice = []
    all_acc = []
    all_class_iou = []
    
    pbar = tqdm(loader, desc='Validating', leave=False)
    for images, masks, _ in pbar:
        images = images.to(device, non_blocking=True)
        masks  = masks.to(device, non_blocking=True)
        
        with autocast(enabled=CFG.USE_AMP):
            outputs = model(pixel_values=images)
            logits = outputs.logits
            logits_up = F.interpolate(logits, size=masks.shape[-2:], mode='bilinear', align_corners=False)
            loss, _ = criterion(logits_up, masks)
        
        running_loss += loss.item()
        num_batches += 1
        
        miou, class_iou = compute_iou(logits_up, masks)
        mdice, _ = compute_dice(logits_up, masks)
        acc = compute_pixel_accuracy(logits_up, masks)
        
        all_iou.append(miou)
        all_dice.append(mdice)
        all_acc.append(acc)
        all_class_iou.append(class_iou)
        
        pbar.set_postfix({'loss': f'{running_loss/num_batches:.4f}', 'mIoU': f'{miou:.3f}'})
    
    avg_class_iou = np.nanmean(all_class_iou, axis=0)
    
    return {
        'loss': running_loss / num_batches,
        'miou': np.mean(all_iou),
        'dice': np.mean(all_dice),
        'pixel_acc': np.mean(all_acc),
        'class_iou': avg_class_iou,
    }

print("‚úÖ Training and validation functions defined")

In [None]:
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
#  MAIN TRAINING LOOP
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê

history = {
    'train_loss': [], 'val_loss': [],
    'train_miou': [], 'val_miou': [],
    'train_dice': [], 'val_dice': [],
    'train_pixel_acc': [], 'val_pixel_acc': [],
    'train_ce': [], 'train_dice_loss': [], 'train_focal': [],
    'lr': [],
}

best_miou = 0.0
start_epoch = 0

# ‚îÄ‚îÄ Resume from checkpoint if exists ‚îÄ‚îÄ
resume_path = os.path.join(CFG.CHECKPOINT_DIR, 'last_checkpoint.pt')
if os.path.exists(resume_path):
    print("üîÑ Resuming from last checkpoint...")
    start_epoch, prev_metrics = CheckpointManager.load_checkpoint(
        resume_path, model, optimizer, scheduler, scaler
    )
    start_epoch += 1
    best_miou = prev_metrics.get('val_miou', 0)
    early_stopping.best_miou = best_miou
    # Load history if saved
    hist_path = os.path.join(CFG.OUTPUT_DIR, 'training_history.json')
    if os.path.exists(hist_path):
        with open(hist_path, 'r') as f:
            history = json.load(f)
        print(f"   Loaded training history ({len(history['train_loss'])} epochs)")

print(f"\n{'='*80}")
print(f"  TRAINING: SegFormer-B5 | {CFG.NUM_EPOCHS} epochs | BS={CFG.BATCH_SIZE}x{CFG.ACCUMULATION}={CFG.BATCH_SIZE*CFG.ACCUMULATION}")
print(f"  Starting from epoch {start_epoch+1}")
print(f"{'='*80}\n")

training_start = time.time()

for epoch in range(start_epoch, CFG.NUM_EPOCHS):
    epoch_start = time.time()
    
    # ‚îÄ‚îÄ Train ‚îÄ‚îÄ
    train_metrics = train_one_epoch(model, train_loader, criterion, optimizer, scheduler, scaler, epoch)
    
    # ‚îÄ‚îÄ Validate ‚îÄ‚îÄ
    val_metrics = validate(model, val_loader, criterion)
    
    epoch_time = time.time() - epoch_start
    
    # ‚îÄ‚îÄ Record history ‚îÄ‚îÄ
    history['train_loss'].append(train_metrics['loss'])
    history['val_loss'].append(val_metrics['loss'])
    history['train_miou'].append(train_metrics['miou'])
    history['val_miou'].append(val_metrics['miou'])
    history['train_dice'].append(train_metrics['dice'])
    history['val_dice'].append(val_metrics['dice'])
    history['train_pixel_acc'].append(train_metrics['pixel_acc'])
    history['val_pixel_acc'].append(val_metrics['pixel_acc'])
    history['train_ce'].append(train_metrics['ce_loss'])
    history['train_dice_loss'].append(train_metrics['dice_loss'])
    history['train_focal'].append(train_metrics['focal_loss'])
    history['lr'].append(train_metrics['lr'])
    
    # ‚îÄ‚îÄ Print epoch summary ‚îÄ‚îÄ
    is_best = val_metrics['miou'] > best_miou
    best_marker = ' ‚òÖ NEW BEST' if is_best else ''
    if is_best:
        best_miou = val_metrics['miou']
    
    print(f"\nEpoch {epoch+1:3d}/{CFG.NUM_EPOCHS} ({epoch_time:.0f}s) | "
          f"Train Loss: {train_metrics['loss']:.4f} | Val Loss: {val_metrics['loss']:.4f} | "
          f"Train mIoU: {train_metrics['miou']:.4f} | Val mIoU: {val_metrics['miou']:.4f} | "
          f"Val Dice: {val_metrics['dice']:.4f} | Val Acc: {val_metrics['pixel_acc']:.4f} | "
          f"LR: {train_metrics['lr']:.2e}{best_marker}")
    
    # Per-class IoU
    if (epoch + 1) % 5 == 0 or is_best:
        print("  Per-class IoU: " + " | ".join(
            f"{CLASS_NAMES[i]}: {val_metrics['class_iou'][i]:.3f}" 
            for i in range(CFG.NUM_CLASSES) if not np.isnan(val_metrics['class_iou'][i])
        ))
    
    # ‚îÄ‚îÄ Checkpointing ‚îÄ‚îÄ
    epoch_metrics = {
        'val_miou': val_metrics['miou'],
        'val_loss': val_metrics['loss'],
        'val_dice': val_metrics['dice'],
        'val_pixel_acc': val_metrics['pixel_acc'],
        'train_loss': train_metrics['loss'],
    }
    
    # Always save last checkpoint (for resume)
    ckpt_manager.save_last(model, optimizer, scheduler, scaler, epoch, epoch_metrics)
    
    # Save periodic / best checkpoints
    if is_best or (epoch + 1) % CFG.SAVE_EVERY == 0:
        ckpt_manager.save_checkpoint(
            model, optimizer, scheduler, scaler, epoch, epoch_metrics, is_best=is_best
        )
    
    # Save history to JSON (crash-safe)
    with open(os.path.join(CFG.OUTPUT_DIR, 'training_history.json'), 'w') as f:
        json.dump(history, f, indent=2)
    
    # ‚îÄ‚îÄ Early Stopping ‚îÄ‚îÄ
    early_stopping(epoch, val_metrics['miou'], val_metrics['loss'])
    print(f"  Early stopping: {early_stopping.status()}")
    
    if early_stopping.should_stop:
        print(f"\nüõë Training stopped early at epoch {epoch+1}")
        break

total_time = time.time() - training_start
print(f"\n{'='*80}")
print(f"  TRAINING COMPLETE in {total_time/3600:.2f} hours")
print(f"  Best Val mIoU: {best_miou:.4f}")
print(f"{'='*80}")

---
## 12. Training Curves & Analysis

In [None]:
def plot_training_curves(history, output_dir):
    """Generate comprehensive training visualization plots."""
    n_epochs = len(history['train_loss'])
    epochs = range(1, n_epochs + 1)
    
    fig, axes = plt.subplots(2, 3, figsize=(20, 12))
    
    # 1. Loss
    axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train', linewidth=2)
    axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val', linewidth=2)
    axes[0, 0].set_title('Total Loss', fontsize=14, fontweight='bold')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].legend(fontsize=12)
    axes[0, 0].grid(True, alpha=0.3)
    
    # 2. mIoU
    axes[0, 1].plot(epochs, history['train_miou'], 'b-', label='Train', linewidth=2)
    axes[0, 1].plot(epochs, history['val_miou'], 'r-', label='Val', linewidth=2)
    best_idx = np.argmax(history['val_miou'])
    axes[0, 1].axvline(x=best_idx+1, color='green', linestyle='--', alpha=0.7, label=f'Best: {history["val_miou"][best_idx]:.4f}')
    axes[0, 1].set_title('Mean IoU', fontsize=14, fontweight='bold')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].legend(fontsize=12)
    axes[0, 1].grid(True, alpha=0.3)
    
    # 3. Dice
    axes[0, 2].plot(epochs, history['train_dice'], 'b-', label='Train', linewidth=2)
    axes[0, 2].plot(epochs, history['val_dice'], 'r-', label='Val', linewidth=2)
    axes[0, 2].set_title('Dice Score', fontsize=14, fontweight='bold')
    axes[0, 2].set_xlabel('Epoch')
    axes[0, 2].legend(fontsize=12)
    axes[0, 2].grid(True, alpha=0.3)
    
    # 4. Pixel Accuracy
    axes[1, 0].plot(epochs, history['train_pixel_acc'], 'b-', label='Train', linewidth=2)
    axes[1, 0].plot(epochs, history['val_pixel_acc'], 'r-', label='Val', linewidth=2)
    axes[1, 0].set_title('Pixel Accuracy', fontsize=14, fontweight='bold')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].legend(fontsize=12)
    axes[1, 0].grid(True, alpha=0.3)
    
    # 5. Component Losses
    axes[1, 1].plot(epochs, history['train_ce'], label='CE Loss', linewidth=2)
    axes[1, 1].plot(epochs, history['train_dice_loss'], label='Dice Loss', linewidth=2)
    axes[1, 1].plot(epochs, history['train_focal'], label='Focal Loss', linewidth=2)
    axes[1, 1].set_title('Component Losses (Train)', fontsize=14, fontweight='bold')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].legend(fontsize=12)
    axes[1, 1].grid(True, alpha=0.3)
    
    # 6. Learning Rate
    axes[1, 2].plot(epochs, history['lr'], 'g-', linewidth=2)
    axes[1, 2].set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
    axes[1, 2].set_xlabel('Epoch')
    axes[1, 2].set_yscale('log')
    axes[1, 2].grid(True, alpha=0.3)
    
    plt.suptitle('SegFormer-B5 Training Curves', fontsize=18, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'training_curves_all.png'), dpi=200, bbox_inches='tight')
    plt.show()
    print(f"Saved training curves to {output_dir}/training_curves_all.png")

plot_training_curves(history, CFG.OUTPUT_DIR)

In [None]:
# Save detailed metrics report
def save_metrics_report(history, output_dir):
    """Save comprehensive metrics to text file."""
    filepath = os.path.join(output_dir, 'evaluation_metrics.txt')
    n_epochs = len(history['train_loss'])
    
    with open(filepath, 'w') as f:
        f.write("‚ïê" * 100 + "\n")
        f.write("  SEGFORMER-B5 TRAINING REPORT\n")
        f.write(f"  Generated: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write("‚ïê" * 100 + "\n\n")
        
        # Best results
        f.write("BEST RESULTS:\n")
        f.write("-" * 50 + "\n")
        f.write(f"  Best Val mIoU:      {max(history['val_miou']):.4f} (Epoch {np.argmax(history['val_miou'])+1})\n")
        f.write(f"  Best Val Dice:      {max(history['val_dice']):.4f} (Epoch {np.argmax(history['val_dice'])+1})\n")
        f.write(f"  Best Val Accuracy:  {max(history['val_pixel_acc']):.4f} (Epoch {np.argmax(history['val_pixel_acc'])+1})\n")
        f.write(f"  Lowest Val Loss:    {min(history['val_loss']):.4f} (Epoch {np.argmin(history['val_loss'])+1})\n")
        f.write("\n")
        
        # Final results
        f.write("FINAL EPOCH RESULTS:\n")
        f.write("-" * 50 + "\n")
        f.write(f"  Train Loss:     {history['train_loss'][-1]:.4f}\n")
        f.write(f"  Val Loss:       {history['val_loss'][-1]:.4f}\n")
        f.write(f"  Train mIoU:     {history['train_miou'][-1]:.4f}\n")
        f.write(f"  Val mIoU:       {history['val_miou'][-1]:.4f}\n")
        f.write(f"  Train Dice:     {history['train_dice'][-1]:.4f}\n")
        f.write(f"  Val Dice:       {history['val_dice'][-1]:.4f}\n")
        f.write(f"  Train Acc:      {history['train_pixel_acc'][-1]:.4f}\n")
        f.write(f"  Val Acc:        {history['val_pixel_acc'][-1]:.4f}\n")
        f.write("\n")
        
        # Per-epoch table
        f.write("PER-EPOCH HISTORY:\n")
        f.write("-" * 120 + "\n")
        headers = ['Epoch', 'Train Loss', 'Val Loss', 'Train mIoU', 'Val mIoU', 'Train Dice', 'Val Dice', 'Train Acc', 'Val Acc', 'LR']
        f.write("{:<8} {:<12} {:<12} {:<12} {:<12} {:<12} {:<12} {:<12} {:<12} {:<12}\n".format(*headers))
        f.write("-" * 120 + "\n")
        
        for i in range(n_epochs):
            f.write("{:<8} {:<12.4f} {:<12.4f} {:<12.4f} {:<12.4f} {:<12.4f} {:<12.4f} {:<12.4f} {:<12.4f} {:<12.2e}\n".format(
                i+1,
                history['train_loss'][i], history['val_loss'][i],
                history['train_miou'][i], history['val_miou'][i],
                history['train_dice'][i], history['val_dice'][i],
                history['train_pixel_acc'][i], history['val_pixel_acc'][i],
                history['lr'][i]
            ))
    
    print(f"üìù Saved metrics report to {filepath}")

save_metrics_report(history, CFG.OUTPUT_DIR)

---
## 13. Load Best Model for Inference

In [None]:
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
#  LOAD BEST MODEL
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê

best_ckpt_path = os.path.join(CFG.CHECKPOINT_DIR, 'best_model.pt')

if os.path.exists(best_ckpt_path):
    print("Loading best model for inference...")
    ckpt = torch.load(best_ckpt_path, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
    print(f"‚úÖ Best model loaded (Epoch {ckpt['epoch']+1}, Val mIoU: {ckpt['metrics']['val_miou']:.4f})")
else:
    # Try loading last checkpoint
    last_path = os.path.join(CFG.CHECKPOINT_DIR, 'last_checkpoint.pt')
    if os.path.exists(last_path):
        ckpt = torch.load(last_path, map_location=device)
        model.load_state_dict(ckpt['model_state_dict'])
        print(f"‚úÖ Last checkpoint loaded (Epoch {ckpt['epoch']+1})")
    else:
        print("‚ö†Ô∏è  No checkpoint found ‚Äî using current model state")

model.eval()
print("Model set to eval mode ‚úÖ")

---
## 14. Full Evaluation on Validation Set

In [None]:
@torch.no_grad()
def full_evaluation(model, loader, dataset_name='Validation'):
    """Run full evaluation with per-class metrics."""
    model.eval()
    all_class_iou = []
    all_class_dice = []
    all_acc = []
    
    pbar = tqdm(loader, desc=f'Evaluating {dataset_name}')
    for images, masks, _ in pbar:
        images = images.to(device, non_blocking=True)
        masks  = masks.to(device, non_blocking=True)
        
        with autocast(enabled=CFG.USE_AMP):
            outputs = model(pixel_values=images)
            logits_up = F.interpolate(outputs.logits, size=masks.shape[-2:], mode='bilinear', align_corners=False)
        
        _, class_iou = compute_iou(logits_up, masks)
        _, class_dice = compute_dice(logits_up, masks)
        acc = compute_pixel_accuracy(logits_up, masks)
        
        all_class_iou.append(class_iou)
        all_class_dice.append(class_dice)
        all_acc.append(acc)
    
    avg_iou = np.nanmean(all_class_iou, axis=0)
    avg_dice = np.nanmean(all_class_dice, axis=0)
    mean_iou = np.nanmean(avg_iou)
    mean_dice = np.mean(avg_dice)
    mean_acc = np.mean(all_acc)
    
    # Print results
    print(f"\n{'‚ïê'*60}")
    print(f"  {dataset_name.upper()} RESULTS")
    print(f"{'‚ïê'*60}")
    print(f"  Mean IoU:       {mean_iou:.4f}")
    print(f"  Mean Dice:      {mean_dice:.4f}")
    print(f"  Pixel Accuracy: {mean_acc:.4f}")
    print(f"{'‚îÄ'*60}")
    print(f"  {'Class':<20} {'IoU':>8} {'Dice':>8}")
    print(f"  {'‚îÄ'*38}")
    for i in range(CFG.NUM_CLASSES):
        iou_str = f"{avg_iou[i]:.4f}" if not np.isnan(avg_iou[i]) else 'N/A'
        print(f"  {CLASS_NAMES[i]:<20} {iou_str:>8} {avg_dice[i]:>8.4f}")
    print(f"{'‚ïê'*60}")
    
    # Bar chart
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    colors = [COLOR_PALETTE[i]/255. for i in range(CFG.NUM_CLASSES)]
    
    valid_iou = [v if not np.isnan(v) else 0 for v in avg_iou]
    ax1.barh(CLASS_NAMES, valid_iou, color=colors, edgecolor='black')
    ax1.axvline(x=mean_iou, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_iou:.4f}')
    ax1.set_xlabel('IoU')
    ax1.set_title(f'Per-Class IoU ({dataset_name})')
    ax1.legend()
    ax1.set_xlim(0, 1)
    
    ax2.barh(CLASS_NAMES, avg_dice, color=colors, edgecolor='black')
    ax2.axvline(x=mean_dice, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_dice:.4f}')
    ax2.set_xlabel('Dice Score')
    ax2.set_title(f'Per-Class Dice ({dataset_name})')
    ax2.legend()
    ax2.set_xlim(0, 1)
    
    plt.tight_layout()
    plt.savefig(os.path.join(CFG.OUTPUT_DIR, f'{dataset_name.lower()}_per_class_metrics.png'), dpi=150, bbox_inches='tight')
    plt.show()
    
    return {'mean_iou': mean_iou, 'mean_dice': mean_dice, 'pixel_acc': mean_acc, 'class_iou': avg_iou, 'class_dice': avg_dice}

val_results = full_evaluation(model, val_loader, 'Validation')

---
## 15. Inference ‚Äî Validation Set Predictions

In [None]:
@torch.no_grad()
def run_inference(model, loader, output_dir, save_comparisons=10, dataset_name='val'):
    """Run inference and save all prediction masks + visualizations."""
    model.eval()
    
    masks_dir = os.path.join(output_dir, f'{dataset_name}_masks')
    color_dir = os.path.join(output_dir, f'{dataset_name}_masks_color')
    comp_dir  = os.path.join(output_dir, f'{dataset_name}_comparisons')
    os.makedirs(masks_dir, exist_ok=True)
    os.makedirs(color_dir, exist_ok=True)
    os.makedirs(comp_dir, exist_ok=True)
    
    mean = np.array([0.485, 0.456, 0.406])
    std  = np.array([0.229, 0.224, 0.225])
    sample_count = 0
    
    pbar = tqdm(loader, desc=f'Inference ({dataset_name})')
    for images, masks, fnames in pbar:
        images = images.to(device, non_blocking=True)
        masks  = masks.to(device, non_blocking=True)
        
        with autocast(enabled=CFG.USE_AMP):
            outputs = model(pixel_values=images)
            logits_up = F.interpolate(outputs.logits, size=masks.shape[-2:], mode='bilinear', align_corners=False)
        
        preds = logits_up.argmax(dim=1).cpu().numpy().astype(np.uint8)
        
        for i in range(images.shape[0]):
            fname = fnames[i]
            base = os.path.splitext(fname)[0]
            
            # Save raw prediction mask
            pred_mask = preds[i]
            Image.fromarray(pred_mask).save(os.path.join(masks_dir, f'{base}_pred.png'))
            
            # Save colored prediction
            pred_color = mask_to_color(pred_mask)
            cv2.imwrite(os.path.join(color_dir, f'{base}_pred_color.png'),
                       cv2.cvtColor(pred_color, cv2.COLOR_RGB2BGR))
            
            # Save comparison visualization
            if sample_count < save_comparisons:
                img_np = images[i].cpu().numpy().transpose(1, 2, 0)
                img_np = ((img_np * std + mean) * 255).clip(0, 255).astype(np.uint8)
                
                gt_mask = masks[i].cpu().numpy().astype(np.uint8)
                gt_color = mask_to_color(gt_mask)
                
                overlay = cv2.addWeighted(img_np, 0.5, pred_color, 0.5, 0)
                
                fig, axes = plt.subplots(1, 4, figsize=(20, 5))
                axes[0].imshow(img_np); axes[0].set_title('Input Image'); axes[0].axis('off')
                axes[1].imshow(gt_color); axes[1].set_title('Ground Truth'); axes[1].axis('off')
                axes[2].imshow(pred_color); axes[2].set_title('Prediction'); axes[2].axis('off')
                axes[3].imshow(overlay); axes[3].set_title('Overlay'); axes[3].axis('off')
                
                patches = [mpatches.Patch(color=COLOR_PALETTE[c]/255., label=CLASS_NAMES[c]) for c in range(CFG.NUM_CLASSES)]
                fig.legend(handles=patches, loc='lower center', ncol=5, fontsize=9, bbox_to_anchor=(0.5, -0.05))
                
                plt.suptitle(f'{fname}', fontsize=14)
                plt.tight_layout()
                plt.savefig(os.path.join(comp_dir, f'comparison_{sample_count:04d}.png'), dpi=150, bbox_inches='tight')
                plt.close()
            
            sample_count += 1
    
    print(f"\n‚úÖ Inference complete: {sample_count} images processed")
    print(f"   Masks:       {masks_dir}")
    print(f"   Color masks: {color_dir}")
    print(f"   Comparisons: {comp_dir}")
    return sample_count

# Run on validation set
run_inference(model, val_loader, CFG.RESULTS_DIR, save_comparisons=20, dataset_name='val')

---
## 16. Inference ‚Äî Test Set Predictions

In [None]:
# Run on test set
run_inference(model, test_loader, CFG.RESULTS_DIR, save_comparisons=30, dataset_name='test')

---
## 17. Confusion Matrix & Error Analysis

In [None]:
@torch.no_grad()
def compute_confusion_matrix(model, loader, num_classes=CFG.NUM_CLASSES):
    """Compute full confusion matrix."""
    model.eval()
    conf_matrix = np.zeros((num_classes, num_classes), dtype=np.int64)
    
    for images, masks, _ in tqdm(loader, desc='Computing confusion matrix'):
        images = images.to(device)
        masks  = masks.to(device)
        
        with autocast(enabled=CFG.USE_AMP):
            outputs = model(pixel_values=images)
            logits_up = F.interpolate(outputs.logits, size=masks.shape[-2:], mode='bilinear', align_corners=False)
        
        preds = logits_up.argmax(dim=1).cpu().numpy().flatten()
        targets = masks.cpu().numpy().flatten()
        
        for t, p in zip(targets, preds):
            if t < num_classes and p < num_classes:
                conf_matrix[t, p] += 1
    
    return conf_matrix

conf_matrix = compute_confusion_matrix(model, val_loader)

# Normalize
conf_norm = conf_matrix.astype(np.float32)
row_sums = conf_norm.sum(axis=1, keepdims=True)
conf_norm = np.divide(conf_norm, row_sums, where=row_sums!=0)

# Plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(22, 9))

sns.heatmap(conf_norm, annot=True, fmt='.2f', cmap='Blues', xticklabels=CLASS_NAMES,
            yticklabels=CLASS_NAMES, ax=ax1, cbar_kws={'shrink': 0.8})
ax1.set_xlabel('Predicted', fontsize=12)
ax1.set_ylabel('True', fontsize=12)
ax1.set_title('Normalized Confusion Matrix', fontsize=14, fontweight='bold')

sns.heatmap(np.log1p(conf_matrix), annot=False, cmap='YlOrRd', xticklabels=CLASS_NAMES,
            yticklabels=CLASS_NAMES, ax=ax2, cbar_kws={'shrink': 0.8})
ax2.set_xlabel('Predicted', fontsize=12)
ax2.set_ylabel('True', fontsize=12)
ax2.set_title('Confusion Matrix (log scale)', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.savefig(os.path.join(CFG.OUTPUT_DIR, 'confusion_matrix.png'), dpi=200, bbox_inches='tight')
plt.show()

---
## 18. Multi-Scale Test Time Augmentation (TTA)

In [None]:
@torch.no_grad()
def tta_inference(model, image, scales=[0.75, 1.0, 1.25], flip=True):
    """Multi-scale + flip test-time augmentation."""
    model.eval()
    B, C, H, W = image.shape
    final_logits = torch.zeros(B, CFG.NUM_CLASSES, H, W, device=image.device)
    n_aug = 0
    
    for scale in scales:
        sh, sw = int(H * scale), int(W * scale)
        scaled = F.interpolate(image, size=(sh, sw), mode='bilinear', align_corners=False)
        
        # Forward
        with autocast(enabled=CFG.USE_AMP):
            out = model(pixel_values=scaled).logits
        logits = F.interpolate(out, size=(H, W), mode='bilinear', align_corners=False)
        final_logits += logits
        n_aug += 1
        
        # Flipped
        if flip:
            flipped = torch.flip(scaled, dims=[-1])
            with autocast(enabled=CFG.USE_AMP):
                out_f = model(pixel_values=flipped).logits
            logits_f = torch.flip(F.interpolate(out_f, size=(H, W), mode='bilinear', align_corners=False), dims=[-1])
            final_logits += logits_f
            n_aug += 1
    
    return final_logits / n_aug


# Run TTA on val set
@torch.no_grad()
def evaluate_with_tta(model, loader):
    model.eval()
    all_iou = []
    all_dice = []
    all_acc = []
    
    for images, masks, _ in tqdm(loader, desc='TTA Evaluation'):
        images = images.to(device)
        masks  = masks.to(device)
        
        logits = tta_inference(model, images, scales=[0.75, 1.0, 1.25, 1.5])
        
        miou, _ = compute_iou(logits, masks)
        mdice, _ = compute_dice(logits, masks)
        acc = compute_pixel_accuracy(logits, masks)
        
        all_iou.append(miou)
        all_dice.append(mdice)
        all_acc.append(acc)
    
    print(f"\nüî¨ TTA Results:")
    print(f"  mIoU:     {np.mean(all_iou):.4f} (no TTA: {val_results['mean_iou']:.4f})")
    print(f"  Dice:     {np.mean(all_dice):.4f} (no TTA: {val_results['mean_dice']:.4f})")
    print(f"  Accuracy: {np.mean(all_acc):.4f} (no TTA: {val_results['pixel_acc']:.4f})")

evaluate_with_tta(model, val_loader)

---
## 19. Package Everything as ZIP

In [None]:
def create_results_zip():
    """Package all outputs, results, checkpoints into a downloadable ZIP."""
    zip_path = '/kaggle/working/offroad_segformer_results.zip'
    
    print("üì¶ Creating results ZIP...")
    
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
        # Add all outputs
        for root_dir, dir_name in [
            (CFG.OUTPUT_DIR, 'outputs'),
            (CFG.RESULTS_DIR, 'results'),
            (CFG.CHECKPOINT_DIR, 'checkpoints'),
        ]:
            if os.path.exists(root_dir):
                for dirpath, dirnames, filenames in os.walk(root_dir):
                    for filename in filenames:
                        filepath = os.path.join(dirpath, filename)
                        arcname = os.path.join(dir_name, os.path.relpath(filepath, root_dir))
                        zf.write(filepath, arcname)
        
        # Add this notebook if available
        nb_candidates = glob.glob('/kaggle/working/*.ipynb')
        for nb in nb_candidates:
            zf.write(nb, os.path.basename(nb))
    
    size_mb = os.path.getsize(zip_path) / (1024 * 1024)
    print(f"\n‚úÖ Results ZIP created: {zip_path}")
    print(f"   Size: {size_mb:.1f} MB")
    print(f"\nüì• Download from Kaggle: Output tab ‚Üí offroad_segformer_results.zip")
    
    # List contents
    with zipfile.ZipFile(zip_path, 'r') as zf:
        names = zf.namelist()
        dirs = set()
        for n in names:
            dirs.add(n.split('/')[0] if '/' in n else n)
        print(f"\n   Contents ({len(names)} files):")
        for d in sorted(dirs):
            count = sum(1 for n in names if n.startswith(d))
            print(f"     {d}/: {count} files")
    
    return zip_path

zip_path = create_results_zip()

---
## 20. Final Summary

In [None]:
print("\n" + "‚ïê" * 80)
print("  üåø OFF-ROAD SEGMENTATION ‚Äî FINAL SUMMARY")
print("‚ïê" * 80)
print(f"")
print(f"  Model:          SegFormer-B5 ({CFG.MODEL_NAME})")
print(f"  Dataset:        {len(train_dataset)} train / {len(val_dataset)} val / {len(test_dataset)} test")
print(f"  Image Size:     {CFG.IMG_SIZE}x{CFG.IMG_SIZE}")
print(f"  Classes:        {CFG.NUM_CLASSES}")
print(f"  Epochs Trained: {len(history['train_loss'])}")
print(f"")
print(f"  ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê")
print(f"  ‚îÇ  Best Val mIoU:       {max(history['val_miou']):.4f}              ‚îÇ")
print(f"  ‚îÇ  Best Val Dice:       {max(history['val_dice']):.4f}              ‚îÇ")
print(f"  ‚îÇ  Best Val Accuracy:   {max(history['val_pixel_acc']):.4f}              ‚îÇ")
print(f"  ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò")
print(f"")
print(f"  üìÅ Outputs:")
print(f"     {CFG.OUTPUT_DIR}/ ‚Äî plots, metrics")
print(f"     {CFG.RESULTS_DIR}/ ‚Äî predictions (masks + color + comparisons)")
print(f"     {CFG.CHECKPOINT_DIR}/ ‚Äî model weights (best + last + top-{CFG.KEEP_TOP_K})")
print(f"     {zip_path} ‚Äî everything as ZIP")
print(f"")
print(f"  Features:")
print(f"     ‚úÖ SegFormer-B5 (MiT encoder + MLP decoder)")
print(f"     ‚úÖ Transfer Learning from ADE20K")
print(f"     ‚úÖ Combined Loss (CE + Dice + Focal)")
print(f"     ‚úÖ Class-weighted loss (imbalanced data)")
print(f"     ‚úÖ Heavy augmentations (Albumentations)")
print(f"     ‚úÖ LayerWise LR Decay (encoder vs decoder)")
print(f"     ‚úÖ Cosine Warmup Scheduler")
print(f"     ‚úÖ Gradient Accumulation (effective BS={CFG.BATCH_SIZE*CFG.ACCUMULATION})")
print(f"     ‚úÖ Mixed Precision (AMP)")
print(f"     ‚úÖ TF32 for H100")
print(f"     ‚úÖ Multi-metric Early Stopping (mIoU + Loss)")
print(f"     ‚úÖ Top-K Checkpoint Manager")
print(f"     ‚úÖ Resumable Training (last_checkpoint.pt)")
print(f"     ‚úÖ Multi-Scale TTA")
print(f"     ‚úÖ Confusion Matrix")
print(f"     ‚úÖ Full ZIP Export")
print("‚ïê" * 80)