In [None]:
# import library
import os
import torch
from torchvision.datasets import VOCSegmentation
import random
import numpy as np

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(42) # For reproduciblity purpose, please do not modify this.

## Helper functions and dataset setup

In [None]:
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from torch.nn import functional as F

# Data preprocessing and augmentation
class VOCDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform=None, target_transform=None):
        self.dataset = dataset
        self.transform = transform
        self.target_transform = target_transform
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        image, target = self.dataset[idx]
        
        if self.transform:
            image = self.transform(image)
        
        if self.target_transform:
            target = self.target_transform(target)
            
        return image, target

# Define transforms
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def target_transform(target):
    target = target.resize((224, 224), Image.NEAREST)
    target = torch.tensor(np.array(target), dtype=torch.long)
    # Convert 255 (void) to 0 (background)
    target[target == 255] = 0
    return target

# Create datasets
train_voc = VOCDataset(train_dataset, transform=train_transform, target_transform=target_transform)
val_voc = VOCDataset(val_dataset, transform=val_transform, target_transform=target_transform)

# Create data loaders
batch_size = 16
train_loader = DataLoader(train_voc, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_voc, batch_size=batch_size, shuffle=False, num_workers=2)

print(f"Train samples: {len(train_voc)}")
print(f"Val samples: {len(val_voc)}")
print(f"Number of classes: {len(VOC_CLASSES)}")

