In [48]:
# Essential imports for deep learning and visualization
import os
import time
import random
import sys
import json
from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple
from dataclasses import dataclass, field, asdict

from datetime import datetime

# Scientific computing and visualization
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import cv2
from PIL import Image

# Deep learning frameworks
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler
from torch.optim.lr_scheduler import CosineAnnealingLR

# HuggingFace Transformers for SegFormer
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor

DATA_ROOT = "./datasets/NYUDepthv2"

if os.getcwd() == "/kaggle/working":
    DATA_ROOT = "/kaggle/input/nyudepthv2/NYUDepthv2"

COLAB = 'google.colab' in sys.modules
if COLAB:
    assert torch.cuda.is_available(), "Colab session must have a GPU enabled."
    
    from google.colab import drive
    drive.mount('/content/drive')

    DATA_ROOT = "drive/MyDrive/EGH444/datasets/NYUDepthv2"
    
assert os.path.exists(DATA_ROOT), f"Data root {DATA_ROOT} does not exist."

## 🎛️ Shared Configuration System

Let's create a comprehensive configuration system that allows us to easily switch between different experimental setups. This approach ensures reproducibility and makes it easy to compare different models.

In [None]:
@dataclass
class TrainingHistory:
    """Tracks comprehensive training metrics for analysis and visualization."""
    
    # Training metrics (per epoch)
    train_loss: List[float] = field(default_factory=list)
    train_miou: List[float] = field(default_factory=list)
    train_pixacc: List[float] = field(default_factory=list)
    
    # Validation metrics (per epoch)
    val_loss: List[float] = field(default_factory=list)
    val_miou: List[float] = field(default_factory=list)
    val_pixacc: List[float] = field(default_factory=list)
    val_iou_per_class: List[List[float]] = field(default_factory=list)
    
    # Training metadata
    learning_rates: List[float] = field(default_factory=list)
    epoch_times: List[float] = field(default_factory=list)
    
    def add_epoch(self, epoch_data: Dict[str, float]) -> None:
        """Add metrics for a completed epoch."""
        self.train_loss.append(epoch_data.get("train_loss", 0.0))
        self.train_miou.append(epoch_data.get("train_miou", 0.0))
        self.train_pixacc.append(epoch_data.get("train_pixacc", 0.0))
        
        self.val_loss.append(epoch_data.get("val_loss", 0.0))
        self.val_miou.append(epoch_data.get("val_miou", 0.0))
        self.val_pixacc.append(epoch_data.get("val_pixacc", 0.0))
        self.val_iou_per_class.append(epoch_data.get("val_iou_per_class", []))
        
        self.learning_rates.append(epoch_data.get("learning_rate", 0.0))
        self.epoch_times.append(epoch_data.get("epoch_time", 0.0))
    
    def get_best_epoch(self, metric: str = "val_miou") -> Tuple[int, float]:
        """Get the epoch number and value of the best performance."""
        values = getattr(self, metric, [])
        if not values:
            return 0, 0.0
        best_idx = max(range(len(values)), key=lambda i: values[i])
        return best_idx + 1, values[best_idx]  # 1-indexed epoch
    
    def __len__(self) -> int:
        """Return number of completed epochs."""
        return len(self.train_loss)


@dataclass
class TrainingConfig:
    """Self-contained training configuration that includes data setup."""
    
    model_name: str  # Simple model identifier for checkpoint organization
    epochs: int
    batch_size: int = 32
    
    processor: Optional[Any] = None  # SegFormer processor or similar (optional for custom datasets)
    data_root: str = DATA_ROOT
    image_size: Tuple[int, int] = (240, 320)  # (height, width)
    num_workers: int = 0
    pin_memory: bool = True
    
    learning_rate: float = 1e-4
    
    num_classes: int = 40
    ignore_index: int = 255
    device: str = "auto"
    seed: int = 42
    
    # Logging/checkpointing
    log_every: int = 5
    save_dir: str = "checkpoints"
    
    @property
    def model_dir(self) -> Path:
        """Get model-specific directory."""
        return Path(self.save_dir) / self.model_name
    
    def __post_init__(self):
        """Create save directory after initialization."""
        Path(self.save_dir).mkdir(parents=True, exist_ok=True)
        Path(self.model_dir).mkdir(parents=True, exist_ok=True)
    

print("✅ Configuration classes defined!")

✅ Configuration classes defined!


## ⚙️ Device Setup & Seed Management

Proper device detection and seed management are crucial for reproducible experiments.

---
# 📂 Section 2: Data Loading & Transformations

NYUv2 dataset loading with SegFormer-compatible preprocessing and data augmentation.

In [35]:
def visualize_batch(batch: Dict[str, torch.Tensor],
                    num_samples: int = 4) -> None:
    """Visualize samples from a batch - supports both SegFormer and TorchVision formats."""
    
    # ImageNet normalization constants
    IMAGENET_MEAN = np.array([0.485, 0.456, 0.406])
    IMAGENET_STD = np.array([0.229, 0.224, 0.225])
    
    def denormalize_imagenet(tensor: torch.Tensor) -> np.ndarray:
        """Denormalize ImageNet preprocessed image."""
        img = tensor.permute(1, 2, 0).cpu().numpy()
        img = (img * IMAGENET_STD + IMAGENET_MEAN).clip(0, 1)
        return (img * 255).astype(np.uint8)
    
    def colorize_mask(mask: np.ndarray) -> np.ndarray:
        """Create colored visualization of segmentation mask."""
        np.random.seed(42)  # Consistent colors
        colors = np.random.randint(0, 255, (41, 3), dtype=np.uint8)  # 40 classes + ignore
        colors[40] = [0, 0, 0]  # Ignore class = black
        
        colored = np.zeros((*mask.shape, 3), dtype=np.uint8)
        for class_id in range(41):
            colored[mask == class_id] = colors[class_id]
        colored[mask == 255] = [128, 128, 128]  # Ignore = gray
        
        return colored
    
    # Detect batch format and extract data
    if "pixel_values" in batch:
        # SegFormer format
        images = batch["pixel_values"]
        labels = batch["labels"] 
        ids = batch["id"]
        format_name = "SegFormer"
    elif "image" in batch:
        # TorchVision format
        images = batch["image"]
        labels = batch["mask"]
        ids = batch["id"]
        format_name = "TorchVision"
    else:
        raise ValueError("Unknown batch format. Expected 'pixel_values' or 'image' keys.")
    
    num_samples = min(num_samples, images.shape[0])
    
    fig, axes = plt.subplots(2, num_samples, figsize=(num_samples * 4, 8))
    if num_samples == 1:
        axes = axes.reshape(-1, 1)
    
    for i in range(num_samples):
        # Denormalize image (both formats use ImageNet normalization)
        img = denormalize_imagenet(images[i])
        mask = labels[i].cpu().numpy()
        colored_mask = colorize_mask(mask)
        
        # Plot image
        axes[0, i].imshow(img)
        axes[0, i].set_title(f"Image: {ids[i]}")
        axes[0, i].axis("off")
        
        # Plot mask
        axes[1, i].imshow(colored_mask)
        axes[1, i].set_title(f"Mask: {mask.shape}")
        axes[1, i].axis("off")
    
    plt.suptitle(f"Batch Visualization ({format_name} format)", fontsize=14)
    plt.tight_layout()
    plt.show()

In [4]:
def to_nyu40_ids(mask_np: np.ndarray) -> np.ndarray:
    """Ensure mask uses NYUv2-40 label space [0..39]. Out-of-range → 255 (ignore)."""
    out = mask_np.astype(np.int64, copy=True)
    bad = (out < 0) | (out > 39)
    out[bad] = 255
    return out


