# 08 - Modern Computer Vision: Transfer Learning, Detection, and Segmentation

## Learning Objectives

By the end of this notebook, you will:

1. **Master transfer learning** - Using pretrained models, fine-tuning strategies, feature extraction
2. **Understand object detection** - Region proposals, anchor boxes, YOLO concepts
3. **Implement semantic segmentation** - FCN, U-Net architecture, pixel-wise classification
4. **Use torchvision models** - Leveraging the model zoo for various CV tasks
5. **Apply modern techniques** - Data augmentation, test-time augmentation, model ensembling

---

## Setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision
import torchvision.transforms as T
from torchvision import models
from torchvision.datasets import CIFAR10, VOCSegmentation
from torchvision.models import resnet18, resnet50, ResNet18_Weights, ResNet50_Weights
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import requests
from io import BytesIO
from typing import List, Tuple, Optional, Dict
import os

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")

torch.manual_seed(42)
np.random.seed(42)

---

## 1. Transfer Learning

Transfer learning uses knowledge from a model trained on a large dataset (like ImageNet) and applies it to a new task. This works because early layers learn general features (edges, textures) that transfer well.

### 1.1 Loading Pretrained Models

In [None]:
# Modern way to load pretrained models (PyTorch 2.0+)

# Load ResNet-18 with ImageNet weights
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)

# Inspect the architecture
print("ResNet-18 Architecture:")
print(f"  Input: 3 channels (RGB)")
print(f"  conv1: {model.conv1}")
print(f"  fc (classifier): {model.fc}")
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Get the preprocessing transforms for the model
weights = ResNet18_Weights.IMAGENET1K_V1
preprocess = weights.transforms()

print("Preprocessing transforms:")
print(preprocess)

In [None]:
# Test on a sample image

def load_image_from_url(url: str) -> Image.Image:
    """Load an image from URL"""
    response = requests.get(url)
    return Image.open(BytesIO(response.content)).convert('RGB')


# Load a sample image (using a public domain image)
sample_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/4d/Cat_November_2010-1a.jpg/1200px-Cat_November_2010-1a.jpg"

try:
    img = load_image_from_url(sample_url)
    
    # Preprocess and predict
    model.eval()
    with torch.no_grad():
        x = preprocess(img).unsqueeze(0)
        output = model(x)
        probs = F.softmax(output, dim=1)
    
    # Get top 5 predictions
    categories = weights.meta["categories"]
    top5_probs, top5_idx = probs[0].topk(5)
    
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(img)
    plt.axis('off')
    plt.title('Input Image')
    
    plt.subplot(1, 2, 2)
    y_pos = range(5)
    plt.barh(y_pos, top5_probs.numpy())
    plt.yticks(y_pos, [categories[i] for i in top5_idx])
    plt.xlabel('Probability')
    plt.title('Top 5 Predictions')
    plt.tight_layout()
    plt.show()
except Exception as e:
    print(f"Could not load image: {e}")
    print("Skipping visualization...")

### 1.2 Transfer Learning Strategies

There are three main strategies:

1. **Feature Extraction**: Freeze all pretrained layers, only train new classifier
2. **Fine-tuning (Full)**: Train all layers with a small learning rate
3. **Fine-tuning (Gradual)**: Freeze early layers, train later layers

In [None]:
# Strategy 1: Feature Extraction

def create_feature_extractor(num_classes: int) -> nn.Module:
    """
    Create a model for feature extraction.
    Freezes all pretrained weights, only trains the classifier.
    """
    model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    
    # Freeze all layers
    for param in model.parameters():
        param.requires_grad = False
    
    # Replace classifier
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, num_classes)
    
    # Count trainable parameters
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    print(f"Trainable parameters: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")
    
    return model


model_fe = create_feature_extractor(num_classes=10)
print(f"New classifier: {model_fe.fc}")

In [None]:
# Strategy 2: Full Fine-tuning

def create_finetuned_model(num_classes: int) -> nn.Module:
    """
    Create a model for full fine-tuning.
    All layers are trainable, use smaller learning rate for pretrained layers.
    """
    model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    
    # Replace classifier
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, num_classes)
    
    # All parameters trainable
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"All parameters trainable: {trainable:,}")
    
    return model


def get_param_groups(model, lr_backbone=1e-4, lr_head=1e-3):
    """
    Create parameter groups with different learning rates.
    Backbone (pretrained): smaller LR
    Head (new): larger LR
    """
    backbone_params = []
    head_params = []
    
    for name, param in model.named_parameters():
        if 'fc' in name:
            head_params.append(param)
        else:
            backbone_params.append(param)
    
    return [
        {'params': backbone_params, 'lr': lr_backbone},
        {'params': head_params, 'lr': lr_head}
    ]


