# Semantic Segmentation with ConvNeXt + FPN

This notebook demonstrates how to create a semantic segmentation model using:
- ConvNeXt backbone from timm (could be swapped with PVT or Swin)
- Feature Pyramid Network (FPN) decoder
- PyTorch training pipeline

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import numpy as np
from typing import List, Dict

# For visualization
import matplotlib.pyplot as plt

## 1. Define the FPN Decoder

In [None]:
class FPNDecoder(nn.Module):
    def __init__(self, in_channels: List[int], out_channels: int = 256):
        super().__init__()
        self.lateral_convs = nn.ModuleList([
            nn.Conv2d(in_ch, out_channels, 1)
            for in_ch in in_channels
        ])
        self.fpn_convs = nn.ModuleList([
            nn.Conv2d(out_channels, out_channels, 3, padding=1)
            for _ in range(len(in_channels))
        ])
        
    def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
        # Convert input features to same channel dimension
        laterals = [conv(feature) for feature, conv in zip(features, self.lateral_convs)]
        
        # Top-down pathway
        fpn_features = [laterals[-1]]
        for lateral in reversed(laterals[:-1]):
            # Upsample previous feature
            prev_feature = F.interpolate(
                fpn_features[-1],
                size=lateral.shape[-2:],
                mode='nearest'
            )
            # Add lateral connection
            fpn_feature = lateral + prev_feature
            fpn_features.append(fpn_feature)
            
        # Apply 3x3 convs and reverse list to maintain original order
        fpn_features = fpn_features[::-1]
        output_features = [
            conv(feature) for feature, conv in zip(fpn_features, self.fpn_convs)
        ]
        
        return output_features

## 2. Create the Complete Segmentation Model

In [None]:
class SegmentationModel(nn.Module):
    def __init__(self, num_classes: int, backbone_name: str = 'convnext_tiny'):
        super().__init__()
        
        # Load backbone and remove classification head
        self.backbone = timm.create_model(
            backbone_name,
            pretrained=True,
            features_only=True,
            out_indices=(1, 2, 3, 4)
        )
        
        # Get feature dimensions from backbone
        dummy_input = torch.randn(1, 3, 224, 224)
        features = self.backbone(dummy_input)
        in_channels = [feat.shape[1] for feat in features]
        
        # Initialize FPN
        self.fpn = FPNDecoder(in_channels)
        
        # Final prediction layers
        self.seg_convs = nn.ModuleList([
            nn.Conv2d(256, num_classes, 3, padding=1)
            for _ in range(len(in_channels))
        ])
        
    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        # Get backbone features
        features = self.backbone(x)
        
        # Apply FPN
        fpn_features = self.fpn(features)
        
        # Generate predictions at each scale
        predictions = {}
        for i, (feature, conv) in enumerate(zip(fpn_features, self.seg_convs)):
            pred = conv(feature)
            # Upsample to input resolution
            pred = F.interpolate(
                pred,
                size=x.shape[-2:],
                mode='bilinear',
                align_corners=False
            )
            predictions[f'p{i}'] = pred
            
        return predictions

## 3. Training Setup

In [None]:
def train_step(model: nn.Module,
               images: torch.Tensor,
               masks: torch.Tensor,
               criterion: nn.Module,
               optimizer: torch.optim.Optimizer) -> float:
    
    model.train()
    optimizer.zero_grad()
    
    # Forward pass
    predictions = model(images)
    
    # Calculate loss (using predictions from finest scale)
    loss = criterion(predictions['p0'], masks)
    
    # Backward pass
    loss.backward()
    optimizer.step()
    
    return loss.item()

def validate(model: nn.Module,
            val_loader: torch.utils.data.DataLoader,
            criterion: nn.Module) -> float:
    
    model.eval()
    val_loss = 0
    
    with torch.no_grad():
        for images, masks in val_loader:
            predictions = model(images)
            loss = criterion(predictions['p0'], masks)
            val_loss += loss.item()
            
    return val_loss / len(val_loader)

## 4. Example Usage

In [None]:
# Initialize model
model = SegmentationModel(
    num_classes=21,  # For example, Pascal VOC classes
    backbone_name='convnext_tiny'  # Can be changed to 'pvt_v2_b0' or 'swin_tiny_patch4_window7_224'
)

# Setup training
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Example training loop (assuming you have your dataloaders set up)
'''
num_epochs = 100
for epoch in range(num_epochs):
    epoch_loss = 0
    for images, masks in train_loader:
        loss = train_step(model, images, masks, criterion, optimizer)
        epoch_loss += loss
        
    # Validation
    val_loss = validate(model, val_loader, criterion)
    
    print(f'Epoch {epoch+1}/{num_epochs}')
    print(f'Training Loss: {epoch_loss/len(train_loader):.4f}')
    print(f'Validation Loss: {val_loss:.4f}')
'''

## 5. Visualization Helper

In [None]:
def visualize_prediction(image: torch.Tensor,
                        mask: torch.Tensor,
                        prediction: torch.Tensor):
    
    # Convert tensors to numpy
    image = image.cpu().permute(1, 2, 0).numpy()
    mask = mask.cpu().numpy()
    prediction = prediction.argmax(dim=0).cpu().numpy()
    
    # Create visualization
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
    
    ax1.imshow(image)
    ax1.set_title('Input Image')
    ax1.axis('off')
    
    ax2.imshow(mask)
    ax2.set_title('Ground Truth')
    ax2.axis('off')
    
    ax3.imshow(prediction)
    ax3.set_title('Prediction')
    ax3.axis('off')
    
    plt.tight_layout()
    plt.show()