class SegFormerTransforms:
    """Enhanced data augmentation transforms for SegFormer training."""
    
    class Train:
        """Training augmentations with RGB and depth support."""
        def __init__(self, hflip_p=0.5, rotation_p=0.3, rotation_angle=15.0, 
                     color_jitter_p=0.4, brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1,
                     crop_p=0.3, crop_scale=(0.8, 1.0), 
                     use_depth=False, depth_noise_p=0.3, depth_noise_std=0.02, depth_dropout_p=0.2):
            self.hflip_p = hflip_p
            self.rotation_p = rotation_p
            self.rotation_angle = rotation_angle
            self.color_jitter_p = color_jitter_p
            self.brightness = brightness
            self.contrast = contrast
            self.saturation = saturation
            self.hue = hue
            self.crop_p = crop_p
            self.crop_scale = crop_scale
            self.use_depth = use_depth
            self.depth_noise_p = depth_noise_p
            self.depth_noise_std = depth_noise_std
            self.depth_dropout_p = depth_dropout_p
        
        def _apply_crop(self, rgb: np.ndarray, label: np.ndarray, depth: np.ndarray = None):
            """Apply random crop with zoom."""
            h, w = rgb.shape[:2]
            scale = random.uniform(*self.crop_scale)
            new_h, new_w = int(h * scale), int(w * scale)
            
            # Random crop position
            top = random.randint(0, max(0, h - new_h))
            left = random.randint(0, max(0, w - new_w))
            
            # Crop and resize back
            rgb_crop = rgb[top:top+new_h, left:left+new_w]
            label_crop = label[top:top+new_h, left:left+new_w]
            rgb = cv2.resize(rgb_crop, (w, h), interpolation=cv2.INTER_LINEAR)
            label = cv2.resize(label_crop.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(np.int64)
            
            if depth is not None:
                depth_crop = depth[top:top+new_h, left:left+new_w]
                depth = cv2.resize(depth_crop, (w, h), interpolation=cv2.INTER_NEAREST)
                return rgb, label, depth
            
            return rgb, label
        
        def _apply_color_jitter(self, rgb: np.ndarray) -> np.ndarray:
            """Apply color jittering augmentations."""
            rgb = rgb.astype(np.float32)
            
            # Brightness
            if self.brightness > 0:
                brightness_factor = random.uniform(1 - self.brightness, 1 + self.brightness)
                rgb *= brightness_factor
            
            # Contrast
            if self.contrast > 0:
                contrast_factor = random.uniform(1 - self.contrast, 1 + self.contrast)
                mean = rgb.mean()
                rgb = (rgb - mean) * contrast_factor + mean
            
            # Saturation (convert to HSV)
            if self.saturation > 0:
                hsv = cv2.cvtColor(np.clip(rgb, 0, 255).astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32)
                saturation_factor = random.uniform(1 - self.saturation, 1 + self.saturation)
                hsv[:, :, 1] *= saturation_factor
                hsv[:, :, 1] = np.clip(hsv[:, :, 1], 0, 255)
                rgb = cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2RGB).astype(np.float32)
            
            # Hue shift
            if self.hue > 0:
                hsv = cv2.cvtColor(np.clip(rgb, 0, 255).astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32)
                hue_shift = random.uniform(-self.hue, self.hue) * 179  # OpenCV hue is 0-179
                hsv[:, :, 0] = (hsv[:, :, 0] + hue_shift) % 180
                rgb = cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2RGB).astype(np.float32)
            
            return np.clip(rgb, 0, 255).astype(np.uint8)
        
        def _apply_depth_augmentation(self, depth: np.ndarray) -> np.ndarray:
            """Apply depth-specific augmentations."""
            depth = depth.copy().astype(np.float32)
            valid_mask = depth > 0
            
            # Gaussian noise
            if random.random() < self.depth_noise_p and valid_mask.any():
                noise = np.random.normal(0, self.depth_noise_std * 255, depth.shape).astype(np.float32)
                depth[valid_mask] += noise[valid_mask]
                depth = np.clip(depth, 0, 255)
            
            # Depth dropout (simulate sensor holes)
            if random.random() < self.depth_dropout_p and valid_mask.any():
                dropout_mask = np.random.random(depth.shape) < 0.05  # 5% dropout
                depth[dropout_mask & valid_mask] = 0
            
            return depth
        
        def __call__(self, rgb: np.ndarray, label: np.ndarray, depth: np.ndarray = None) -> tuple:
            h, w = rgb.shape[:2]
            
            # Random crop with zoom
            if random.random() < self.crop_p:
                if depth is not None:
                    rgb, label, depth = self._apply_crop(rgb, label, depth)
                else:
                    rgb, label = self._apply_crop(rgb, label)
            
            # Horizontal flip
            if random.random() < self.hflip_p:
                rgb = np.ascontiguousarray(np.flip(rgb, axis=1))
                label = np.ascontiguousarray(np.flip(label, axis=1))
                if depth is not None:
                    depth = np.ascontiguousarray(np.flip(depth, axis=1))
            
            # Rotation
            if random.random() < self.rotation_p:
                angle = random.uniform(-self.rotation_angle, self.rotation_angle)
                center = (w // 2, h // 2)
                M = cv2.getRotationMatrix2D(center, angle, 1.0)
                rgb = cv2.warpAffine(rgb, M, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101)
                label = cv2.warpAffine(label.astype(np.uint8), M, (w, h), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=255)
                label = label.astype(np.int64)
                if depth is not None:
                    depth = cv2.warpAffine(depth, M, (w, h), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=0)
            
            # Color jitter
            if random.random() < self.color_jitter_p:
                rgb = self._apply_color_jitter(rgb)
            
            # Depth augmentation
            if depth is not None and self.use_depth:
                depth = self._apply_depth_augmentation(depth)
                return rgb, label, depth
            
            return rgb, label
    
    class Eval:
        """No augmentations for evaluation."""
        def __call__(self, rgb: np.ndarray, label: np.ndarray, depth: np.ndarray = None) -> tuple:
            if depth is not None:
                return rgb, label, depth
            return rgb, label

In [None]:
from torchvision.transforms._presets import SemanticSegmentation

class NYUDepthDataset(Dataset):
    """NYUv2 dataset with support for both SegFormer and TorchVision preprocessing."""

    def __init__(self, base_dir: str, split: str, processor: Any,
                 transform=None, image_size: Optional[Tuple[int, int]] = None):
        self.base = Path(base_dir)
        self.processor = processor
        self.transform = transform
        self.image_size = image_size

        # Detect processor type
        self.is_segformer = isinstance(processor, SegformerImageProcessor) if processor else False
        self.is_torchvision = isinstance(processor, SemanticSegmentation) if processor else False
        
        assert self.is_segformer or self.is_torchvision, "Processor must be either SegFormer or TorchVision type."

        # ensure split is valid
        assert split in ["train", "val", "test"], "split must be 'train', 'val', or 'test'"
        folder_split = "train" if split == "train" or split == "val" else "test"

        # Load split files
        with open(self.base / f"{folder_split}.txt") as f:
            stems = [Path(line.split()[0]).stem for line in f if line.strip()]

        # If val split, take 20% of train set
        if split == "val" or split == "train":
            # Create reproducible random split
            np.random.seed(42)  # Fixed seed for reproducibility
            indices = np.random.permutation(len(stems))
            val_size = int(0.2 * len(stems))

            if split == "val":
                stems = [stems[i] for i in indices[-val_size:]]  # Random 20%
            else:  # train
                stems = [stems[i] for i in indices[:-val_size]]  # Random 80%

        # Find valid RGB + Label pairs
        self.items = []
        for s in stems:
            rgb_path = self.base / "RGB" / f"{s}.jpg"
            if not rgb_path.exists():
                rgb_path = self.base / "RGB" / f"{s}.png"
            label_path = self.base / "Label" / f"{s}.png"

            if rgb_path.exists() and label_path.exists():
                self.items.append((s, rgb_path, label_path))

        if not self.items:
            raise RuntimeError(f"No valid samples found in {base_dir}/{folder_split}.txt")

    def __len__(self) -> int:
        return len(self.items)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        stem, rgb_path, label_path = self.items[idx]

        # Load RGB and convert BGR→RGB
        rgb = cv2.imread(str(rgb_path), cv2.IMREAD_COLOR)
        rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)

        # Load label and apply NYU40 processing
        label = cv2.imread(str(label_path), cv2.IMREAD_GRAYSCALE).astype(np.int64)
        label = to_nyu40_ids(label)

        # Apply transforms if provided (before final processing)
        if self.transform:
            rgb, label = self.transform(rgb, label)

        # Resize if custom size specified
        if self.image_size:
            h, w = self.image_size
            rgb = cv2.resize(rgb, (w, h), interpolation=cv2.INTER_LINEAR)
            label = cv2.resize(label.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(np.int64)

        # Handle different processor types
        if self.is_segformer:
            # SegFormer preprocessing
            inputs = self.processor(images=rgb, return_tensors="pt")
            pixel_values = inputs["pixel_values"].squeeze(0).contiguous()

            # Process labels to match processor output size
            processor_size = self.processor.size
            if isinstance(processor_size, dict):
                target_h, target_w = processor_size["height"], processor_size["width"]
            else:
                target_h = target_w = processor_size

            label_resized = cv2.resize(
                label.astype(np.uint8), (target_w, target_h), interpolation=cv2.INTER_NEAREST
            )
            labels = torch.from_numpy(label_resized.astype(np.int64)).contiguous()

            return {
                "pixel_values": pixel_values,
                "labels": labels,
                "id": stem
            }

        elif self.is_torchvision:
            # TorchVision preprocessing
            from PIL import Image
            rgb_pil = Image.fromarray(rgb)
            label_pil = Image.fromarray(label.astype(np.uint8))

            # Apply TorchVision transforms
            image_tensor = self.processor(rgb_pil)

            # For labels, we need to convert to tensor manually since TorchVision transforms are for images
            labels = torch.from_numpy(label.astype(np.int64)).contiguous()

            return {
                "image": image_tensor,
                "mask": labels,
                "id": stem
            }

        else:
            raise ValueError("Processor must be provided and be either SegFormer or TorchVision type.")

In [6]:
def show_augmentations(dataset: NYUDepthDataset, sample_idx: int = 0, num_augmentations: int = 6, use_depth: bool = False):
    """Show multiple augmentations of a single sample without SegFormer preprocessing."""
    
    # Get raw sample without transformss
    stem, rgb_path, label_path = dataset.items[sample_idx]
    
    # Load RGB and label
    rgb = cv2.imread(str(rgb_path), cv2.IMREAD_COLOR)
    rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
    label = cv2.imread(str(label_path), cv2.IMREAD_GRAYSCALE).astype(np.int64)
    label = to_nyu40_ids(label)
    
    # Load depth if requested
    depth = None
    if use_depth:
        depth_path = dataset.base / "Depth" / f"{stem}.png"
        if depth_path.exists():
            depth = cv2.imread(str(depth_path), cv2.IMREAD_GRAYSCALE).astype(np.float32)
        else:
            print(f"⚠️ Depth file not found: {depth_path}")
            use_depth = False
    
    # Create transform with enhanced settings for visualization
    transform = SegFormerTransforms.Train(
        hflip_p=0.8, rotation_p=0.8, rotation_angle=20,
        color_jitter_p=0.8, brightness=0.4, contrast=0.4, saturation=0.4, hue=0.15,
        crop_p=0.6, crop_scale=(0.7, 1.0),
        use_depth=use_depth, depth_noise_p=0.8, depth_noise_std=0.03, depth_dropout_p=0.5
    )
    
    def colorize_mask(mask: np.ndarray) -> np.ndarray:
        """Create colored visualization of segmentation mask."""
        np.random.seed(42)  # Consistent colors
        colors = np.random.randint(0, 255, (41, 3), dtype=np.uint8)
        colors[40] = [0, 0, 0]  # Ignore class = black
        
        colored = np.zeros((*mask.shape, 3), dtype=np.uint8)
        for class_id in range(41):
            colored[mask == class_id] = colors[class_id]
        colored[mask == 255] = [128, 128, 128]  # Ignore = gray
        return colored
    
    # Create visualization
    num_cols = num_augmentations + 1  # +1 for original
    num_rows = 3 if use_depth else 2  # RGB, Mask, (Depth)
    
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols * 4, num_rows * 3))
    if num_rows == 1:
        axes = axes.reshape(1, -1)
    
    # Show original
    axes[0, 0].imshow(rgb)
    axes[0, 0].set_title("Original RGB")
    axes[0, 0].axis("off")
    
    axes[1, 0].imshow(colorize_mask(label))
    axes[1, 0].set_title("Original Mask")
    axes[1, 0].axis("off")
    
    if use_depth and depth is not None:
        axes[2, 0].imshow(depth, cmap='plasma')
        axes[2, 0].set_title("Original Depth")
        axes[2, 0].axis("off")
    
    # Show augmentations
    for i in range(num_augmentations):
        # Apply transform
        if use_depth and depth is not None:
            rgb_aug, label_aug, depth_aug = transform(rgb.copy(), label.copy(), depth.copy())
        else:
            rgb_aug, label_aug = transform(rgb.copy(), label.copy())
        
        # Display RGB
        axes[0, i+1].imshow(rgb_aug)
        axes[0, i+1].set_title(f"Aug {i+1} RGB")
        axes[0, i+1].axis("off")
        
        # Display mask
        axes[1, i+1].imshow(colorize_mask(label_aug))
        axes[1, i+1].set_title(f"Aug {i+1} Mask")
        axes[1, i+1].axis("off")
        
        # Display depth if available
        if use_depth and depth is not None:
            axes[2, i+1].imshow(depth_aug, cmap='plasma')
            axes[2, i+1].set_title(f"Aug {i+1} Depth")
            axes[2, i+1].axis("off")
    
    plt.tight_layout()
    plt.suptitle(f"Augmentations for sample: {stem}", y=1.02, fontsize=16)
    plt.show()