model_ft = create_finetuned_model(num_classes=10)
param_groups = get_param_groups(model_ft)
print(f"Backbone params: {len(param_groups[0]['params'])} tensors, LR={param_groups[0]['lr']}")
print(f"Head params: {len(param_groups[1]['params'])} tensors, LR={param_groups[1]['lr']}")

In [None]:
# Strategy 3: Gradual Unfreezing

class GradualUnfreezer:
    """
    Gradually unfreeze layers during training.
    Start with only classifier trainable, then unfreeze deeper layers.
    """
    
    def __init__(self, model, layer_groups: List[str]):
        """
        Args:
            model: PyTorch model
            layer_groups: List of layer name prefixes to unfreeze in order
        """
        self.model = model
        self.layer_groups = layer_groups
        self.current_group = -1
        
        # Initially freeze all except classifier
        for param in model.parameters():
            param.requires_grad = False
        for param in model.fc.parameters():
            param.requires_grad = True
    
    def unfreeze_next(self):
        """Unfreeze the next layer group"""
        self.current_group += 1
        
        if self.current_group >= len(self.layer_groups):
            print("All layers already unfrozen")
            return
        
        group_name = self.layer_groups[self.current_group]
        
        for name, param in self.model.named_parameters():
            if name.startswith(group_name):
                param.requires_grad = True
        
        trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        print(f"Unfroze {group_name}, trainable params: {trainable:,}")


# Example usage with ResNet
model_gradual = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model_gradual.fc = nn.Linear(model_gradual.fc.in_features, 10)

# Define layer groups (from last to first)
layer_groups = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1']
unfreezer = GradualUnfreezer(model_gradual, layer_groups)

print("Initial state:")
print(f"Trainable: {sum(p.numel() for p in model_gradual.parameters() if p.requires_grad):,}")

print("\nUnfreezing layers:")
for _ in range(3):
    unfreezer.unfreeze_next()

### 1.3 Training with Transfer Learning

In [None]:
# Prepare CIFAR-10 for transfer learning

