# Semantic Segmentation with Modern Architectures
### Computer Vision Graduate Course Tutorial

In this tutorial, we'll explore modern approaches to semantic segmentation by combining:
- State-of-the-art backbone architectures (ConvNeXt/PVT/Swin)
- Feature Pyramid Network (FPN) decoder
- ADE20K dataset

## Learning Objectives
1. Understand the architecture of modern vision transformers and their use as backbones
2. Implement and analyze Feature Pyramid Networks for multi-scale feature fusion
3. Build a complete segmentation pipeline using PyTorch best practices
4. Gain hands-on experience with real-world dataset handling

## Prerequisites
- Understanding of CNNs and attention mechanisms
- Basic PyTorch knowledge
- Familiarity with image segmentation concepts

## Setup and Imports

In [None]:
import torch
import matplotlib.pyplot as plt
from pathlib import Path

# Import our modular implementation
from seg_model.models import SegmentationModel
from seg_model.datasets import ADE20KDataset
from seg_model.utils import visualize_batch, calculate_metrics

# Set random seeds for reproducibility
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True

## Part 1: Understanding Modern Backbones

Let's examine the key architectural elements of our available backbones:

1. **ConvNeXt**
   - Modernized CNN architecture inspired by Vision Transformers
   - Uses depthwise convolutions and larger kernels
   - Maintains CNN's inductive biases while incorporating transformer benefits

2. **PVT (Pyramid Vision Transformer)**
   - Hierarchical transformer that generates multi-scale features
   - Progressive shrinking of sequence length for efficiency
   - Spatial-reduction attention for better memory usage

3. **Swin Transformer**
   - Hierarchical architecture with shifted windows
   - Local self-attention within windows
   - Linear complexity with respect to image size

### Exercise 1: Backbone Feature Analysis

In [None]:
def analyze_backbone_features(backbone_name: str):
    """Analyze the feature maps from different backbone stages"""
    model = SegmentationModel(num_classes=150, backbone_name=backbone_name)
    
    # Create sample input
    x = torch.randn(1, 3, 224, 224)
    features = model.extract_backbone_features(x)
    
    # Print feature statistics
    print(f"\nFeature analysis for {backbone_name}:")
    for idx, feat in enumerate(features):
        print(f"Stage {idx+1}:")
        print(f"  Shape: {feat.shape}")
        print(f"  Mean activation: {feat.mean():.4f}")
        print(f"  Std deviation: {feat.std():.4f}")

# Compare different backbones
for backbone in ['convnext_tiny', 'pvt_v2_b0', 'swin_tiny_patch4_window7_224']:
    analyze_backbone_features(backbone)

### Discussion Questions:
1. How do the feature map resolutions differ between backbones?
2. What are the trade-offs between these architectures?
3. Which backbone might be most suitable for segmentation? Why?

## Part 2: Feature Pyramid Network Deep Dive

In [None]:
def visualize_fpn_features(model, image):
    """Visualize FPN feature maps at different scales"""
    fpn_features = model.extract_fpn_features(image)
    
    fig, axes = plt.subplots(2, len(fpn_features), figsize=(15, 6))
    
    for i, feat in enumerate(fpn_features):
        # Original feature map
        axes[0, i].imshow(feat[0, 0].detach().cpu())
        axes[0, i].set_title(f'P{i} Feature Map')
        
        # Channel-wise attention visualization
        attention = feat.mean(dim=1)[0]
        axes[1, i].imshow(attention.detach().cpu())
        axes[1, i].set_title(f'P{i} Channel Attention')
    
    plt.tight_layout()
    plt.show()

# Load a sample image and visualize features
dataset = ADE20KDataset(root_dir='ADEChallengeData2016', split='training')
image, _ = dataset[0]
model = SegmentationModel(num_classes=150)
visualize_fpn_features(model, image.unsqueeze(0))

### Exercise 2: FPN Ablation Study

Experiment with different FPN configurations:
1. Remove lateral connections
2. Change upsampling mode
3. Modify number of output channels

## Part 3: Training Pipeline

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    """Single training epoch with detailed monitoring"""
    model.train()
    total_loss = 0
    metrics = {'iou': [], 'pixel_acc': []}
    
    for batch_idx, (images, masks) in enumerate(train_loader):
        images, masks = images.to(device), masks.to(device)
        
        # Forward pass with intermediate feature extraction
        predictions = model(images)
        loss = criterion(predictions['p0'], masks)
        
        # Backward pass with gradient monitoring
        optimizer.zero_grad()
        loss.backward()
        
        # Log gradient statistics
        grad_stats = {}
        for name, param in model.named_parameters():
            if param.grad is not None:
                grad_stats[name] = {
                    'mean': param.grad.mean().item(),
                    'std': param.grad.std().item()
                }
        
        optimizer.step()
        
        # Calculate metrics
        batch_metrics = calculate_metrics(predictions['p0'], masks)
        metrics['iou'].append(batch_metrics['iou'])
        metrics['pixel_acc'].append(batch_metrics['pixel_acc'])
        
        total_loss += loss.item()
        
        if batch_idx % 10 == 0:
            print(f'Batch {batch_idx}: Loss = {loss.item():.4f}, '
                  f'IoU = {batch_metrics["iou"]:.4f}')
            
    return total_loss / len(train_loader), metrics

# Training configuration
config = {
    'num_epochs': 100,
    'batch_size': 16,
    'learning_rate': 1e-4,
    'weight_decay': 1e-4
}

# Initialize training components
model = SegmentationModel(num_classes=150).to('cuda')
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), **config)

### Advanced Topics for Exploration:

1. **Loss Functions**
   - Implement and compare different segmentation losses
   - Combine multiple loss terms
   - Analyze class imbalance handling

2. **Optimization Strategies**
   - Learning rate scheduling
   - Gradient accumulation
   - Mixed precision training

3. **Performance Analysis**
   - Error analysis by class
   - Confusion matrix visualization
   - Boundary accuracy metrics

## Assignments

1. **Architecture Modification (30 points)**
   - Modify the FPN architecture to include attention mechanisms
   - Compare performance with baseline
   - Analyze computational overhead

2. **Dataset Analysis (30 points)**
   - Analyze class distribution in ADE20K
   - Implement class-balanced sampling
   - Evaluate impact on performance

3. **Model Analysis (40 points)**
   - Conduct ablation studies
   - Visualize feature maps and attention
   - Compare different backbones
   - Write a technical report

## Additional Resources
- [ConvNeXt Paper](https://arxiv.org/abs/2201.03545)
- [PVT Paper](https://arxiv.org/abs/2102.12122)
- [Swin Transformer Paper](https://arxiv.org/abs/2103.14030)
- [ADE20K Dataset](https://groups.csail.mit.edu/vision/datasets/ADE20K/)