In [7]:
def create_data_loaders(
    data_root: str,
    batch_size: int,
    image_size: Tuple[int, int],
    num_workers: int = 0,
    processor = None,  # Can be SegFormer processor or TorchVision transforms
    pin_memory: bool = True
) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """Create training, validation, and test data loaders.
    
    Args:
        data_root: Path to dataset directory
        processor: SegFormer processor or TorchVision transforms
        batch_size: Batch size for data loaders
        image_size: (height, width) for resizing
        num_workers: Number of worker processes for data loading
        pin_memory: Whether to pin memory for faster GPU transfer
    
    Returns:
        train_loader, val_loader, test_loader
    """
    
    # Detect processor type and print appropriate info
    if processor is not None:
        if hasattr(processor, 'size'):
            print(f"📏 Using SegFormer processor with size: {processor.size}")
        elif hasattr(processor, 'transforms'):
            print(f"📏 Using TorchVision transforms")
        else:
            print(f"📏 Using custom processor: {type(processor).__name__}")
    else:
        print("📏 No processor provided - using default ImageNet normalization")
    
    # Create transforms
    train_transform = SegFormerTransforms.Train()
    val_transform = SegFormerTransforms.Eval()
    
    # Create datasets
    train_dataset = NYUDepthDataset(
        data_root, "train", processor, 
        transform=train_transform, image_size=image_size
    )
    val_dataset = NYUDepthDataset(
        data_root, "val", processor,
        transform=val_transform, image_size=image_size
    )
    test_dataset = NYUDepthDataset(
        data_root, "test", processor,
        transform=val_transform, image_size=image_size
    )

    # Create data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size, 
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size, 
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory
    )

    print(f"📊 Train: {len(train_dataset)} samples,  Val: {len(val_dataset)} samples,  Test: {len(test_dataset)} samples")
    print(f"📦 Train batches: {len(train_loader)},  Val batches: {len(val_loader)},  Test batches: {len(test_loader)}")

    return train_loader, val_loader, test_loader

print("✅ Universal data loader creation function defined")

✅ Universal data loader creation function defined


In [8]:
train_dataset = NYUDepthDataset(
    DATA_ROOT, "train", 
    processor=SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
)
# show_augmentations(train_dataset, sample_idx=10, num_augmentations=4, use_depth=False)
# show_augmentations(train_dataset, sample_idx=10, num_augmentations=4, use_depth=True)

  image_processor = cls(**image_processor_dict)


---
# 📊 Section 3: Universal Trainer Class

A single, model-agnostic trainer that works with any PyTorch model and data loaders.

In [None]:
# Example: Data exploration and visualization using standalone function
model_name = "nvidia/segformer-b0-finetuned-ade-512-512"
processor = SegformerImageProcessor.from_pretrained(model_name)

# Create data loaders for exploration
train_loader, val_loader, test_loader = create_data_loaders(
    data_root=DATA_ROOT,
    processor=processor,
    batch_size=8,
    image_size=(240, 320),
    num_workers=0,
    pin_memory=True
)

# Test data loading and visualization
print("\\n🧪 Testing data loading...")
sample_batch = next(iter(train_loader))
print(f"✅ Batch loaded successfully!")
print(f"📏 pixel_values shape: {sample_batch['pixel_values'].shape}")
print(f"📏 labels shape: {sample_batch['labels'].shape}")
print(f"🏷️ Sample IDs: {sample_batch['id'][:3]}")

# Visualize samples for Segformer
print("\\n🖼️ Visualizing batch samples...")
# visualize_batch(sample_batch, num_samples=4)

📏 Using SegFormer processor with size: {'height': 512, 'width': 512}
📊 Train: 636 samples,  Val: 159 samples,  Test: 654 samples
📦 Train batches: 80,  Val batches: 20,  Test batches: 82
\n🧪 Testing data loading...
✅ Batch loaded successfully!
📏 pixel_values shape: torch.Size([8, 3, 512, 512])
📏 labels shape: torch.Size([8, 512, 512])
🏷️ Sample IDs: ['391', '1016', '256']
\n🖼️ Visualizing batch samples...