# CIFAR-10 images are 32x32, ResNet expects 224x224
train_transform = T.Compose([
    T.Resize(224),
    T.RandomHorizontalFlip(),
    T.RandomRotation(10),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform = T.Compose([
    T.Resize(224),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load datasets
train_dataset = CIFAR10('../data', train=True, download=True, transform=train_transform)
test_dataset = CIFAR10('../data', train=False, download=True, transform=test_transform)

# Use smaller subset for faster training
train_subset, _ = random_split(train_dataset, [5000, 45000])
test_subset, _ = random_split(test_dataset, [1000, 9000])

train_loader = DataLoader(train_subset, batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(test_subset, batch_size=32, shuffle=False, num_workers=2)

print(f"Training samples: {len(train_subset)}")
print(f"Test samples: {len(test_subset)}")

In [None]:
# Training functions

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
    
    return total_loss / len(loader), 100. * correct / total


@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        loss = criterion(output, target)
        
        total_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
    
    return total_loss / len(loader), 100. * correct / total

In [None]:
# Train with feature extraction

model = create_feature_extractor(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)

print("Training with Feature Extraction (frozen backbone):")
for epoch in range(5):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    print(f"Epoch {epoch+1}: Train Acc: {train_acc:.1f}% | Test Acc: {test_acc:.1f}%")

In [None]:
# Fine-tune with different learning rates

model = create_finetuned_model(num_classes=10).to(device)
param_groups = get_param_groups(model, lr_backbone=1e-5, lr_head=1e-3)
optimizer = torch.optim.Adam(param_groups)

print("\nTraining with Full Fine-tuning (differential LR):")
for epoch in range(5):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    print(f"Epoch {epoch+1}: Train Acc: {train_acc:.1f}% | Test Acc: {test_acc:.1f}%")

---

## 2. Object Detection

Object detection finds and classifies multiple objects in an image. Key concepts:

- **Bounding boxes**: (x, y, width, height) or (x1, y1, x2, y2)
- **IoU (Intersection over Union)**: Measures overlap between boxes
- **Anchor boxes**: Predefined box shapes the model refines
- **Non-max suppression**: Removes duplicate detections

### 2.1 Understanding Bounding Boxes

In [None]:
def calculate_iou(box1: torch.Tensor, box2: torch.Tensor) -> torch.Tensor:
    """
    Calculate Intersection over Union (IoU) between two boxes.
    
    Args:
        box1, box2: Tensors of shape (4,) with [x1, y1, x2, y2]
    
    Returns:
        IoU value between 0 and 1
    """
    # Calculate intersection
    x1 = torch.max(box1[0], box2[0])
    y1 = torch.max(box1[1], box2[1])
    x2 = torch.min(box1[2], box2[2])
    y2 = torch.min(box1[3], box2[3])
    
    intersection = torch.clamp(x2 - x1, min=0) * torch.clamp(y2 - y1, min=0)
    
    # Calculate union
    area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
    area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
    union = area1 + area2 - intersection
    
    return intersection / (union + 1e-6)


def visualize_iou():
    """Visualize IoU between two boxes"""
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    
    test_cases = [
        (torch.tensor([10, 10, 50, 50]), torch.tensor([30, 30, 70, 70])),  # Partial overlap
        (torch.tensor([10, 10, 50, 50]), torch.tensor([60, 60, 100, 100])),  # No overlap
        (torch.tensor([10, 10, 50, 50]), torch.tensor([15, 15, 45, 45])),  # High overlap
    ]
    
    for ax, (box1, box2) in zip(axes, test_cases):
        iou = calculate_iou(box1, box2).item()
        
        # Draw boxes
        from matplotlib.patches import Rectangle
        ax.add_patch(Rectangle((box1[0], box1[1]), box1[2]-box1[0], box1[3]-box1[1],
                               fill=False, edgecolor='blue', linewidth=2, label='Box 1'))
        ax.add_patch(Rectangle((box2[0], box2[1]), box2[2]-box2[0], box2[3]-box2[1],
                               fill=False, edgecolor='red', linewidth=2, label='Box 2'))
        ax.set_xlim(0, 110)
        ax.set_ylim(0, 110)
        ax.set_aspect('equal')
        ax.set_title(f'IoU = {iou:.3f}')
        ax.legend()
    
    plt.tight_layout()
    plt.show()

visualize_iou()

### 2.2 Non-Maximum Suppression

In [None]:
def nms(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float = 0.5) -> torch.Tensor:
    """
    Non-Maximum Suppression to remove duplicate detections.
    
    Args:
        boxes: (N, 4) tensor of boxes [x1, y1, x2, y2]
        scores: (N,) tensor of confidence scores
        iou_threshold: Boxes with IoU > threshold are suppressed
    
    Returns:
        Indices of boxes to keep
    """
    # Sort by score (descending)
    order = scores.argsort(descending=True)
    
    keep = []
    while order.numel() > 0:
        # Keep highest scoring box
        i = order[0].item()
        keep.append(i)
        
        if order.numel() == 1:
            break
        
        # Calculate IoU with remaining boxes
        remaining = order[1:]
        ious = torch.tensor([calculate_iou(boxes[i], boxes[j]) for j in remaining])
        
        # Keep boxes with IoU below threshold
        mask = ious <= iou_threshold
        order = remaining[mask]
    
    return torch.tensor(keep)


# Example
boxes = torch.tensor([
    [10, 10, 50, 50],
    [12, 12, 52, 52],  # Almost same as box 0
    [100, 100, 150, 150],
    [102, 102, 152, 152],  # Almost same as box 2
], dtype=torch.float32)

scores = torch.tensor([0.9, 0.8, 0.95, 0.85])

keep_indices = nms(boxes, scores, iou_threshold=0.5)
print(f"Original boxes: {len(boxes)}")
print(f"After NMS: {len(keep_indices)}")
print(f"Kept indices: {keep_indices.tolist()}")

### 2.3 Using Pretrained Detection Model

In [None]:
# Load Faster R-CNN pretrained on COCO

detection_model = fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
detection_model.eval()
detection_model.to(device)

# COCO class names
COCO_CLASSES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

print(f"Faster R-CNN loaded with {len(COCO_CLASSES)} classes")

In [None]:
def detect_objects(model, image: Image.Image, threshold: float = 0.5):
    """
    Run object detection on an image.
    
    Returns:
        boxes, labels, scores
    """
    # Preprocess
    transform = T.Compose([T.ToTensor()])
    img_tensor = transform(image).unsqueeze(0).to(device)
    
    # Detect
    with torch.no_grad():
        predictions = model(img_tensor)
    
    pred = predictions[0]
    
    # Filter by confidence
    mask = pred['scores'] > threshold
    boxes = pred['boxes'][mask].cpu()
    labels = pred['labels'][mask].cpu()
    scores = pred['scores'][mask].cpu()
    
    return boxes, labels, scores


def visualize_detections(image: Image.Image, boxes, labels, scores, classes):
    """Visualize detection results"""
    fig, ax = plt.subplots(1, figsize=(12, 8))
    ax.imshow(image)
    
    colors = plt.cm.hsv(np.linspace(0, 1, len(classes)))
    
    for box, label, score in zip(boxes, labels, scores):
        x1, y1, x2, y2 = box
        color = colors[label % len(colors)]
        
        # Draw box
        rect = plt.Rectangle((x1, y1), x2-x1, y2-y1,
                            fill=False, edgecolor=color, linewidth=2)
        ax.add_patch(rect)
        
        # Draw label
        class_name = classes[label] if label < len(classes) else f'class_{label}'
        ax.text(x1, y1-5, f'{class_name}: {score:.2f}',
               color='white', fontsize=10,
               bbox=dict(boxstyle='round', facecolor=color, alpha=0.8))
    
    ax.axis('off')
    plt.tight_layout()
    plt.show()


# Test detection on sample image
try:
    img = load_image_from_url(sample_url)
    boxes, labels, scores = detect_objects(detection_model, img)
    print(f"Detected {len(boxes)} objects")
    visualize_detections(img, boxes, labels, scores, COCO_CLASSES)
except Exception as e:
    print(f"Could not run detection: {e}")

### 2.4 Simple YOLO-style Detection Head

In [None]:
class SimpleDetectionHead(nn.Module):
    """
    Simplified YOLO-style detection head.
    
    Divides image into SxS grid cells.
    Each cell predicts B bounding boxes with confidence and C class probabilities.
    
    Output per cell: B * (5 + C) values
        - 5 = (x, y, w, h, confidence)
        - C = class probabilities
    """
    
    def __init__(self, in_channels: int, num_classes: int = 20, 
                 grid_size: int = 7, num_boxes: int = 2):
        super().__init__()
        
        self.S = grid_size
        self.B = num_boxes
        self.C = num_classes
        
        # Output channels: B * 5 (box params) + C (class probs)
        out_channels = self.B * 5 + self.C
        
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 1024, 3, padding=1),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.1),
            nn.Conv2d(1024, out_channels, 1),
        )
        
        # Adaptive pool to get SxS grid
        self.pool = nn.AdaptiveAvgPool2d(grid_size)
    
    def forward(self, features):
        """
        Args:
            features: (batch, channels, H, W) from backbone
        
        Returns:
            (batch, S, S, B*5 + C) predictions
        """
        x = self.conv(features)
        x = self.pool(x)
        # Reshape to (batch, S, S, B*5 + C)
        x = x.permute(0, 2, 3, 1)
        return x


# Test
backbone = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
# Remove classifier, use as feature extractor
backbone = nn.Sequential(*list(backbone.children())[:-2])

detection_head = SimpleDetectionHead(in_channels=512, num_classes=20)

x = torch.randn(4, 3, 448, 448)
features = backbone(x)
print(f"Backbone output: {features.shape}")

detections = detection_head(features)
print(f"Detection output: {detections.shape}")
print(f"  Grid: {detection_head.S}x{detection_head.S}")
print(f"  Per cell: {detection_head.B}*5 + {detection_head.C} = {detection_head.B*5 + detection_head.C}")

---

## 3. Semantic Segmentation

Semantic segmentation assigns a class label to each pixel in the image.

### 3.1 Understanding Segmentation

In [None]:
# Key components of segmentation networks:
# 1. Encoder: Extracts features (like classification backbone)
# 2. Decoder: Upsamples to original resolution

# Upsampling methods
x = torch.randn(1, 64, 8, 8)

# Method 1: Bilinear interpolation
up_bilinear = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
print(f"Bilinear upsampling: {x.shape} -> {up_bilinear.shape}")

# Method 2: Transposed convolution (learnable)
up_conv = nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1)
up_transposed = up_conv(x)
print(f"Transposed conv: {x.shape} -> {up_transposed.shape}")

# Method 3: Pixel shuffle (sub-pixel convolution)
# Rearranges (C*r^2, H, W) to (C, H*r, W*r)
x_shuffle = torch.randn(1, 256, 8, 8)  # 256 = 64 * 2^2
up_shuffle = nn.PixelShuffle(upscale_factor=2)(x_shuffle)
print(f"Pixel shuffle: {x_shuffle.shape} -> {up_shuffle.shape}")

### 3.2 U-Net Architecture

In [None]:
class DoubleConv(nn.Module):
    """Two consecutive conv-bn-relu blocks"""
    
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    """
    U-Net architecture for semantic segmentation.
    
    Features:
    - Encoder-decoder structure
    - Skip connections between encoder and decoder
    - Gradually reduces then increases spatial resolution
    """
    
    def __init__(self, in_channels: int = 3, num_classes: int = 21):
        super().__init__()
        
        # Encoder (downsampling)
        self.enc1 = DoubleConv(in_channels, 64)
        self.enc2 = DoubleConv(64, 128)
        self.enc3 = DoubleConv(128, 256)
        self.enc4 = DoubleConv(256, 512)
        
        # Bottleneck
        self.bottleneck = DoubleConv(512, 1024)
        
        # Decoder (upsampling)
        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = DoubleConv(1024, 512)  # 512 + 512 from skip
        
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = DoubleConv(512, 256)
        
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = DoubleConv(256, 128)
        
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = DoubleConv(128, 64)
        
        # Output
        self.out = nn.Conv2d(64, num_classes, 1)
        
        self.pool = nn.MaxPool2d(2)
    
    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        
        # Bottleneck
        b = self.bottleneck(self.pool(e4))
        
        # Decoder with skip connections
        d4 = self.dec4(torch.cat([self.up4(b), e4], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        
        return self.out(d1)


# Test U-Net
unet = UNet(in_channels=3, num_classes=21)
x = torch.randn(2, 3, 256, 256)
out = unet(x)
print(f"U-Net: {x.shape} -> {out.shape}")
print(f"Parameters: {sum(p.numel() for p in unet.parameters()):,}")

### 3.3 Using Pretrained Segmentation Models

In [None]:
# Load pretrained FCN
seg_model = fcn_resnet50(weights=FCN_ResNet50_Weights.DEFAULT)
seg_model.eval()
seg_model.to(device)

# VOC class names
VOC_CLASSES = [
    'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
    'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
    'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
]

print(f"FCN-ResNet50 loaded with {len(VOC_CLASSES)} classes")

In [None]:
def segment_image(model, image: Image.Image) -> torch.Tensor:
    """
    Run segmentation on an image.
    
    Returns:
        Segmentation mask (H, W) with class indices
    """
    # Preprocess
    transform = T.Compose([
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    img_tensor = transform(image).unsqueeze(0).to(device)
    
    # Segment
    with torch.no_grad():
        output = model(img_tensor)['out']
    
    # Get class predictions
    mask = output.argmax(dim=1).squeeze().cpu()
    
    return mask


def visualize_segmentation(image: Image.Image, mask: torch.Tensor, classes: list):
    """Visualize segmentation results"""
    # Create color map
    num_classes = len(classes)
    colors = plt.cm.tab20(np.linspace(0, 1, num_classes))
    
    # Create colored mask
    colored_mask = np.zeros((*mask.shape, 3))
    for i in range(num_classes):
        colored_mask[mask == i] = colors[i, :3]
    
    # Resize mask to image size
    mask_resized = Image.fromarray((colored_mask * 255).astype(np.uint8))
    mask_resized = mask_resized.resize(image.size, Image.NEAREST)
    
    # Plot
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    axes[0].imshow(image)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    axes[1].imshow(mask_resized)
    axes[1].set_title('Segmentation Mask')
    axes[1].axis('off')
    
    # Overlay
    axes[2].imshow(image)
    axes[2].imshow(mask_resized, alpha=0.5)
    axes[2].set_title('Overlay')
    axes[2].axis('off')
    
    # Add legend for present classes
    unique_classes = torch.unique(mask).tolist()
    legend_elements = [plt.Rectangle((0, 0), 1, 1, facecolor=colors[i, :3], 
                                      label=classes[i]) 
                       for i in unique_classes if i < len(classes)]
    axes[1].legend(handles=legend_elements, loc='upper right', fontsize=8)
    
    plt.tight_layout()
    plt.show()


# Test segmentation
try:
    img = load_image_from_url(sample_url)
    mask = segment_image(seg_model, img)
    visualize_segmentation(img, mask, VOC_CLASSES)
except Exception as e:
    print(f"Could not run segmentation: {e}")

### 3.4 Segmentation Loss Functions

In [None]:
class DiceLoss(nn.Module):
    """
    Dice Loss for segmentation.
    
    Dice = 2 * |A ∩ B| / (|A| + |B|)
    
    Works well for imbalanced classes.
    """
    
    def __init__(self, smooth: float = 1.0):
        super().__init__()
        self.smooth = smooth
    
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            pred: (N, C, H, W) logits
            target: (N, H, W) class indices
        """
        num_classes = pred.shape[1]
        
        # One-hot encode target
        target_one_hot = F.one_hot(target, num_classes).permute(0, 3, 1, 2).float()
        
        # Softmax predictions
        pred_soft = F.softmax(pred, dim=1)
        
        # Calculate Dice per class
        intersection = (pred_soft * target_one_hot).sum(dim=(2, 3))
        union = pred_soft.sum(dim=(2, 3)) + target_one_hot.sum(dim=(2, 3))
        
        dice = (2 * intersection + self.smooth) / (union + self.smooth)
        
        # Average over classes and batch
        return 1 - dice.mean()


class FocalLoss(nn.Module):
    """
    Focal Loss for handling class imbalance.
    
    FL(p) = -α(1-p)^γ * log(p)
    
    Down-weights easy examples, focuses on hard ones.
    """
    
    def __init__(self, alpha: float = 1.0, gamma: float = 2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            pred: (N, C, H, W) logits
            target: (N, H, W) class indices
        """
        ce_loss = F.cross_entropy(pred, target, reduction='none')
        p = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - p) ** self.gamma * ce_loss
        return focal_loss.mean()


# Test losses
pred = torch.randn(4, 21, 64, 64)
target = torch.randint(0, 21, (4, 64, 64))

ce_loss = nn.CrossEntropyLoss()(pred, target)
dice_loss = DiceLoss()(pred, target)
focal_loss = FocalLoss()(pred, target)

print(f"Cross Entropy Loss: {ce_loss.item():.4f}")
print(f"Dice Loss: {dice_loss.item():.4f}")
print(f"Focal Loss: {focal_loss.item():.4f}")

---

## 4. Advanced Techniques

### 4.1 Test-Time Augmentation (TTA)

In [None]:
class TTAWrapper:
    """
    Test-Time Augmentation wrapper.
    
    Applies augmentations at test time and averages predictions.
    """
    
    def __init__(self, model, augmentations: List[callable]):
        self.model = model
        self.augmentations = augmentations
    
    @torch.no_grad()
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input tensor (N, C, H, W)
        
        Returns:
            Averaged predictions
        """
        self.model.eval()
        predictions = []
        
        for aug in self.augmentations:
            # Apply augmentation
            x_aug = aug(x)
            
            # Get prediction
            pred = self.model(x_aug)
            
            # Reverse augmentation on prediction if needed
            # (for segmentation, need to flip mask back)
            pred = self._reverse_aug(pred, aug)
            
            predictions.append(pred)
        
        # Average predictions
        return torch.stack(predictions).mean(dim=0)
    
    def _reverse_aug(self, pred, aug):
        """Reverse augmentation if needed"""
        # For horizontal flip, flip the prediction back
        if hasattr(aug, '__name__') and 'hflip' in aug.__name__:
            return torch.flip(pred, dims=[-1])
        return pred


# Define augmentations
def identity(x):
    return x

def hflip(x):
    return torch.flip(x, dims=[-1])
hflip.__name__ = 'hflip'

def vflip(x):
    return torch.flip(x, dims=[-2])


# Example usage
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, 10)
model.eval()

tta_model = TTAWrapper(model, [identity, hflip])

x = torch.randn(4, 3, 224, 224)
pred_tta = tta_model(x)
print(f"TTA prediction shape: {pred_tta.shape}")

### 4.2 Model Ensembling

In [None]:
class EnsembleModel(nn.Module):
    """
    Ensemble multiple models by averaging their predictions.
    """
    
    def __init__(self, models: List[nn.Module], weights: Optional[List[float]] = None):
        super().__init__()
        self.models = nn.ModuleList(models)
        
        if weights is None:
            weights = [1.0 / len(models)] * len(models)
        self.weights = weights
    
    @torch.no_grad()
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        predictions = []
        
        for model, weight in zip(self.models, self.weights):
            model.eval()
            pred = model(x)
            predictions.append(weight * F.softmax(pred, dim=1))
        
        # Weighted average of softmax probabilities
        return torch.stack(predictions).sum(dim=0)


# Create ensemble of different architectures
model1 = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model1.fc = nn.Linear(model1.fc.in_features, 10)

model2 = resnet18()  # Random weights
model2.fc = nn.Linear(model2.fc.in_features, 10)

ensemble = EnsembleModel([model1, model2], weights=[0.7, 0.3])

x = torch.randn(4, 3, 224, 224)
pred_ensemble = ensemble(x)
print(f"Ensemble prediction shape: {pred_ensemble.shape}")
print(f"Probabilities sum to 1: {pred_ensemble[0].sum().item():.4f}")

### 4.3 Mixed Precision Training

In [None]:
from torch.cuda.amp import autocast, GradScaler

def train_with_mixed_precision(model, loader, criterion, optimizer, device, epochs=1):
    """
    Train with automatic mixed precision (AMP).
    
    Uses FP16 for forward/backward, FP32 for weight updates.
    Faster training with less memory.
    """
    scaler = GradScaler()
    model.train()
    
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass with autocast
            with autocast():
                output = model(data)
                loss = criterion(output, target)
            
            # Backward pass with gradient scaling
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            if batch_idx % 10 == 0:
                print(f"Batch {batch_idx}: Loss = {loss.item():.4f}")
            
            if batch_idx >= 20:  # Short demo
                break


if torch.cuda.is_available():
    model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    model.fc = nn.Linear(model.fc.in_features, 10)
    model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    
    print("Training with Mixed Precision:")
    train_with_mixed_precision(model, train_loader, criterion, optimizer, device)
else:
    print("Mixed precision requires CUDA")

---

## Exercises

### Exercise 1: Fine-tune for Custom Dataset

Create a complete transfer learning pipeline for a custom classification task.

In [None]:
# Exercise 1: Implement a complete transfer learning pipeline

class TransferLearningPipeline:
    """
    Complete pipeline for transfer learning.
    
    Steps:
    1. Load pretrained model
    2. Modify for new task
    3. Set up training with appropriate LR strategy
    4. Train with early stopping
    5. Evaluate with TTA
    """
    
    def __init__(self, num_classes: int, strategy: str = 'finetune'):
        """
        Args:
            num_classes: Number of output classes
            strategy: 'feature_extract' or 'finetune'
        """
        # YOUR CODE HERE
        pass
    
    def train(self, train_loader, val_loader, epochs: int = 10):
        """Train the model with early stopping"""
        # YOUR CODE HERE
        pass
    
    def predict(self, x: torch.Tensor, use_tta: bool = False) -> torch.Tensor:
        """Make predictions, optionally with TTA"""
        # YOUR CODE HERE
        pass

### Exercise 2: Build a Detection Dataset

Create a PyTorch dataset for object detection that handles bounding boxes.

In [None]:
# Exercise 2: Create a detection dataset

class DetectionDataset(Dataset):
    """
    Dataset for object detection.
    
    Should handle:
    - Loading images and annotations
    - Applying transforms to both image AND boxes
    - Converting box format (xyxy, xywh, cxcywh)
    """
    
    def __init__(self, images: List, annotations: List, transform=None):
        """
        Args:
            images: List of image paths or PIL images
            annotations: List of dicts with 'boxes' and 'labels'
            transform: Optional transform to apply
        """
        # YOUR CODE HERE
        pass
    
    def __len__(self):
        # YOUR CODE HERE
        pass
    
    def __getitem__(self, idx):
        # YOUR CODE HERE
        # Should return: image, target dict with 'boxes', 'labels'
        pass

### Exercise 3: Implement Segmentation Metrics

Implement IoU (Jaccard) and Dice metrics for segmentation evaluation.

In [None]:
# Exercise 3: Implement segmentation metrics

class SegmentationMetrics:
    """
    Calculate segmentation metrics:
    - IoU (Intersection over Union) per class
    - Mean IoU (mIoU)
    - Dice coefficient per class
    - Pixel accuracy
    """
    
    def __init__(self, num_classes: int, ignore_index: int = 255):
        # YOUR CODE HERE
        pass
    
    def update(self, pred: torch.Tensor, target: torch.Tensor):
        """Update confusion matrix with batch predictions"""
        # YOUR CODE HERE
        pass
    
    def compute(self) -> Dict[str, float]:
        """Compute all metrics"""
        # YOUR CODE HERE
        pass
    
    def reset(self):
        """Reset the confusion matrix"""
        # YOUR CODE HERE
        pass

---

## Solutions

In [None]:
# Solution 1: Transfer Learning Pipeline

class TransferLearningPipeline:
    def __init__(self, num_classes: int, strategy: str = 'finetune'):
        self.num_classes = num_classes
        self.strategy = strategy
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Load model
        self.model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        
        if strategy == 'feature_extract':
            for param in self.model.parameters():
                param.requires_grad = False
        
        # Replace classifier
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
        self.model.to(self.device)
        
        # Setup optimizer
        if strategy == 'finetune':
            self.optimizer = torch.optim.Adam([
                {'params': self.model.fc.parameters(), 'lr': 1e-3},
                {'params': [p for n, p in self.model.named_parameters() 
                           if 'fc' not in n and p.requires_grad], 'lr': 1e-5}
            ])
        else:
            self.optimizer = torch.optim.Adam(self.model.fc.parameters(), lr=1e-3)
        
        self.criterion = nn.CrossEntropyLoss()
        self.best_acc = 0
        self.patience_counter = 0
    
    def train(self, train_loader, val_loader, epochs: int = 10, patience: int = 3):
        for epoch in range(epochs):
            # Train
            self.model.train()
            train_loss = 0
            for data, target in train_loader:
                data, target = data.to(self.device), target.to(self.device)
                self.optimizer.zero_grad()
                output = self.model(data)
                loss = self.criterion(output, target)
                loss.backward()
                self.optimizer.step()
                train_loss += loss.item()
            
            # Validate
            val_acc = self._evaluate(val_loader)
            print(f"Epoch {epoch+1}: Val Acc = {val_acc:.2f}%")
            
            # Early stopping
            if val_acc > self.best_acc:
                self.best_acc = val_acc
                self.patience_counter = 0
            else:
                self.patience_counter += 1
                if self.patience_counter >= patience:
                    print("Early stopping!")
                    break
    
    @torch.no_grad()
    def _evaluate(self, loader):
        self.model.eval()
        correct = 0
        total = 0
        for data, target in loader:
            data, target = data.to(self.device), target.to(self.device)
            output = self.model(data)
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
        return 100. * correct / total
    
    @torch.no_grad()
    def predict(self, x: torch.Tensor, use_tta: bool = False) -> torch.Tensor:
        self.model.eval()
        x = x.to(self.device)
        
        if use_tta:
            preds = []
            preds.append(F.softmax(self.model(x), dim=1))
            preds.append(F.softmax(self.model(torch.flip(x, [-1])), dim=1))
            return torch.stack(preds).mean(dim=0)
        else:
            return F.softmax(self.model(x), dim=1)


print("TransferLearningPipeline defined!")

In [None]:
# Solution 2: Detection Dataset

class DetectionDataset(Dataset):
    def __init__(self, images: List, annotations: List, transform=None):
        self.images = images
        self.annotations = annotations
        self.transform = transform
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        # Load image
        if isinstance(self.images[idx], str):
            image = Image.open(self.images[idx]).convert('RGB')
        else:
            image = self.images[idx]
        
        # Get annotations
        ann = self.annotations[idx]
        boxes = torch.tensor(ann['boxes'], dtype=torch.float32)
        labels = torch.tensor(ann['labels'], dtype=torch.int64)
        
        # Apply transforms
        if self.transform:
            # For detection, need to transform boxes too
            # Here we just transform the image
            image = self.transform(image)
        else:
            image = T.ToTensor()(image)
        
        target = {
            'boxes': boxes,
            'labels': labels,
        }
        
        return image, target
    
    @staticmethod
    def collate_fn(batch):
        """Custom collate for variable number of boxes per image"""
        images = [item[0] for item in batch]
        targets = [item[1] for item in batch]
        return torch.stack(images), targets


# Example usage
dummy_images = [torch.randn(3, 224, 224) for _ in range(10)]
dummy_annotations = [
    {'boxes': [[10, 10, 50, 50], [100, 100, 150, 150]], 'labels': [1, 2]}
    for _ in range(10)
]

det_dataset = DetectionDataset(dummy_images, dummy_annotations)
print(f"Detection dataset size: {len(det_dataset)}")

In [None]:
# Solution 3: Segmentation Metrics

class SegmentationMetrics:
    def __init__(self, num_classes: int, ignore_index: int = 255):
        self.num_classes = num_classes
        self.ignore_index = ignore_index
        self.confusion_matrix = torch.zeros(num_classes, num_classes)
    
    @torch.no_grad()
    def update(self, pred: torch.Tensor, target: torch.Tensor):
        """Update confusion matrix"""
        # Get predictions
        if pred.dim() == 4:  # (N, C, H, W)
            pred = pred.argmax(dim=1)
        
        # Flatten
        pred = pred.flatten()
        target = target.flatten()
        
        # Ignore specified index
        mask = target != self.ignore_index
        pred = pred[mask]
        target = target[mask]
        
        # Update confusion matrix
        for t, p in zip(target, pred):
            if t < self.num_classes and p < self.num_classes:
                self.confusion_matrix[t, p] += 1
    
    def compute(self) -> Dict[str, float]:
        """Compute all metrics"""
        cm = self.confusion_matrix
        
        # Per-class IoU
        intersection = cm.diag()
        union = cm.sum(dim=1) + cm.sum(dim=0) - cm.diag()
        iou_per_class = intersection / (union + 1e-6)
        
        # Per-class Dice
        dice_per_class = 2 * intersection / (cm.sum(dim=1) + cm.sum(dim=0) + 1e-6)
        
        # Pixel accuracy
        pixel_acc = cm.diag().sum() / (cm.sum() + 1e-6)
        
        return {
            'iou_per_class': iou_per_class.tolist(),
            'mean_iou': iou_per_class.mean().item(),
            'dice_per_class': dice_per_class.tolist(),
            'mean_dice': dice_per_class.mean().item(),
            'pixel_accuracy': pixel_acc.item(),
        }
    
    def reset(self):
        self.confusion_matrix.zero_()


# Test
metrics = SegmentationMetrics(num_classes=3)

# Simulated predictions and targets
pred = torch.randint(0, 3, (4, 64, 64))
target = torch.randint(0, 3, (4, 64, 64))

metrics.update(pred, target)
results = metrics.compute()

print("Segmentation Metrics:")
print(f"  Mean IoU: {results['mean_iou']:.4f}")
print(f"  Mean Dice: {results['mean_dice']:.4f}")
print(f"  Pixel Accuracy: {results['pixel_accuracy']:.4f}")

---

## Summary

### Key Takeaways

1. **Transfer Learning**:
   - Feature extraction: Fast, good for small datasets
   - Fine-tuning: Better performance, needs more data
   - Use differential learning rates (smaller for pretrained layers)

2. **Object Detection**:
   - IoU measures box overlap quality
   - NMS removes duplicate detections
   - Modern detectors use anchor boxes and feature pyramids

3. **Semantic Segmentation**:
   - Encoder-decoder architecture (like U-Net)
   - Skip connections preserve spatial information
   - Dice/Focal loss for imbalanced classes

4. **Advanced Techniques**:
   - TTA improves predictions by averaging augmented inputs
   - Model ensembling combines multiple models
   - Mixed precision speeds up training

### When to Use What

| Task | Architecture | Loss Function |
|------|--------------|---------------|
| Classification | ResNet, EfficientNet | Cross Entropy |
| Detection | Faster R-CNN, YOLO | Smooth L1 + CE |
| Segmentation | U-Net, DeepLab | Dice + CE |