##1. Download dataset
Please refer to [this function](https://docs.pytorch.org/vision/main/generated/torchvision.datasets.VOCSegmentation.html) from TorchVision to download the Pascal VOC Segmentation Dataset.

Note that you can change the input of provided code to match with your requirement.

Because the Pascal VOC Segmentation Dataset 2012 only provide a `train` set and a `val` set. So that you are required to train on `train` set only and then test the model on `val` set

**Note:** There is a void class with index 255 in dataset, you can treat the pixels with this label as backbround or just simply ignore it when calculate the loss value. [Refer to this post for suggestion](https://discuss.pytorch.org/t/having-trouble-with-voc-2012-segmentation-with-the-void-255-label/46486/7)

In [None]:
voc_dir = './data'
os.makedirs(voc_dir, exist_ok=True)
train_dataset = VOCSegmentation(root=voc_dir, year="2012", image_set="train", download=True)
val_dataset = VOCSegmentation(root=voc_dir, year="2012", image_set="val", download=True)

VOC_CLASSES = ["background", "aeroplane", "bicycle", "bird", "boat", "bottle",
               "bus", "car",  "cat",  "chair", "cow",  "diningtable", "dog", "horse",
               "motorbike", "person","potted plant", "sheep", "sofa","train", "tv/monitor"]

VOC_COLORMAP = [
    [0, 0, 0],
    [128, 0, 0],
    [0, 128, 0],
    [128, 128, 0],
    [0, 0, 128],
    [128, 0, 128],
    [0, 128, 128],
    [128, 128, 128],
    [64, 0, 0],
    [192, 0, 0],
    [64, 128, 0],
    [192, 128, 0],
    [64, 0, 128],
    [192, 0, 128],
    [64, 128, 128],
    [192, 128, 128],
    [0, 64, 0],
    [128, 64, 0],
    [0, 192, 0],
    [128, 192, 0],
    [0, 64, 128],
]

100%|██████████| 2.00G/2.00G [00:51<00:00, 38.6MB/s]


##2. Helper function
You are required to use this helper function to calculate the mean IoU score

In [None]:
# Provided meanIoU score
import numpy as np
from sklearn.metrics import confusion_matrix

def calculate_segmentation_metrics(preds, masks, num_classes, ignore_index=0):
    """
    Computes segmentation metrics: per-class and mean Precision, Recall, IoU, Dice, and overall Pixel Accuracy.

    Args:
        preds (Tensor): Predicted segmentation masks (B, H, W), each element is the predicted index class
        masks (Tensor): Ground truth segmentation masks (B, H, W)
        num_classes (int): Number of classes including background
        ignore_index (int): Label to ignore in evaluation (e.g., it should be the index of the background)

    Returns:
        metrics (dict): Dictionary containing:
            - 'per_class': dict of per-class metrics
            - 'mean_metrics': dict of averaged metrics across foreground classes
            - 'pixel_accuracy': float, overall pixel accuracy (excluding ignored)
    """
    eps = 1e-6  # for numerical stability
    preds = preds.view(-1)
    masks = masks.view(-1)
    valid = masks != ignore_index

    preds = preds[valid]
    masks = masks[valid]

    per_class_metrics = {}
    total_correct = 0
    total_pixels = valid.sum().item()

    precision_list = []
    recall_list = []
    iou_list = []
    dice_list = []

    for cls in range(num_classes):
        pred_inds = preds == cls
        target_inds = masks == cls

        TP = (pred_inds & target_inds).sum().item()
        FP = (pred_inds & ~target_inds).sum().item()
        FN = (~pred_inds & target_inds).sum().item()
        TN = ((~pred_inds) & (~target_inds)).sum().item()

        union = TP + FP + FN
        pred_sum = pred_inds.sum().item()
        target_sum = target_inds.sum().item()

        if target_sum == 0 and pred_sum == 0:
            continue

        precision = TP / (TP + FP + eps)
        recall = TP / (TP + FN + eps)
        iou = TP / (union + eps)
        dice = (2 * TP) / (pred_sum + target_sum + eps)

        precision_list.append(precision)
        recall_list.append(recall)
        iou_list.append(iou)
        dice_list.append(dice)

        total_correct += TP

    pixel_accuracy = total_correct / (total_pixels + eps)

    return {
        "precision": sum(precision_list) / len(precision_list),
        "recall": sum(recall_list) / len(recall_list),
        "iou": sum(iou_list) / len(iou_list),
        "dice": sum(dice_list) / len(dice_list),
        "pixel_accuracy": pixel_accuracy,
    }

# Task 1: Build a baseline Fully Convolutional Network (FCN) model for semantic segmentation (5 marks)

In [None]:
# Note: You can modify this code to load the backbone, just make sure you use model and weights from Nvidia
backbone_efficientnet = torch.hub.load("NVIDIA/DeepLearningExamples:torchhub",  "nvidia_efficientnet_b0", pretrained=True)

Downloading: "https://github.com/NVIDIA/DeepLearningExamples/zipball/torchhub" to /root/.cache/torch/hub/torchhub.zip
Downloading: "https://api.ngc.nvidia.com/v2/models/nvidia/efficientnet_b0_pyt_amp/versions/20.12.0/files/nvidia_efficientnet-b0_210412.pth" to /root/.cache/torch/hub/checkpoints/nvidia_efficientnet-b0_210412.pth
100%|██████████| 20.5M/20.5M [00:00<00:00, 128MB/s] 


In [None]:
# Your code starts from here
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# Assuming backbone_efficientnet and device are already defined
# Also assuming VOC_CLASSES is defined and contains the class names

# Baseline FCN Model
class FCN(nn.Module):
    def __init__(self, backbone, num_classes=21):
        super(FCN, self).__init__()
        self.backbone = backbone
        self.num_classes = num_classes
        
        # Remove the classifier and avgpool layers
        self.features = nn.Sequential(*list(backbone.children())[:-2])
        
        # Get feature map size by doing a forward pass
        with torch.no_grad():
            sample_input = torch.randn(1, 3, 224, 224)
            feature_map = self.features(sample_input)
            self.feature_channels = feature_map.shape[1]
            self.feature_size = feature_map.shape[2:]
        
        # FCN classifier head
        self.classifier = nn.Sequential(
            nn.Conv2d(self.feature_channels, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1),
            nn.Conv2d(512, num_classes, kernel_size=1)
        )
        
        # Upsampling layer to restore original image size
        self.upsample = nn.Upsample(size=(224, 224), mode='bilinear', align_corners=False)
        
    def forward(self, x):
        # Extract features
        features = self.features(x)
        
        # Apply classifier
        out = self.classifier(features)
        
        # Upsample to original size
        out = self.upsample(out)
        
        return out

# Create the baseline FCN model
model = FCN(backbone_efficientnet, num_classes=len(VOC_CLASSES))
model = model.to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore background class
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

# Training function
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    total_samples = 0
    
    for batch_idx, (images, targets) in enumerate(train_loader):
        images, targets = images.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        total_samples += images.size(0)
        
        if batch_idx % 50 == 0:
            print(f'Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')
    
    return running_loss / total_samples

# Validation function
def validate_epoch(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    total_samples = 0
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for images, targets in val_loader:
            images, targets = images.to(device), targets.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, targets)
            
            running_loss += loss.item() * images.size(0)
            total_samples += images.size(0)
            
            # Get predictions
            preds = torch.argmax(outputs, dim=1)
            all_preds.append(preds.cpu())
            all_targets.append(targets.cpu())
    
    all_preds = torch.cat(all_preds)
    all_targets = torch.cat(all_targets)
    
    # Calculate metrics
    metrics = calculate_segmentation_metrics(all_preds, all_targets, len(VOC_CLASSES), ignore_index=0)
    
    return running_loss / total_samples, metrics

# Training loop for baseline model
num_epochs = 15
train_losses = []
val_losses = []
val_ious = []

print("Training Baseline FCN Model...")
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    # Train
    train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_metrics = validate_epoch(model, val_loader, criterion, device)
    
    # Update scheduler
    scheduler.step()
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    val_ious.append(val_metrics['iou'])
    
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Loss: {val_loss:.4f}")
    print(f"Val IoU: {val_metrics['iou']:.4f}")
    print(f"Val Pixel Accuracy: {val_metrics['pixel_accuracy']:.4f}")

print(f"\nBaseline FCN Training Complete!")
print(f"Best Validation IoU: {max(val_ious):.4f}")

# Save the baseline model
torch.save(model.state_dict(), 'baseline_fcn_model.pth')
print("Baseline model saved as 'baseline_fcn_model.pth'")

# Plot training curves
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(val_ious, label='Val IoU', color='green')
plt.xlabel('Epoch')
plt.ylabel('IoU Score')
plt.title('Validation IoU Score')
plt.legend()

plt.tight_layout()
plt.show()

# Task 2: Improve the baseline FCN model (8 marks)

In [None]:
# Improved FCN Model with Skip Connections and Advanced Architecture
class ImprovedFCN(nn.Module):
    def __init__(self, backbone, num_classes=21):
        super(ImprovedFCN, self).__init__()
        self.backbone = backbone
        self.num_classes = num_classes
        
        # Extract features at multiple scales for skip connections
        self.layer1 = nn.Sequential(*list(backbone.children())[:3])  # Early features
        self.layer2 = nn.Sequential(*list(backbone.children())[3:5])  # Mid features  
        self.layer3 = nn.Sequential(*list(backbone.children())[5:7])  # High-level features
        self.layer4 = nn.Sequential(*list(backbone.children())[7:-2])  # Final features
        
        # Get feature dimensions
        with torch.no_grad():
            sample_input = torch.randn(1, 3, 224, 224)
            f1 = self.layer1(sample_input)
            f2 = self.layer2(f1)
            f3 = self.layer3(f2)
            f4 = self.layer4(f3)
            
            self.f1_channels = f1.shape[1]
            self.f2_channels = f2.shape[1] 
            self.f3_channels = f3.shape[1]
            self.f4_channels = f4.shape[1]
        
        # Reduce channel dimensions for skip connections
        self.skip1 = nn.Conv2d(self.f1_channels, 64, 1)
        self.skip2 = nn.Conv2d(self.f2_channels, 128, 1)
        self.skip3 = nn.Conv2d(self.f3_channels, 256, 1)
        
        # Main classifier path
        self.classifier = nn.Sequential(
            nn.Conv2d(self.f4_channels, 512, 3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.2),
            nn.Conv2d(512, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        
        # Fusion layers for combining features
        self.fusion3 = nn.Sequential(
            nn.Conv2d(256 + 256, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        
        self.fusion2 = nn.Sequential(
            nn.Conv2d(256 + 128, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        
        self.fusion1 = nn.Sequential(
            nn.Conv2d(128 + 64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        # Final classification layer
        self.final_conv = nn.Conv2d(64, num_classes, 1)
        
        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Conv2d(64, 16, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 1, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        # Multi-scale feature extraction
        f1 = self.layer1(x)  # 1/4 scale
        f2 = self.layer2(f1)  # 1/8 scale
        f3 = self.layer3(f2)  # 1/16 scale
        f4 = self.layer4(f3)  # 1/32 scale
        
        # Main classification path
        out = self.classifier(f4)
        
        # Progressive upsampling with skip connections
        # Stage 1: Upsample and fuse with f3
        out = F.interpolate(out, size=f3.shape[2:], mode='bilinear', align_corners=False)
        skip3 = self.skip3(f3)
        out = torch.cat([out, skip3], dim=1)
        out = self.fusion3(out)
        
        # Stage 2: Upsample and fuse with f2
        out = F.interpolate(out, size=f2.shape[2:], mode='bilinear', align_corners=False)
        skip2 = self.skip2(f2)
        out = torch.cat([out, skip2], dim=1)
        out = self.fusion2(out)
        
        # Stage 3: Upsample and fuse with f1
        out = F.interpolate(out, size=f1.shape[2:], mode='bilinear', align_corners=False)
        skip1 = self.skip1(f1)
        out = torch.cat([out, skip1], dim=1)
        out = self.fusion1(out)
        
        # Apply attention
        attention_map = self.attention(out)
        out = out * attention_map
        
        # Final classification
        out = self.final_conv(out)
        
        # Upsample to original size
        out = F.interpolate(out, size=(224, 224), mode='bilinear', align_corners=False)
        
        return out

# Create improved model
improved_model = ImprovedFCN(backbone_efficientnet, num_classes=len(VOC_CLASSES))
improved_model = improved_model.to(device)

# Advanced loss function - Focal Loss for handling class imbalance
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, ignore_index=0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.ignore_index = ignore_index
        
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, ignore_index=self.ignore_index, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
        return focal_loss.mean()

# Combined loss function
class CombinedLoss(nn.Module):
    def __init__(self, ignore_index=0):
        super(CombinedLoss, self).__init__()
        self.focal_loss = FocalLoss(ignore_index=ignore_index)
        self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index)
        
    def forward(self, inputs, targets):
        focal = self.focal_loss(inputs, targets)
        ce = self.ce_loss(inputs, targets)
        return 0.7 * focal + 0.3 * ce

# Improved training setup
criterion_improved = CombinedLoss(ignore_index=0)
optimizer_improved = optim.AdamW(improved_model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler_improved = optim.lr_scheduler.CosineAnnealingLR(optimizer_improved, T_max=20, eta_min=1e-6)

print(f"Improved model created with {sum(p.numel() for p in improved_model.parameters())} parameters")

# Enhanced data augmentation for improved model
train_transform_improved = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def target_transform_improved(target):
    target = target.resize((256, 256), Image.NEAREST)
    target = transforms.CenterCrop((224, 224))(target)
    target = torch.tensor(np.array(target), dtype=torch.long)
    target[target == 255] = 0
    return target

# Create improved datasets
train_voc_improved = VOCDataset(train_dataset, transform=train_transform_improved, target_transform=target_transform_improved)
val_voc_improved = VOCDataset(val_dataset, transform=val_transform, target_transform=target_transform)

train_loader_improved = DataLoader(train_voc_improved, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader_improved = DataLoader(val_voc_improved, batch_size=batch_size, shuffle=False, num_workers=2)

# Training loop for improved model
num_epochs_improved = 20
train_losses_improved = []
val_losses_improved = []
val_ious_improved = []

print("\nTraining Improved FCN Model...")
for epoch in range(num_epochs_improved):
    print(f"\nEpoch {epoch+1}/{num_epochs_improved}")
    
    # Train
    train_loss = train_epoch(improved_model, train_loader_improved, criterion_improved, optimizer_improved, device)
    
    # Validate 
    val_loss, val_metrics = validate_epoch(improved_model, val_loader_improved, criterion_improved, device)
    
    # Update scheduler
    scheduler_improved.step()
    
    train_losses_improved.append(train_loss)
    val_losses_improved.append(val_loss)
    val_ious_improved.append(val_metrics['iou'])
    
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Loss: {val_loss:.4f}")
    print(f"Val IoU: {val_metrics['iou']:.4f}")
    print(f"Val Pixel Accuracy: {val_metrics['pixel_accuracy']:.4f}")
    print(f"Val Precision: {val_metrics['precision']:.4f}")
    print(f"Val Recall: {val_metrics['recall']:.4f}")
    print(f"Val Dice: {val_metrics['dice']:.4f}")

print(f"\nImproved FCN Training Complete!")
print(f"Best Validation IoU: {max(val_ious_improved):.4f}")
print(f"Improvement over baseline: {max(val_ious_improved) - max(val_ious):.4f}")

# Save the improved model
torch.save(improved_model.state_dict(), 'improved_fcn_model.pth')
print("Improved model saved as 'improved_fcn_model.pth'")

# Compare results
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Baseline Train', alpha=0.7)
plt.plot(val_losses, label='Baseline Val', alpha=0.7)
plt.plot(train_losses_improved, label='Improved Train')
plt.plot(val_losses_improved, label='Improved Val')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Comparison')
plt.legend()

plt.subplot(1, 3, 2)
plt.plot(val_ious, label='Baseline IoU', alpha=0.7)
plt.plot(val_ious_improved, label='Improved IoU')
plt.xlabel('Epoch')
plt.ylabel('IoU Score')
plt.title('Validation IoU Comparison')
plt.legend()

plt.subplot(1, 3, 3)
epochs = range(1, len(val_ious) + 1)
epochs_improved = range(1, len(val_ious_improved) + 1)
plt.bar(['Baseline', 'Improved'], [max(val_ious), max(val_ious_improved)], 
        color=['lightblue', 'darkblue'], alpha=0.8)
plt.ylabel('Best IoU Score')
plt.title('Best Validation IoU Comparison')
for i, v in enumerate([max(val_ious), max(val_ious_improved)]):
    plt.text(i, v + 0.001, f'{v:.4f}', ha='center', va='bottom')

plt.tight_layout()
plt.show()

# Display sample predictions
def visualize_predictions(model, dataloader, num_samples=3):
    model.eval()
    fig, axes = plt.subplots(num_samples, 3, figsize=(12, num_samples*3))
    
    with torch.no_grad():
        for i, (images, targets) in enumerate(dataloader):
            if i >= num_samples:
                break
                
            images, targets = images.to(device), targets.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            
            # Get first image from batch
            img = images[0].cpu()
            target = targets[0].cpu()
            pred = preds[0].cpu()
            
            # Denormalize image
            mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
            img = img * std + mean
            img = torch.clamp(img, 0, 1)
            
            axes[i, 0].imshow(img.permute(1, 2, 0))
            axes[i, 0].set_title('Original Image')
            axes[i, 0].axis('off')
            
            axes[i, 1].imshow(target, cmap='tab20')
            axes[i, 1].set_title('Ground Truth')
            axes[i, 1].axis('off')
            
            axes[i, 2].imshow(pred, cmap='tab20')
            axes[i, 2].set_title('Prediction')
            axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.show()

print("\nSample predictions from improved model:")
visualize_predictions(improved_model, val_loader_improved)