In [None]:
class Trainer:
    """Universal trainer class with automatic checkpoint resume."""

    def __init__(self, config: TrainingConfig, verbose: bool = True):
        """
        Initialize trainer with configuration.
        
        Args:
            config: Training configuration with data and model setup
            verbose: Whether to print detailed logs 
        """
        self.config = config
        self.verbose = verbose
        self.device = self._setup_device()
        self.history = TrainingHistory()
        
        # Set random seeds for reproducibility
        self._set_random_seeds()
        
        # Create data loaders from config
        self.train_loader, self.val_loader, self.test_loader = self._create_data_loaders()
        
        # Auto-resume setup
        self.latest_checkpoint_path = self.config.model_dir / "latest.pth"
        self.start_epoch = 1
        self.loaded_checkpoint = None
        
        # Check for existing checkpoint
        if self.latest_checkpoint_path.exists():
            self._load_latest_checkpoint()
    
    def _setup_device(self) -> torch.device:
        """Setup training device."""
        if self.config.device == "auto":
            if torch.cuda.is_available():
                device = torch.device("cuda")
            elif torch.backends.mps.is_available():
                device = torch.device("cpu")  # Temporary fix for MPS issues
            else:
                device = torch.device("cpu")
        else:
            device = torch.device(self.config.device)
        print(f"Using device: {device}")
        return device
    
    def _set_random_seeds(self) -> None:
        """Set random seeds for reproducible results."""
        random.seed(self.config.seed)
        np.random.seed(self.config.seed)
        torch.manual_seed(self.config.seed)
        torch.cuda.manual_seed_all(self.config.seed)
        
        # Make CuDNN deterministic (slower but reproducible)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
        print(f"🎲 Random seeds set to {self.config.seed} for reproducibility")
    
    def _create_data_loaders(self) -> Tuple[DataLoader, DataLoader, DataLoader]:
        """Create training, validation, and test data loaders from config."""
        
        # Handle processor-based datasets (like SegFormer)
        if self.config.processor is not None:
            # Use the standalone create_data_loaders function
            return create_data_loaders(
                data_root=self.config.data_root,
                processor=self.config.processor,
                batch_size=self.config.batch_size,
                image_size=self.config.image_size,
                num_workers=self.config.num_workers,
                pin_memory=self.config.pin_memory
            )
        else:
            # For custom datasets without processor, you would implement your own dataset creation here
            # This is a placeholder for when processor=None
            raise NotImplementedError(
                "Custom dataset creation (processor=None) not implemented yet. "
                "Please provide a processor or implement custom dataset logic."
            )
    
    def _load_latest_checkpoint(self):
        """Load latest checkpoint and prepare for resume."""
        print(f"📂 Found checkpoint: {self.latest_checkpoint_path}")
        try:
            self.loaded_checkpoint = torch.load(self.latest_checkpoint_path, map_location='cpu', weights_only=False)
            self.start_epoch = self.loaded_checkpoint.get('epoch', 0) + 1
            
            # Load history if available
            if 'history' in self.loaded_checkpoint:
                history_dict = self.loaded_checkpoint['history']
                # Reconstruct TrainingHistory from dict
                self.history = TrainingHistory()
                for key, value in history_dict.items():
                    if hasattr(self.history, key):
                        setattr(self.history, key, value)
            
            if self.verbose: 
                print(f"✅ Will resume from epoch {self.start_epoch}")
                print(f"📈 Loaded {len(self.history)} epochs of training history")
            
        except Exception as e:
            print(f"❌ Failed to load checkpoint: {e}")
            print("🆕 Starting fresh training")
            self.loaded_checkpoint = None
            self.start_epoch = 1
    
    def _save_latest_checkpoint(self, model, optimizer, scaler, scheduler, epoch, metrics):
        """Save latest checkpoint (every epoch)."""
        checkpoint = {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scaler_state_dict": scaler.state_dict(),
            "scheduler_state_dict": scheduler.state_dict() if scheduler is not None else None,
            "history": self.history.__dict__,
            "config": self.config,
            "metrics": metrics,
            "timestamp": datetime.now().isoformat()
        }
        
        torch.save(checkpoint, self.latest_checkpoint_path)
    
    def _save_best_model(self, model, optimizer, scheduler, epoch, metrics, best_miou):
        """Save best model checkpoint."""
        best_checkpoint_path = self.config.model_dir / "best.pth"
        best_checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
        
        checkpoint = {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict() if scheduler is not None else None,
            "config": self.config,
            "metrics": metrics,
            "history": self.history,
            "best_miou": best_miou
        }
        
        torch.save(checkpoint, best_checkpoint_path)
        print(f"✅ Saved best model to {best_checkpoint_path} (mIoU: {best_miou:.4f})")
    
    @torch.no_grad()
    def evaluate(self, model, loader=None, criterion=None):
        """
        Evaluate model on validation data.
        
        Args:
            model: Model to evaluate
            loader: DataLoader to use (defaults to self.val_loader)
            criterion: Loss function to use for validation loss calculation
            
        Returns:
            Dictionary with evaluation metrics including loss if criterion provided
        """
        model.eval()
        loader = loader or self.val_loader
        
        cm = np.zeros((self.config.num_classes, self.config.num_classes), dtype=np.int64)
        total, correct = 0, 0
        running_val_loss = 0.0
        num_batches = 0
        
        for batch in loader:
            # Handle different batch formats
            if "pixel_values" in batch:
                # SegFormer format
                inputs = batch["pixel_values"].to(self.device, non_blocking=True)
                targets = batch["labels"].to(self.device, non_blocking=True)
            elif "image" in batch:
                # TorchVision format
                inputs = batch["image"].to(self.device, non_blocking=True)
                targets = batch["mask"].to(self.device, non_blocking=True)
            else:
                raise ValueError("Unknown batch format. Expected 'pixel_values' or 'image' keys.")
            
            # Forward pass with mixed precision
            with autocast(device_type=self.device.type, enabled=(self.device.type == "cuda")):
                outputs = model(inputs)
                
                # Handle different model output formats
                if hasattr(outputs, 'logits'):
                    logits = outputs.logits  # SegFormer format
                elif isinstance(outputs, dict) and "out" in outputs:
                    logits = outputs["out"]  # TorchVision DeepLab format
                else:
                    logits = outputs  # Standard tensor output
                
                # Upsample logits to target size if needed
                if logits.shape[-2:] != targets.shape[-2:]:
                    logits = torch.nn.functional.interpolate(
                        logits, size=targets.shape[-2:], mode="bilinear", align_corners=False
                    )
                
                # Calculate validation loss if criterion provided
                if criterion is not None:
                    val_loss = criterion(logits, targets)
                    running_val_loss += val_loss.item()
                    num_batches += 1
            
            pred = logits.argmax(1)  # [B,H,W]
            
            valid = targets != self.config.ignore_index
            total += valid.sum().item()
            correct += (pred[valid] == targets[valid]).sum().item()
            
            p = pred[valid].view(-1).cpu().numpy()
            t = targets[valid].view(-1).cpu().numpy()
            for i in range(p.shape[0]):
                if 0 <= t[i] < self.config.num_classes and 0 <= p[i] < self.config.num_classes:
                    cm[t[i], p[i]] += 1
        
        # Calculate metrics
        inter = np.diag(cm).astype(np.float64)
        gt = cm.sum(1).astype(np.float64)
        pr = cm.sum(0).astype(np.float64)
        union = gt + pr - inter
        iou = inter / np.maximum(union, 1)
        miou = float(np.nanmean(iou)) if iou.size > 0 else 0.0
        pixacc = float(correct / max(total, 1))
        
        metrics = {"mIoU": miou, "PixelAcc": pixacc, "IoU_per_class": iou}
        
        # Add validation loss if calculated
        if criterion is not None and num_batches > 0:
            avg_val_loss = running_val_loss / num_batches
            metrics["loss"] = avg_val_loss
        
        return metrics
    
    def train(self, model, optimizer=None, scheduler=None, criterion=None,
              gradient_clip_val=1.0, early_stopping_patience=None):
        """
        Train the provided model with automatic checkpoint resume.

        Args:
            model: PyTorch model to train
            optimizer: Custom optimizer (defaults to AdamW)
            scheduler: Custom scheduler (defaults to None)
            criterion: Custom loss function (defaults to CrossEntropyLoss)
            gradient_clip_val: Gradient clipping value (None to disable)
            early_stopping_patience: Stop training if no improvement for N epochs (None to disable)

        Returns:
            Trained model
        """
        # Prepare model
        model = model.to(self.device)
        
        # Setup training components with defaults if not provided
        if criterion is None:
            criterion = nn.CrossEntropyLoss(ignore_index=self.config.ignore_index)
            
        if optimizer is None:
            optimizer = optim.AdamW(
                model.parameters(), lr=self.config.learning_rate
            )
            print(f"📊 Using default AdamW optimizer (LR: {self.config.learning_rate:.2e})")
        else:
            print(f"📊 Using custom optimizer: {type(optimizer).__name__}")
            
        scaler = GradScaler(enabled=(self.device.type == "cuda"))
        
        if scheduler:
            print(f"📈 Using custom scheduler: {type(scheduler).__name__}")
        else:
            print("📈 No LR scheduling (constant learning rate)")
        
        # Early stopping setup
        if early_stopping_patience is not None:
            print(f"🛑 Early stopping enabled (patience: {early_stopping_patience} epochs)")
            best_miou_for_early_stop = -1.0
            patience_counter = 0
            min_delta = 0.001  # Minimum improvement threshold
        
        # Load checkpoint states if resuming
        if self.loaded_checkpoint:
            print("🔄 Restoring training state from checkpoint...")
            
            # Load model state
            model.load_state_dict(self.loaded_checkpoint['model_state_dict'])
            
            # Load optimizer state
            if 'optimizer_state_dict' in self.loaded_checkpoint:
                optimizer.load_state_dict(self.loaded_checkpoint['optimizer_state_dict'])
            
            # Load scheduler state
            if scheduler and 'scheduler_state_dict' in self.loaded_checkpoint and self.loaded_checkpoint['scheduler_state_dict']:
                scheduler.load_state_dict(self.loaded_checkpoint['scheduler_state_dict'])
            
            # Load scaler state
            if 'scaler_state_dict' in self.loaded_checkpoint:
                scaler.load_state_dict(self.loaded_checkpoint['scaler_state_dict'])
            
            print("✅ Training state restored")
        
        # Print parameter information
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable parameters: {trainable_params:,}")
        
        if gradient_clip_val is not None:
            print(f"✂️ Gradient clipping enabled (max_norm: {gradient_clip_val})")
        
        # Training setup
        best_miou = max(self.history.val_miou) if self.history.val_miou else -1.0
        
        print(f"🚀 Starting training for {self.config.epochs} epochs...")
        print(f"📊 Training batches: {len(self.train_loader)}, Validation batches: {len(self.val_loader)}")
        print(f"▶️ Resuming from epoch {self.start_epoch}")
        
        for epoch in range(self.start_epoch, self.config.epochs + 1):
            model.train()
            running_loss = 0.0
            epoch_start = time.time()
            
            print(f"\n📈 Epoch {epoch}/{self.config.epochs}")
            
            for it, batch in enumerate(self.train_loader, start=1):
                # Handle different batch formats
                if "pixel_values" in batch:
                    # SegFormer format
                    inputs = batch["pixel_values"].to(self.device, non_blocking=True)
                    targets = batch["labels"].to(self.device, non_blocking=True)
                elif "image" in batch:
                    # TorchVision format
                    inputs = batch["image"].to(self.device, non_blocking=True)
                    targets = batch["mask"].to(self.device, non_blocking=True)
                else:
                    raise ValueError("Unknown batch format. Expected 'pixel_values' or 'image' keys.")
                
                optimizer.zero_grad(set_to_none=True)
                
                # Forward pass with mixed precision
                with autocast(device_type=self.device.type, enabled=(self.device.type == "cuda")):
                    outputs = model(inputs)
                    
                    # Handle different model output formats
                    if hasattr(outputs, 'logits'):
                        logits = outputs.logits  # SegFormer format
                    elif isinstance(outputs, dict) and "out" in outputs:
                        logits = outputs["out"]  # TorchVision DeepLab format
                    else:
                        logits = outputs  # Standard tensor output
                    
                    # Upsample logits to target size if needed
                    if logits.shape[-2:] != targets.shape[-2:]:
                        logits = torch.nn.functional.interpolate(
                            logits, size=targets.shape[-2:], mode="bilinear", align_corners=False
                        )
                    
                    loss = criterion(logits, targets)
                
                # Backward pass
                scaler.scale(loss).backward()
                
                # Gradient clipping
                if gradient_clip_val is not None:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip_val)
                
                scaler.step(optimizer)
                scaler.update()
                
                running_loss += loss.item()
                
                # Log periodically
                if it % self.config.log_every == 0:
                    avg_loss = running_loss / it
                    current_lr = optimizer.param_groups[0]['lr']
                    print(f"   Batch {it:05d}/{len(self.train_loader)} | Loss: {avg_loss:.4f} | LR: {current_lr:.2e}")
            
            # Step scheduler if provided
            if scheduler is not None:
                scheduler.step()
            
            # Validation
            print("🔍 Running validation...")
            val_metrics = self.evaluate(model, criterion=criterion)
            current_miou = val_metrics["mIoU"]
            
            # Early stopping check
            if early_stopping_patience is not None:
                if current_miou > best_miou_for_early_stop + min_delta:
                    best_miou_for_early_stop = current_miou
                    patience_counter = 0
                else:
                    patience_counter += 1
                
                print(f"📊 Early stopping: {patience_counter}/{early_stopping_patience} (best: {best_miou_for_early_stop:.4f})")
                
                if patience_counter >= early_stopping_patience:
                    print(f"🛑 Early stopping triggered at epoch {epoch}")
                    print(f"   No improvement in validation mIoU for {early_stopping_patience} epochs")
                    break
            
            # Calculate epoch metrics
            epoch_time = time.time() - epoch_start
            avg_train_loss = running_loss / len(self.train_loader)
            current_lr = optimizer.param_groups[0]['lr']
            
            # Add to history
            self.history.add_epoch({
                "train_loss": avg_train_loss,
                "train_miou": 0.0,  # Not calculated during training for performance
                "train_pixacc": 0.0,  # Not calculated during training for performance
                "val_loss": val_metrics.get("loss", 0.0),
                "val_miou": val_metrics["mIoU"],
                "val_pixacc": val_metrics["PixelAcc"],
                "val_iou_per_class": val_metrics["IoU_per_class"].tolist(),
                "learning_rate": current_lr,
                "epoch_time": epoch_time
            })
            
            # Print epoch summary
            print(f"\n📊 Epoch {epoch} Results:")
            print(f"   • Train Loss: {avg_train_loss:.4f}")
            if "loss" in val_metrics:
                print(f"   • Val Loss: {val_metrics['loss']:.4f}")
            print(f"   • Val mIoU: {val_metrics['mIoU']:.4f}")
            print(f"   • Val PixelAcc: {val_metrics['PixelAcc']:.4f}")
            print(f"   • Epoch Time: {epoch_time:.1f}s")
            print(f"   • Learning Rate: {current_lr:.2e}")
            
            # Save best model
            if val_metrics["mIoU"] > best_miou:
                best_miou = val_metrics["mIoU"]
                self._save_best_model(model, optimizer, scheduler, epoch, val_metrics, best_miou)
            
            # Save latest checkpoint every epoch
            self._save_latest_checkpoint(model, optimizer, scaler, scheduler, epoch, val_metrics)
            
            # Save every 5 epochs 
            if epoch % 5 == 0:
                epoch_checkpoint_path = self.config.model_dir / f"checkpoint_epoch_{epoch}.pth"
                torch.save({
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict() if scheduler is not None else None,
                    "scaler_state_dict": scaler.state_dict(),
                    "history": self.history.__dict__,
                    "config": self.config,
                    "metrics": val_metrics,
                    "timestamp": datetime.now().isoformat()
                }, epoch_checkpoint_path)
                print(f"💾 Saved periodic checkpoint to {epoch_checkpoint_path}")

            print("-" * 80)
        
        print(f"\n🎉 Training Complete!")
        print(f"   • Best Validation mIoU: {best_miou:.4f}")
        print(f"   • Total Training Time: {sum(self.history.epoch_times):.1f}s")
        
        return model

print("✅ Universal Trainer class with dual model support defined!")

✅ Universal Trainer class with dual model support defined!


In [11]:
def print_model_status(model):
    """Show which parts of the model are frozen/unfrozen."""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    frozen_params = total_params - trainable_params
    
    print(f"📊 Model status: {trainable_params:,} trainable / {total_params:,} total")
    print(f"   ({frozen_params:,} frozen parameters)")


In [12]:
segformer_training_config = TrainingConfig(
    data_root=DATA_ROOT,
    epochs=300,
    batch_size=32,
    learning_rate=1e-4,  # Conservative learning rate
    model_name="segformer_b0",
    processor=SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
)

# Create fresh SegFormer model
segformer_model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/segformer-b0-finetuned-ade-512-512",
    num_labels=40,
    ignore_mismatched_sizes=True
)

# freeze all encoder layers
for param in segformer_model.segformer.encoder.parameters():
    param.requires_grad = False

print_model_status(segformer_model)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b0-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([40]) in the model instantiated
- decode_head.classifier.weight: found shape torch.Size([150, 256, 1, 1]) in the checkpoint and torch.Size([40, 256, 1, 1]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


📊 Model status: 405,032 trainable / 3,724,424 total
   (3,319,392 frozen parameters)


In [13]:
def plot_training_history(history: TrainingHistory):
    """Plot training and validation loss and mIoU over epochs."""
    epochs = range(1, len(history) + 1)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot Loss
    ax1.plot(epochs, history.train_loss, label='Train Loss', color='blue')
    ax1.plot(epochs, history.val_loss, label='Val Loss', color='orange')
    ax1.set_title('Training and Validation Loss')
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)
    
    # Plot mIoU
    ax2.plot(epochs, history.val_miou, label='Val mIoU', color='green')
    ax2.plot(epochs,history.val_pixacc, label='Val PixelAcc', color='red')
    ax2.set_title('Validation mIoU')
    ax2.set_xlabel('Epochs')
    ax2.set_ylabel('mIoU')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()

In [14]:
# # create cosine annealing scheduler
# scheduler = optim.lr_scheduler.CosineAnnealingLR(
#     optim.AdamW(segformer_model.parameters(), lr=segformer_training_config.learning_rate),
#     T_max=segformer_training_config.epochs
# )

# trainer = Trainer(segformer_training_config)
# trained_model = trainer.train(segformer_model, scheduler=scheduler, early_stopping_patience=5)

# plot_training_history(trainer.history)

In [None]:
from torchvision.models.segmentation import deeplabv3_mobilenet_v3_large, DeepLabV3_MobileNet_V3_Large_Weights # fcn_resnet50, FCN_ResNet50_Weights

dl_mnv3large_training_config = TrainingConfig(
    epochs=20,
    batch_size=64,
    learning_rate=1e-4,
    model_name="deeplabv3_mobilenet_v3_large",
    processor=DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT.transforms()
)

model_dl_mnv3large = deeplabv3_mobilenet_v3_large(weights=DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT)

# Modify the final layer to match NYU40 classes
model_dl_mnv3large.classifier[-1] = nn.Conv2d(256, 40, kernel_size=(1, 1), stride=(1, 1))  # 40 classes for NYU40

print("Layers in the classifier:")
for i, layer in enumerate(model_dl_mnv3large.classifier):
    print(f"Layer {i}: {type(layer).__name__}")


Layers in the classifier:
Layer 0: ASPP
Layer 1: Conv2d
Layer 2: BatchNorm2d
Layer 3: ReLU
Layer 4: Conv2d


In [46]:
# Create data loaders for exploration
train_loader, val_loader, test_loader = create_data_loaders(
    data_root=DATA_ROOT,
    processor=dl_mnv3large_training_config.processor,
    batch_size=8,
    image_size=(240, 320),
    num_workers=0,
    pin_memory=True
)

# Test data loading and visualization
print("\\n🧪 Testing data loading...")
sample_batch = next(iter(train_loader))
print(f"✅ Batch loaded successfully!")

# Visualize samples for Segformer
print("\\n🖼️ Visualizing batch samples...")
# visualize_batch(sample_batch, num_samples=8)

📏 Using custom processor: SemanticSegmentation
SemanticSegmentation(
    resize_size=[520]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)
SemanticSegmentation(
    resize_size=[520]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)
SemanticSegmentation(
    resize_size=[520]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)
📊 Train: 636 samples,  Val: 159 samples,  Test: 654 samples
📦 Train batches: 80,  Val batches: 20,  Test batches: 82
\n🧪 Testing data loading...
✅ Batch loaded successfully!
\n🖼️ Visualizing batch samples...




In [None]:
# Freeze all parameters
for param in model_dl_mnv3large.parameters():
    param.requires_grad = False
    
# Phase 1: train only classifier head
print("Unfreezing classifier: final classifier layer")
for param in model_dl_mnv3large.classifier[-1].parameters():
    param.requires_grad = True
print_model_status(model_dl_mnv3large)
dl_mnv3large_training_config.epochs = 20
trainer = Trainer(dl_mnv3large_training_config)
trainer.train(model_dl_mnv3large, early_stopping_patience=5)
plot_training_history(trainer.history)

In [None]:
# Phase 2: Unfreeze ASPP and classifier head
print("Unfreezing classifier: ASPP and classifier head")
for param in model_dl_mnv3large.classifier[0].parameters():
    param.requires_grad = True
print_model_status(model_dl_mnv3large)

dl_mnv3large_training_config.epochs = 30
dl_mnv3large_training_config.learning_rate = 5e-5  # Lower learning rate for fine-tuning
trainer = Trainer(dl_mnv3large_training_config)
trainer.train(model_dl_mnv3large, early_stopping_patience=5)
plot_training_history(trainer.history)

In [None]:
# Phase 3: Unfreeze Conv block
print("Unfreezing classifier: Conv block, ASPP, and classifier head")
for param in model_dl_mnv3large.classifier[1].parameters():
    param.requires_grad = True
print_model_status(model_dl_mnv3large)

dl_mnv3large_training_config.epochs = 30
dl_mnv3large_training_config.learning_rate = 1e-5  # Lower learning rate for fine-tuning
trainer = Trainer(dl_mnv3large_training_config)
trainer.train(model_dl_mnv3large, early_stopping_patience=5)
plot_training_history(trainer.history)

In [None]:
# Phase 4: All encoder unfrozen (ie unfreeze BN layer only)
print("Unfreezing classifier: Conv block, ASPP, and classifier head all classifier layers")
for param in model_dl_mnv3large.classifier.parameters():
    param.requires_grad = True
print_model_status(model_dl_mnv3large)

dl_mnv3large_training_config.epochs = 20
trainer = Trainer(dl_mnv3large_training_config)
trainer.train(model_dl_mnv3large, early_stopping_patience=5)
plot_training_history(trainer.history)

In [None]:
# Phase 5: unfreeze all params
print("Unfreezing entire model for fine-tuning")
for param in model_dl_mnv3large.parameters():
    param.requires_grad = True
print_model_status(model_dl_mnv3large)
dl_mnv3large_training_config.epochs = 200
dl_mnv3large_training_config.learning_rate = 5e-6  # Lower learning rate for unfrozen training
trainer = Trainer(dl_mnv3large_training_config)
trainer.train(model_dl_mnv3large, early_stopping_patience=5)
plot_training_history(trainer.history)

In [18]:
class ModelEvaluator:
    """Lightweight evaluator for loading and testing trained models from 
checkpoints."""

    def __init__(self, model, checkpoint_dir_name):
        """
        Initialize evaluator with a model and checkpoint directory.
        
        Args:
            model: The model instance to evaluate (should match checkpoint 
architecture)
            checkpoint_dir_name: Name of the checkpoint directory (e.g., 
"deeplabv3_resnet50")
        """
        self.model = model
        self.checkpoint_dir = Path("checkpoints") / checkpoint_dir_name
        self.device = self._setup_device()
        self.config = None
        self.loaded_checkpoint = None

    def _setup_device(self) -> torch.device:
        """Setup evaluation device (reuse Trainer logic)."""
        if torch.cuda.is_available():
            device = torch.device("cuda")
        elif torch.backends.mps.is_available():
            device = torch.device("cpu")  # Temporary fix for MPS issues
        else:
            device = torch.device("cpu")
        print(f"🔧 Evaluation device: {device}")
        return device

    def load_checkpoint(self, checkpoint_name="best.pth"):
        """
        Load model weights and config from checkpoint.
        
        Args:
            checkpoint_name: Name of checkpoint file ("best.pth", "latest.pth", etc.)
        """
        checkpoint_path = self.checkpoint_dir / checkpoint_name

        if not checkpoint_path.exists():
            raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

        print(f"📂 Loading checkpoint: {checkpoint_path}")

        # Load checkpoint
        self.loaded_checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)

        # Extract config
        self.config = self.loaded_checkpoint['config']

        # Load model weights
        self.model.load_state_dict(self.loaded_checkpoint['model_state_dict'])
        self.model = self.model.to(self.device)
        self.model.eval()

        # Print checkpoint info
        epoch = self.loaded_checkpoint.get('epoch', 'Unknown')
        metrics = self.loaded_checkpoint.get('metrics', {})
        print(f"✅ Loaded checkpoint from epoch {epoch}")
        if 'mIoU' in metrics:
            print(f"📊 Checkpoint mIoU: {metrics['mIoU']:.4f}")

        return self

    def evaluate_on_sets(self):
        """
        Evaluate model on validation and test sets using saved config.
        
        Returns:
            Dict with validation and test results
        """
        if self.config is None:
            raise ValueError("Must load checkpoint first using load_checkpoint()")

        print(f"🔍 Creating data loaders from saved config...")

        # Create data loaders using saved config
        train_loader, val_loader, test_loader = create_data_loaders(
            data_root=self.config.data_root,
            processor=self.config.processor,
            batch_size=self.config.batch_size,
            image_size=self.config.image_size,
            num_workers=self.config.num_workers,
            pin_memory=self.config.pin_memory
        )

        # Create a minimal trainer instance just for the evaluate method

        temp_trainer = Trainer(self.config, verbose=False)

        # Evaluate on validation set
        print(f"\n📊 Evaluating on validation set...")
        val_metrics = temp_trainer.evaluate(self.model, val_loader)

        # Evaluate on test set  
        print(f"📊 Evaluating on test set...")
        test_metrics = temp_trainer.evaluate( self.model, test_loader)

        # Format results
        results = {
            "validation": val_metrics,
            "test": test_metrics,
            "checkpoint_info": {
                "epoch": self.loaded_checkpoint.get('epoch'),
                "checkpoint_metrics": self.loaded_checkpoint.get('metrics', {})
            }
        }

        # Print summary
        print(f"\n🎯 Evaluation Results:")
        print(f"{'Set':<12} {'mIoU':<8} {'PixelAcc':<10}")
        print("-" * 32)
        print(f"{'Validation':<12} {val_metrics['mIoU']:<8.4f} {val_metrics['PixelAcc']:<10.4f}")
        print(f"{'Test':<12} {test_metrics['mIoU']:<8.4f} {test_metrics['PixelAcc']:<10.4f}")

        return results

print("✅ ModelEvaluator class defined!")

✅ ModelEvaluator class defined!


In [19]:
segformer_model_eval = ModelEvaluator(segformer_model, "segformer_b0")

segformer_model_eval.load_checkpoint("best.pth")
segformer_model_eval.evaluate_on_sets()

🔧 Evaluation device: cpu
📂 Loading checkpoint: checkpoints/segformer_b0/best.pth
✅ Loaded checkpoint from epoch 84
📊 Checkpoint mIoU: 0.2439
🔍 Creating data loaders from saved config...
📏 Using SegFormer processor with size: {'height': 512, 'width': 512}
📊 Train: 636 samples,  Val: 159 samples,  Test: 654 samples
📦 Train batches: 20,  Val batches: 5,  Test batches: 21
Using device: cpu
🎲 Random seeds set to 42 for reproducibility
📏 Using SegFormer processor with size: {'height': 512, 'width': 512}
📊 Train: 636 samples,  Val: 159 samples,  Test: 654 samples
📦 Train batches: 20,  Val batches: 5,  Test batches: 21
📂 Found checkpoint: checkpoints/segformer_b0/latest.pth
❌ Failed to load checkpoint: 'Trainer' object has no attribute 'verbose'
🆕 Starting fresh training

📊 Evaluating on validation set...
📊 Evaluating on test set...

🎯 Evaluation Results:
Set          mIoU     PixelAcc  
--------------------------------
Validation   0.2439   0.6611    
Test         0.3317   0.6499    


{'validation': {'mIoU': 0.24392526795392336,
  'PixelAcc': 0.6611169348869812,
  'IoU_per_class': array([0.37849045, 0.69339775, 0.71812873, 0.36620149, 0.72410672,
         0.46212947, 0.56167901, 0.42161724, 0.17991459, 0.34473374,
         0.25828629, 0.53749807, 0.14937577, 0.37125376, 0.05110136,
         0.06615942, 0.30263649, 0.20773736, 0.26192   , 0.10157913,
         0.30809409, 0.2482833 , 0.61231308, 0.12936746, 0.        ,
         0.30012957, 0.20733366, 0.        , 0.        , 0.02480961,
         0.        , 0.26239706, 0.13460407, 0.        , 0.        ,
         0.15390921, 0.        , 0.        , 0.14669873, 0.07112406])},
 'test': {'mIoU': 0.331694428039324,
  'PixelAcc': 0.6499031482781346,
  'IoU_per_class': array([0.39428265, 0.66724136, 0.71529508, 0.52249209, 0.65283061,
         0.45790927, 0.55360578, 0.34235176, 0.257325  , 0.38311996,
         0.36877846, 0.53099914, 0.47119343, 0.50339252, 0.14686693,
         0.06762825, 0.49665052, 0.42461858, 0.3118506

In [None]:
model_dl_mnv3large_eval = ModelEvaluator(model_dl_mnv3large, "deeplabv3_resnet50")
model_dl_mnv3large_eval.load_checkpoint("best.pth")
model_dl_mnv3large_eval.evaluate_on_sets()

🔧 Evaluation device: cpu
📂 Loading checkpoint: checkpoints/deeplabv3_resnet50/best.pth
✅ Loaded checkpoint from epoch 1
📊 Checkpoint mIoU: 0.0037
🔍 Creating data loaders from saved config...
📏 Using custom processor: SemanticSegmentation
📊 Train: 636 samples,  Val: 159 samples,  Test: 654 samples
📦 Train batches: 20,  Val batches: 5,  Test batches: 21
Using device: cpu
🎲 Random seeds set to 42 for reproducibility
📏 Using custom processor: SemanticSegmentation
📊 Train: 636 samples,  Val: 159 samples,  Test: 654 samples
📦 Train batches: 20,  Val batches: 5,  Test batches: 21
📂 Found checkpoint: checkpoints/deeplabv3_resnet50/latest.pth
❌ Failed to load checkpoint: 'Trainer' object has no attribute 'verbose'
🆕 Starting fresh training

📊 Evaluating on validation set...
📊 Evaluating on test set...

🎯 Evaluation Results:
Set          mIoU     PixelAcc  
--------------------------------
Validation   0.0037   0.0253    
Test         0.0046   0.0250    


{'validation': {'mIoU': 0.0036749019225885164,
  'PixelAcc': 0.025326054077795257,
  'IoU_per_class': array([2.01587630e-02, 3.08297507e-02, 8.08725702e-03, 2.28449855e-02,
         1.05620294e-03, 1.49315430e-03, 0.00000000e+00, 1.50002625e-05,
         2.34294558e-04, 0.00000000e+00, 2.53473691e-03, 0.00000000e+00,
         0.00000000e+00, 3.06587878e-02, 0.00000000e+00, 0.00000000e+00,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
         1.55034307e-02, 0.00000000e+00, 1.62816729e-05, 0.00000000e+00,
         0.00000000e+00, 1.24630109e-03, 0.00000000e+00, 0.00000000e+00,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
         1.49115250e-04, 0.00000000e+00, 0.00000000e+00, 4.45306490e-03,
         0.00000000e+00, 0.00000000e+00, 5.27164625e-03, 2.44330405e-03])},
 'test': {'mIoU': 0.00463618206942854,
  'PixelAcc': 0.02498658294723772,
  'IoU_per_class': array([3.37432214e-02, 2.98771322e-02, 8.09790839e-03, 2.67364434e-02,
   

## Visualise Predictions

In [31]:
def get_model_predictions(model, dataset, device, sample_indices=[0, 1, 2, 3]):
    """
    Get model predictions for specified samples.
    
    Returns:
        List of dicts with keys: 'rgb', 'ground_truth', 'prediction', 'sample_id'
    """
    
    def denormalize_imagenet(tensor: torch.Tensor) -> np.ndarray:
        """Denormalize ImageNet preprocessed image."""
        IMAGENET_MEAN = np.array([0.485, 0.456, 0.406])
        IMAGENET_STD = np.array([0.229, 0.224, 0.225])
        
        img = tensor.permute(1, 2, 0).cpu().numpy()
        img = (img * IMAGENET_STD + IMAGENET_MEAN).clip(0, 1)
        return (img * 255).astype(np.uint8)
    
    model.eval()
    results = []
    
    with torch.no_grad():
        for idx in sample_indices:
            # Get original image size by loading raw image
            stem, rgb_path, label_path = dataset.items[idx]
            
            # Load original image to get true dimensions
            import cv2
            original_rgb = cv2.imread(str(rgb_path), cv2.IMREAD_COLOR)
            original_rgb = cv2.cvtColor(original_rgb, cv2.COLOR_BGR2RGB)
            original_height, original_width = original_rgb.shape[:2]
            
            # Get processed sample from dataset
            sample = dataset[idx]
            
            # Extract data based on format
            if "pixel_values" in sample:
                # SegFormer format
                image_tensor = sample["pixel_values"].unsqueeze(0).to(device)
                label = sample["labels"].cpu().numpy()
                sample_id = sample["id"]
            elif "image" in sample:
                # TorchVision format
                image_tensor = sample["image"].unsqueeze(0).to(device)
                label = sample["mask"].cpu().numpy()
                sample_id = sample["id"]
            else:
                raise ValueError("Unknown sample format")
            
            # Get model prediction
            outputs = model(image_tensor)
            
            # Handle different model output formats
            if hasattr(outputs, 'logits'):
                logits = outputs.logits  # SegFormer
            elif isinstance(outputs, dict) and "out" in outputs:
                logits = outputs["out"]  # TorchVision DeepLab
            else:
                logits = outputs  # Standard tensor
            
            # Always resize prediction to ORIGINAL image size for fair comparison
            if logits.shape[-2:] != (original_height, original_width):
                logits = torch.nn.functional.interpolate(
                    logits, size=(original_height, original_width), mode="bilinear", align_corners=False
                )
            
            prediction = logits.argmax(1).squeeze(0).cpu().numpy()
            
            # Denormalize input image and resize to original size for display
            rgb_display = denormalize_imagenet(image_tensor.squeeze(0))
            if rgb_display.shape[:2] != (original_height, original_width):
                rgb_display = cv2.resize(rgb_display, (original_width, original_height), interpolation=cv2.INTER_LINEAR)
            
            # Also resize ground truth label to original size for consistent comparison
            if label.shape != (original_height, original_width):
                label_resized = cv2.resize(
                    label.astype(np.uint8), (original_width, original_height), interpolation=cv2.INTER_NEAREST
                ).astype(np.int64)
            else:
                label_resized = label
            
            results.append({
                'rgb': rgb_display,
                'ground_truth': label_resized,
                'prediction': prediction,
                'sample_id': sample_id,
                'original_size': (original_height, original_width)
            })
    
    return results

In [32]:
def visualize_multi_model_predictions(predictions_dict, num_classes=40):
    """
    Visualize predictions from multiple models side by side.
    
    Args:
        predictions_dict: Dict where keys are model names and values are prediction results
        num_classes: Number of classes for colorization
    """
    
    def colorize_mask(mask: np.ndarray, num_classes: int = 40) -> np.ndarray:
        """Create colored visualization of segmentation mask."""
        np.random.seed(42)  # Consistent colors across models
        colors = np.random.randint(0, 255, (num_classes + 1, 3), dtype=np.uint8)
        colors[-1] = [128, 128, 128]  # Ignore class = gray
        
        colored = np.zeros((*mask.shape, 3), dtype=np.uint8)
        for class_id in range(num_classes + 1):
            if class_id == 255:  # Handle ignore index
                colored[mask == 255] = [128, 128, 128]
            else:
                colored[mask == class_id] = colors[class_id]
        return colored
    
    model_names = list(predictions_dict.keys())
    num_models = len(model_names)
    
    # Get number of samples from first model's predictions
    num_samples = len(predictions_dict[model_names[0]])
    
    # Rows: RGB + Ground Truth + Model predictions
    num_rows = 2 + num_models  
    
    fig, axes = plt.subplots(num_rows, num_samples, figsize=(num_samples * 4, num_rows * 3))
    if num_samples == 1:
        axes = axes.reshape(-1, 1)
    if num_rows == 1:
        axes = axes.reshape(1, -1)
    
    for i in range(num_samples):
        # Use first model's data for RGB and ground truth (should be same across models)
        first_model_data = predictions_dict[model_names[0]][i]
        
        # Plot RGB image
        axes[0, i].imshow(first_model_data['rgb'])
        axes[0, i].set_title(f"RGB\n{first_model_data['sample_id']}")
        axes[0, i].axis("off")
        
        # Plot ground truth
        colored_gt = colorize_mask(first_model_data['ground_truth'], num_classes)
        axes[1, i].imshow(colored_gt)
        axes[1, i].set_title("Ground Truth")
        axes[1, i].axis("off")
        
        # Plot each model's prediction
        for j, model_name in enumerate(model_names):
            prediction = predictions_dict[model_name][i]['prediction']
            colored_pred = colorize_mask(prediction, num_classes)
            
            row_idx = 2 + j
            axes[row_idx, i].imshow(colored_pred)
            axes[row_idx, i].set_title(f"{model_name}")
            axes[row_idx, i].axis("off")
    
    plt.suptitle("Multi-Model Predictions Comparison", fontsize=16)
    plt.tight_layout()
    plt.show()

print("✅ Multi-model prediction and visualization functions defined!")

✅ Multi-model prediction and visualization functions defined!


In [30]:
segformer_model.load_state_dict(torch.load("checkpoints/segformer_b0/best.pth", weights_only=False, map_location="cpu")['model_state_dict'])
model_dl_mnv3large.load_state_dict(torch.load("checkpoints/deeplabv3_resnet50/dl_best2.pth", weights_only=False, map_location="cpu")['model_state_dict'])

<All keys matched successfully>

In [45]:
# Example Usage: Multi-Model Prediction Comparison

# Step 1: Create datasets for each model
segformer_processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
torchvision_processor = DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT.transforms()

val_dataset_segformer = NYUDepthDataset(DATA_ROOT, "test", processor=segformer_processor)
val_dataset_torchvision = NYUDepthDataset(DATA_ROOT, "test", processor=torchvision_processor)

# Step 2: Get predictions from all models
device = torch.device("cpu")  # or "cuda" if available
sample_indices = [0, 5, 10, 15, 20, 25]  # Which samples to compare

print("🔮 Getting SegFormer predictions...")
segformer_predictions = get_model_predictions(
    segformer_model, val_dataset_segformer, device, sample_indices
)

print("🔮 Getting DeepLabV3 predictions...")
deeplabv3_predictions = get_model_predictions(
    model_dl_mnv3large, val_dataset_torchvision, device, sample_indices  
)

# Step 3: Visualize all models together
predictions_dict = {
    "SegFormer-B0": segformer_predictions,
    "DeepLabV3-ResNet50": deeplabv3_predictions
}

print("🖼️ Creating multi-model comparison visualization...")
# visualize_multi_model_predictions(predictions_dict)  # Now automatically shows all samples

print("✅ Multi-model prediction comparison complete!")

loading configuration file preprocessor_config.json from cache at /Users/reubendrummond/.cache/huggingface/hub/models--nvidia--segformer-b0-finetuned-ade-512-512/snapshots/489d5cd81a0b59fab9b7ea758d3548ebe99677da/preprocessor_config.json
  image_processor = cls(**image_processor_dict)
size should be a dictionary on of the following set of keys: ({'height', 'width'}, {'shortest_edge'}, {'longest_edge', 'shortest_edge'}, {'longest_edge'}, {'max_width', 'max_height'}), got 512. Converted to {'height': 512, 'width': 512}.
Image processor SegformerImageProcessor {
  "do_normalize": true,
  "do_reduce_labels": false,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.485,
    0.456,
    0.406
  ],
  "image_processor_type": "SegformerImageProcessor",
  "image_std": [
    0.229,
    0.224,
    0.225
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 512,
    "width": 512
  }
}



SegformerImageProcessor {
  "do_normalize": true,
  "do_reduce_labels": false,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.485,
    0.456,
    0.406
  ],
  "image_processor_type": "SegformerImageProcessor",
  "image_std": [
    0.229,
    0.224,
    0.225
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 512,
    "width": 512
  }
}

SemanticSegmentation(
    resize_size=[520]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)
🔮 Getting SegFormer predictions...
🔮 Getting DeepLabV3 predictions...
🖼️ Creating multi-model comparison visualization...
✅ Multi-model prediction comparison complete!
