In [1]:
# Cell 1 - Install Dependencies

print("Installing required packages...")
print("="*70)

# Install packages
%pip install torch torchvision numpy pillow matplotlib tqdm pandas tabulate seaborn 

print("\n‚úÖ All packages installed successfully!")

Installing required packages...
Looking in indexes: https://boartifactory.micron.com/artifactory/api/pypi/micron-pypi-rel-virtual/simple
Note: you may need to restart the kernel to use updated packages.

‚úÖ All packages installed successfully!


In [2]:
# Cell 2 - Import Libraries

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.models import vgg16

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
import os
import random
import json
from tqdm import tqdm
from typing import Tuple, List, Optional, Dict

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

print("‚úÖ All libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

‚úÖ All libraries imported successfully!
PyTorch version: 2.7.1+cu118
CUDA available: True
CUDA device: NVIDIA RTX 2000 Ada Generation Laptop GPU


In [7]:
# Cell 3 - Configuration

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}\n")

# Pascal VOC class names
PASCAL_VOC_CLASSES = [
    'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
    'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
    'horse', 'motorbike', 'person', 'pottedplant', 'sheep',
    'sofa', 'train', 'tvmonitor'
]

import kagglehub
huanghanchina_pascal_voc_2012_path = kagglehub.dataset_download('huanghanchina/pascal-voc-2012')
print('Data source import complete.')

import os

# After kagglehub download
base_path = huanghanchina_pascal_voc_2012_path
print(f"Dataset downloaded to: {base_path}")

# Configure paths
image_dir = os.path.join(base_path,  'VOC2012', 'JPEGImages')
label_dir = os.path.join(base_path,  'VOC2012', 'SegmentationClass')
train_list = os.path.join(base_path,  'VOC2012', 'ImageSets', 'Segmentation', 'train.txt')
val_list = os.path.join(base_path,  'VOC2012', 'ImageSets', 'Segmentation', 'val.txt')

# Verify paths exist
assert os.path.exists(image_dir), f"Image directory not found: {image_dir}"
assert os.path.exists(label_dir), f"Label directory not found: {label_dir}"
print(f"‚úÖ Image directory: {image_dir}")
print(f"‚úÖ Label directory: {label_dir}")

# Check dataset statistics
image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg')]
label_files = [f for f in os.listdir(label_dir) if f.endswith('.png')]
print(f"\nüìä Dataset Statistics:")
print(f"   Total images: {len(image_files)}")
print(f"   Labeled images: {len(label_files)}")

# Training configuration
config = {
    # Dataset paths (UPDATE THESE!)
    'voc_root': 'huanghanchina_pascal_voc_2012_path',  
    
    # Model parameters
    'num_classes': 21,
    'use_pairwise': True,
    
    # Training parameters
    'batch_size': 4,  # Reduce if OOM
    'num_workers': 4,
    'learning_rate': 1e-3,
    'lr_pretrained': 1e-4,  # Lower LR for pretrained layers
    'weight_decay': 5e-4,
    
    # Piecewise training epochs
    'stage1_epochs': 2,   # Unary training (increase to 20 for full training)
    'stage2_epochs': 1,   # Pairwise training (increase to 10 for full training)
    'stage3_epochs': 1,   # Joint fine-tuning (increase to 5 for full training)
    
    # Data augmentation
    'image_size': (512, 512),
    'base_size': 520,
    'crop_size': 512,
    'scale_range': (0.5, 2.0),
    
    # Class weighting
    'use_class_weights': True,
    
    # Debugging (set to None for full dataset)
    'max_train_images': 100,  # Use small subset for testing (set to None for full)
    'max_val_images': 50,     # Use small subset for testing (set to None for full)
}

print("Configuration:")
print(json.dumps(config, indent=2))

Using device: cuda

Data source import complete.
Dataset downloaded to: C:\Users\rkekanaje\.cache\kagglehub\datasets\huanghanchina\pascal-voc-2012\versions\1
‚úÖ Image directory: C:\Users\rkekanaje\.cache\kagglehub\datasets\huanghanchina\pascal-voc-2012\versions\1\VOC2012\JPEGImages
‚úÖ Label directory: C:\Users\rkekanaje\.cache\kagglehub\datasets\huanghanchina\pascal-voc-2012\versions\1\VOC2012\SegmentationClass

üìä Dataset Statistics:
   Total images: 17125
   Labeled images: 2913
Configuration:
{
  "voc_root": "huanghanchina_pascal_voc_2012_path",
  "num_classes": 21,
  "use_pairwise": true,
  "batch_size": 4,
  "num_workers": 4,
  "learning_rate": 0.001,
  "lr_pretrained": 0.0001,
  "weight_decay": 0.0005,
  "stage1_epochs": 2,
  "stage2_epochs": 1,
  "stage3_epochs": 1,
  "image_size": [
    512,
    512
  ],
  "base_size": 520,
  "crop_size": 512,
  "scale_range": [
    0.5,
    2.0
  ],
  "use_class_weights": true,
  "max_train_images": 100,
  "max_val_images": 50
}


In [8]:
# Cell 5 - Dataset Classes

class SegmentationAugmentation:
    """Data augmentation for semantic segmentation."""
    
    def __init__(self, base_size=520, crop_size=512, scale_range=(0.5, 2.0)):
        self.base_size = base_size
        self.crop_size = crop_size
        self.scale_range = scale_range
    
    def __call__(self, image, label):
        # Random scaling
        scale = random.uniform(*self.scale_range)
        w, h = image.size
        new_w, new_h = int(w * scale), int(h * scale)
        
        image = image.resize((new_w, new_h), Image.BILINEAR)
        label = label.resize((new_w, new_h), Image.NEAREST)
        
        # Random crop
        w, h = image.size
        if w > self.crop_size and h > self.crop_size:
            x = random.randint(0, w - self.crop_size)
            y = random.randint(0, h - self.crop_size)
            image = image.crop((x, y, x + self.crop_size, y + self.crop_size))
            label = label.crop((x, y, x + self.crop_size, y + self.crop_size))
        else:
            # Pad if too small
            image = image.resize((self.crop_size, self.crop_size), Image.BILINEAR)
            label = label.resize((self.crop_size, self.crop_size), Image.NEAREST)
        
        # Random horizontal flip
        if random.random() > 0.5:
            image = image.transpose(Image.FLIP_LEFT_RIGHT)
            label = label.transpose(Image.FLIP_LEFT_RIGHT)
        
        return image, label


class EnhancedSegmentationDataset(Dataset):
    """Enhanced dataset with split file support and optional augmentation."""
    
    def __init__(self, image_dir, label_dir, split_list_file=None, 
                 augmentation=None, image_size=(512, 512), max_images=None):
        self.image_dir = Path(image_dir)
        self.label_dir = Path(label_dir)
        self.augmentation = augmentation
        self.image_size = image_size
        
        # Load image IDs from split file
        if split_list_file and Path(split_list_file).exists():
            with open(split_list_file, 'r') as f:
                self.image_ids = [line.strip() for line in f.readlines()]
        else:
            # Fallback: use all images
            self.image_ids = [f.stem for f in self.image_dir.glob('*.jpg')]
        
        # Limit dataset size if specified (for debugging)
        if max_images is not None:
            self.image_ids = self.image_ids[:max_images]
        
        # Normalization
        self.normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        
        # Load image and label
        img_path = self.image_dir / f'{img_id}.jpg'
        label_path = self.label_dir / f'{img_id}.png'
        
        image = Image.open(img_path).convert('RGB')
        label = Image.open(label_path)
        
        # Apply augmentation
        if self.augmentation:
            image, label = self.augmentation(image, label)
        else:
            # Just resize
            image = image.resize(self.image_size, Image.BILINEAR)
            label = label.resize(self.image_size, Image.NEAREST)
        
        # Convert to tensors
        image = transforms.ToTensor()(image)
        label = torch.from_numpy(np.array(label, dtype=np.int64))
        
        # Normalize image
        image = self.normalize(image)
        
        return image, label


def compute_class_weights(dataset, num_classes, ignore_index=255):
    """Compute class weights for handling class imbalance."""
    print("Computing class weights (this may take a few minutes)...")
    
    class_counts = torch.zeros(num_classes)
    
    for idx in tqdm(range(len(dataset)), desc="Scanning dataset"):
        _, label = dataset[idx]
        for c in range(num_classes):
            class_counts[c] += (label == c).sum().item()
    
    # Compute weights (inverse frequency)
    total_pixels = class_counts.sum()
    class_weights = total_pixels / (num_classes * class_counts)
    
    # Normalize weights
    class_weights = class_weights / class_weights.sum() * num_classes
    
    # Clip extreme weights
    class_weights = torch.clamp(class_weights, 0.1, 10.0)
    
    return class_weights


print("‚úÖ Dataset classes defined!")

‚úÖ Dataset classes defined!


In [9]:
# Cell 6 - Model Architecture

class FeatMapNet(nn.Module):
    """Feature extraction network based on VGG-16."""
    
    def __init__(self, pretrained=True):
        super().__init__()
        
        # Load pretrained VGG-16
        vgg = vgg16(pretrained=pretrained)
        
        # Extract feature layers (conv1 to pool5)
        self.block1 = nn.Sequential(*list(vgg.features[:5]))   # conv1_1, conv1_2, pool1
        self.block2 = nn.Sequential(*list(vgg.features[5:10]))  # conv2_1, conv2_2, pool2
        self.block3 = nn.Sequential(*list(vgg.features[10:17])) # conv3_1, conv3_2, conv3_3, pool3
        self.block4 = nn.Sequential(*list(vgg.features[17:24])) # conv4_1, conv4_2, conv4_3, pool4
        self.block5 = nn.Sequential(*list(vgg.features[24:31])) # conv5_1, conv5_2, conv5_3, pool5
        
        # Additional dilated convolution block
        self.block6 = nn.Sequential(
            nn.Conv2d(512, 512, 3, padding=2, dilation=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 3, padding=2, dilation=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 3, padding=2, dilation=2),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        return x


class UnaryNet(nn.Module):
    """Unary potential network with multi-scale and spatial pyramid pooling."""
    
    def __init__(self, in_channels=512, num_classes=21, scales=[1.0, 0.5], pool_sizes=[1, 2]):
        super().__init__()
        self.scales = scales
        self.pool_sizes = pool_sizes
        
        # Convolutional layers for unary potentials
        self.conv1 = nn.Conv2d(in_channels, 512, 3, padding=1)
        self.conv2 = nn.Conv2d(512, 512, 3, padding=1)
        self.conv3 = nn.Conv2d(512, num_classes, 1)
        
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, feat):
        # Multi-scale processing
        outputs = []
        for scale in self.scales:
            if scale != 1.0:
                h, w = feat.shape[2:]
                scaled_feat = F.interpolate(feat, size=(int(h*scale), int(w*scale)), 
                                           mode='bilinear', align_corners=False)
            else:
                scaled_feat = feat
            
            # Spatial pyramid pooling
            pooled_feats = []
            for pool_size in self.pool_sizes:
                if pool_size > 1:
                    pooled = F.adaptive_avg_pool2d(scaled_feat, pool_size)
                    pooled = F.interpolate(pooled, size=scaled_feat.shape[2:], 
                                          mode='bilinear', align_corners=False)
                    pooled_feats.append(pooled)
                else:
                    pooled_feats.append(scaled_feat)
            
            # Combine pooled features
            combined = torch.cat(pooled_feats, dim=1) if len(pooled_feats) > 1 else pooled_feats[0]
            
            # Reduce channels back to original
            if combined.shape[1] != feat.shape[1]:
                combined = F.conv2d(combined, 
                                   torch.ones(feat.shape[1], combined.shape[1], 1, 1).to(feat.device) / combined.shape[1])
            
            # Apply convolutions
            x = self.relu(self.conv1(combined))
            x = self.relu(self.conv2(x))
            x = self.conv3(x)
            
            # Resize back to original feature size
            if scale != 1.0:
                x = F.interpolate(x, size=feat.shape[2:], mode='bilinear', align_corners=False)
            
            outputs.append(x)
        
        # Average multi-scale outputs
        return torch.mean(torch.stack(outputs), dim=0)


class PairwiseNet(nn.Module):
    """Pairwise potential network."""
    
    def __init__(self, in_channels=512, num_classes=21):
        super().__init__()
        
        self.conv1 = nn.Conv2d(in_channels, 256, 3, padding=1)
        self.conv2 = nn.Conv2d(256, 256, 3, padding=1)
        self.conv3 = nn.Conv2d(256, num_classes, 1)
        
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, feat):
        x = self.relu(self.conv1(feat))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x


class AlternativeCRFModel(nn.Module):
    """Complete alternative CRF model."""
    
    def __init__(self, num_classes=21, use_pairwise=True, scales=[1.0, 0.5], pool_sizes=[1, 2]):
        super().__init__()
        
        self.num_classes = num_classes
        self.use_pairwise = use_pairwise
        
        # Networks
        self.featmap_net = FeatMapNet(pretrained=True)
        self.unary_net = UnaryNet(in_channels=512, num_classes=num_classes, 
                                  scales=scales, pool_sizes=pool_sizes)
        
        if use_pairwise:
            self.pairwise_net = PairwiseNet(in_channels=512, num_classes=num_classes)
    
    def forward(self, x, return_features=False):
        # Extract features
        feat = self.featmap_net(x)
        
        # Compute unary potentials
        unary = self.unary_net(feat)
        
        # Upsample to input size
        unary = F.interpolate(unary, size=x.shape[2:], mode='bilinear', align_corners=False)
        
        if return_features:
            return unary, feat
        return unary
    
    def get_pairwise_scores(self, feat):
        """Get pairwise potential scores."""
        if not self.use_pairwise:
            return None
        return self.pairwise_net(feat)


# Import SegmentationMetrics from existing code
from src.piecewise_training.metrics import SegmentationMetrics

print("‚úÖ Model architecture defined!")

‚úÖ Model architecture defined!


In [16]:
# Cell 7 - Create Datasets & Data Loaders (FIXED)

print("Creating datasets...")

# Training augmentation
train_aug = SegmentationAugmentation(
    base_size=config['base_size'],
    crop_size=config['crop_size'],
    scale_range=config['scale_range']
)

# Create datasets
train_dataset = EnhancedSegmentationDataset(
    image_dir=image_dir,
    label_dir=label_dir,
    split_list_file=train_list,
    augmentation=train_aug,
    image_size=config['image_size'],
    max_images=config['max_train_images']
)

val_dataset = EnhancedSegmentationDataset(
    image_dir=image_dir,
    label_dir=label_dir,
    split_list_file=val_list,
    augmentation=None,  # No augmentation for validation
    image_size=config['image_size'],
    max_images=config['max_val_images']
)

print(f"\nüìä Dataset Statistics:")
print(f"   Training samples: {len(train_dataset)}")
print(f"   Validation samples: {len(val_dataset)}")

# Create data loaders (FIX: Set num_workers=0 for Windows)
train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=0,  # ‚úÖ Changed from 4 to 0 (fixes Windows multiprocessing issue)
    pin_memory=False,  # ‚úÖ Changed from True to False
    drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=0,  # ‚úÖ Changed from 4 to 0
    pin_memory=False  # ‚úÖ Changed from True to False
)

print(f"‚úÖ Data loaders created!")
print(f"   Train batches: {len(train_loader)}")
print(f"   Val batches: {len(val_loader)}")

Creating datasets...

üìä Dataset Statistics:
   Training samples: 100
   Validation samples: 50
‚úÖ Data loaders created!
   Train batches: 25
   Val batches: 13


In [17]:
# Cell 8 - Compute Class Weights 

class_weights = None
if config['use_class_weights']:
    print("\n" + "="*70)
    print("Computing class weights for handling class imbalance...")
    print("="*70)
    class_weights = compute_class_weights(train_dataset, config['num_classes'])
    class_weights = class_weights.to(device)
    
    # Display class weights
    print("\nClass Weights:")
    for idx, (name, weight) in enumerate(zip(PASCAL_VOC_CLASSES, class_weights)):
        print(f"   {name:15s}: {weight:.4f}")
else:
    print("‚ö†Ô∏è  Not using class weights")


Computing class weights for handling class imbalance...
Computing class weights (this may take a few minutes)...


Scanning dataset: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 100/100 [00:02<00:00, 40.73it/s]


Class Weights:
   background     : 0.1000
   aeroplane      : 0.7460
   bicycle        : 2.7984
   bird           : 0.8267
   boat           : 1.5631
   bottle         : 1.7627
   bus            : 0.3974
   car            : 0.6645
   cat            : 0.3138
   chair          : 1.2500
   cow            : 0.9179
   diningtable    : 0.3978
   dog            : 1.2287
   horse          : 0.9798
   motorbike      : 0.4705
   person         : 0.1438
   pottedplant    : 2.5484
   sheep          : 0.8606
   sofa           : 0.8830
   train          : 1.6133
   tvmonitor      : 0.6207





In [18]:
# Cell 9 - Create Model

print("\n" + "="*70)
print("CREATING MODEL")
print("="*70)

model = AlternativeCRFModel(
    num_classes=config['num_classes'],
    use_pairwise=True,
    scales=[1.0, 0.5],  # Multi-scale processing
    pool_sizes=[1, 2]   # Spatial pyramid pooling
)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nüìä Model Statistics:")
print(f"   Total parameters: {total_params:,}")
print(f"   Trainable parameters: {trainable_params:,}")
print(f"   Model size: {total_params * 4 / 1024 / 1024:.2f} MB (FP32)")

# Test forward pass
with torch.no_grad():
    test_input = torch.randn(1, 3, *config['image_size']).to(device)
    test_output = model(test_input)
    print(f"\n‚úÖ Model forward pass successful!")
    print(f"   Input shape: {test_input.shape}")
    print(f"   Output shape: {test_output.shape}")


CREATING MODEL

üìä Model Statistics:
   Total parameters: 28,299,882
   Trainable parameters: 28,299,882
   Model size: 107.96 MB (FP32)

‚úÖ Model forward pass successful!
   Input shape: torch.Size([1, 3, 512, 512])
   Output shape: torch.Size([1, 21, 512, 512])


In [22]:
# Cell 10 - Trainer Implementation (FIXED)

class AlternativeTrainer:
    """Trainer for the alternative CRF model with piecewise training."""
    
    def __init__(self, model, device, num_classes, learning_rate=1e-3, 
                 lr_pretrained=1e-4, weight_decay=5e-4, class_weights=None):
        self.model = model
        self.device = device
        self.num_classes = num_classes
        self.learning_rate = learning_rate
        self.lr_pretrained = lr_pretrained
        self.weight_decay = weight_decay
        
        # Loss functions
        self.unary_loss_fn = nn.CrossEntropyLoss(
            ignore_index=255,
            weight=class_weights
        )
        self.pairwise_loss_fn = nn.CrossEntropyLoss(
            ignore_index=255,
            weight=class_weights
        )
        
        # Metrics
        self.metrics = SegmentationMetrics(num_classes)
    
    def _get_optimizer_stage1(self):
        """Get optimizer for stage 1 (unary training)."""
        pretrained_params = []
        new_params = []
        
        for name, param in self.model.named_parameters():
            if 'featmap_net.block' in name and 'block6' not in name:
                pretrained_params.append(param)
            else:
                new_params.append(param)
        
        return torch.optim.SGD([
            {'params': pretrained_params, 'lr': self.lr_pretrained},
            {'params': new_params, 'lr': self.learning_rate}
        ], momentum=0.9, weight_decay=self.weight_decay)
    
    def _get_optimizer_stage2(self):
        """Get optimizer for stage 2 (pairwise training)."""
        return torch.optim.SGD(
            self.model.pairwise_net.parameters(),
            lr=self.learning_rate,
            momentum=0.9,
            weight_decay=self.weight_decay
        )
    
    def _get_optimizer_stage3(self):
        """Get optimizer for stage 3 (joint fine-tuning)."""
        trainable_params = list(self.model.unary_net.parameters()) + \
                          list(self.model.pairwise_net.parameters())
        
        return torch.optim.SGD(
            trainable_params,
            lr=self.learning_rate * 0.1,
            momentum=0.9,
            weight_decay=self.weight_decay
        )
    
    def train_stage1_unary(self, train_loader, num_epochs, val_loader=None):
        """Stage 1: Train unary network."""
        print("\n" + "="*70)
        print("STAGE 1: Training Unary Potentials")
        print("="*70)
        
        optimizer = self._get_optimizer_stage1()
        history = {'train_loss': [], 'val_miou': [], 'val_pixel_acc': []}
        
        for epoch in range(num_epochs):
            self.model.train()
            total_loss = 0
            
            pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
            for imgs, labels in pbar:
                imgs, labels = imgs.to(self.device), labels.to(self.device)
                
                optimizer.zero_grad()
                unary, _ = self.model(imgs, return_features=True)
                loss = self.unary_loss_fn(unary, labels)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                pbar.set_postfix({'loss': f'{loss.item():.4f}'})
            
            avg_loss = total_loss / len(train_loader)
            history['train_loss'].append(avg_loss)
            
            if val_loader is not None:
                val_metrics = self.validate(val_loader, use_pairwise=False)
                history['val_miou'].append(val_metrics['miou'])  # ‚úÖ Changed from 'mIoU'
                history['val_pixel_acc'].append(val_metrics['pixel_acc'])
                
                print(f"Epoch {epoch+1}/{num_epochs} - "
                      f"Loss: {avg_loss:.4f}, "
                      f"Val mIoU: {val_metrics['miou']:.4f}, "
                      f"Val Acc: {val_metrics['pixel_acc']:.4f}")
            else:
                print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")
        
        return history
    
    def train_stage2_pairwise(self, train_loader, num_epochs, val_loader=None):
        """Stage 2: Train pairwise network."""
        print("\n" + "="*70)
        print("STAGE 2: Training Pairwise Potentials")
        print("="*70)
        
        # Freeze unary network
        for param in self.model.featmap_net.parameters():
            param.requires_grad = False
        for param in self.model.unary_net.parameters():
            param.requires_grad = False
        
        optimizer = self._get_optimizer_stage2()
        history = {'train_loss': [], 'val_miou': [], 'val_pixel_acc': []}
        
        for epoch in range(num_epochs):
            self.model.train()
            total_loss = 0
            
            pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
            for imgs, labels in pbar:
                imgs, labels = imgs.to(self.device), labels.to(self.device)
                
                optimizer.zero_grad()
                unary, feat = self.model(imgs, return_features=True)
                pairwise = self.model.get_pairwise_scores(feat)
                
                # Upsample pairwise to match label size
                if pairwise.shape[2:] != labels.shape[1:]:
                    pairwise = F.interpolate(pairwise, size=labels.shape[1:], 
                                            mode='bilinear', align_corners=False)
                
                loss = self.pairwise_loss_fn(pairwise, labels)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                pbar.set_postfix({'loss': f'{loss.item():.4f}'})
            
            avg_loss = total_loss / len(train_loader)
            history['train_loss'].append(avg_loss)
            
            if val_loader is not None:
                val_metrics = self.validate(val_loader, use_pairwise=True)
                history['val_miou'].append(val_metrics['miou'])  # ‚úÖ Changed from 'mIoU'
                history['val_pixel_acc'].append(val_metrics['pixel_acc'])
                
                print(f"Epoch {epoch+1}/{num_epochs} - "
                      f"Loss: {avg_loss:.4f}, "
                      f"Val mIoU: {val_metrics['miou']:.4f}, "
                      f"Val Acc: {val_metrics['pixel_acc']:.4f}")
            else:
                print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")
        
        # Unfreeze for next stage
        for param in self.model.parameters():
            param.requires_grad = True
        
        return history
    
    def train_stage3_joint(self, train_loader, num_epochs, val_loader=None):
        """Stage 3: Joint fine-tuning."""
        print("\n" + "="*70)
        print("STAGE 3: Joint Fine-tuning")
        print("="*70)
        
        # Freeze feature extractor
        for param in self.model.featmap_net.parameters():
            param.requires_grad = False
        
        optimizer = self._get_optimizer_stage3()
        history = {'train_loss': [], 'val_miou': [], 'val_pixel_acc': []}
        
        for epoch in range(num_epochs):
            self.model.train()
            total_loss = 0
            
            pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
            for imgs, labels in pbar:
                imgs, labels = imgs.to(self.device), labels.to(self.device)
                
                optimizer.zero_grad()
                unary, feat = self.model(imgs, return_features=True)
                pairwise = self.model.get_pairwise_scores(feat)
                
                # Upsample pairwise
                if pairwise.shape[2:] != labels.shape[1:]:
                    pairwise = F.interpolate(pairwise, size=labels.shape[1:], 
                                            mode='bilinear', align_corners=False)
                
                # Combined loss
                loss_u = self.unary_loss_fn(unary, labels)
                loss_p = self.pairwise_loss_fn(pairwise, labels)
                loss = loss_u + 0.5 * loss_p
                
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                pbar.set_postfix({'loss': f'{loss.item():.4f}'})
            
            avg_loss = total_loss / len(train_loader)
            history['train_loss'].append(avg_loss)
            
            if val_loader is not None:
                val_metrics = self.validate(val_loader, use_pairwise=True)
                history['val_miou'].append(val_metrics['miou'])  # ‚úÖ Changed from 'mIoU'
                history['val_pixel_acc'].append(val_metrics['pixel_acc'])
                
                print(f"Epoch {epoch+1}/{num_epochs} - "
                      f"Loss: {avg_loss:.4f}, "
                      f"Val mIoU: {val_metrics['miou']:.4f}, "
                      f"Val Acc: {val_metrics['pixel_acc']:.4f}")
            else:
                print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")
        
        return history
    
    def validate(self, val_loader, use_pairwise=False):
        """Validate the model."""
        self.model.eval()
        self.metrics.reset()
        
        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(self.device), labels.to(self.device)
                
                if use_pairwise:
                    unary, feat = self.model(imgs, return_features=True)
                    pairwise = self.model.get_pairwise_scores(feat)
                    if pairwise.shape[2:] != labels.shape[1:]:
                        pairwise = F.interpolate(pairwise, size=labels.shape[1:], 
                                                mode='bilinear', align_corners=False)
                    preds = (unary + pairwise).argmax(dim=1)
                else:
                    unary = self.model(imgs)
                    preds = unary.argmax(dim=1)
                
                self.metrics.update(preds, labels)
        
        # ‚úÖ FIX: Use correct method names from SegmentationMetrics
        miou = self.metrics.compute_miou()
        pixel_acc = self.metrics.compute_pixel_accuracy()
        iou_per_class = self.metrics.compute_iou()
        
        return {
            'miou': miou,
            'pixel_acc': pixel_acc,
            'iou_per_class': iou_per_class
        }
    
    def train_piecewise(self, train_loader, stage1_epochs=2, stage2_epochs=2, 
                       stage3_epochs=2, val_loader=None):
        """Run complete piecewise training pipeline."""
        history = {'stage1': {}, 'stage2': {}, 'stage3': {}}
        
        history['stage1'] = self.train_stage1_unary(train_loader, stage1_epochs, val_loader)
        history['stage2'] = self.train_stage2_pairwise(train_loader, stage2_epochs, val_loader)
        history['stage3'] = self.train_stage3_joint(train_loader, stage3_epochs, val_loader)
        
        return history

print("‚úÖ Trainer class defined!")

‚úÖ Trainer class defined!


In [23]:
# Cell 11 - Create Trainer

print("\n" + "="*70)
print("CREATING TRAINER")
print("="*70)

trainer = AlternativeTrainer(
    model=model,
    device=device,
    num_classes=config['num_classes'],
    learning_rate=config['learning_rate'],
    lr_pretrained=config['lr_pretrained'],
    weight_decay=config['weight_decay'],
    class_weights=class_weights
)

print("‚úÖ Trainer initialized!")
print(f"\nTraining Configuration:")
print(f"   Stage 1 (Unary): {config['stage1_epochs']} epochs")
print(f"   Stage 2 (Pairwise): {config['stage2_epochs']} epochs")
print(f"   Stage 3 (Joint): {config['stage3_epochs']} epochs")
print(f"   Total: {config['stage1_epochs'] + config['stage2_epochs'] + config['stage3_epochs']} epochs")


CREATING TRAINER
‚úÖ Trainer initialized!

Training Configuration:
   Stage 1 (Unary): 2 epochs
   Stage 2 (Pairwise): 1 epochs
   Stage 3 (Joint): 1 epochs
   Total: 4 epochs


In [24]:
# Cell 12 - RUN PIECEWISE TRAINING üöÄ

print("\n" + "="*70)
print("STARTING PIECEWISE TRAINING")
print("="*70)

# Run piecewise training
history = trainer.train_piecewise(
    train_loader=train_loader,
    stage1_epochs=config['stage1_epochs'],
    stage2_epochs=config['stage2_epochs'],
    stage3_epochs=config['stage3_epochs'],
    val_loader=val_loader
)

print("\n" + "="*70)
print("‚úÖ TRAINING COMPLETE!")
print("="*70)


STARTING PIECEWISE TRAINING

STAGE 1: Training Unary Potentials


Epoch 1/2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 25/25 [00:09<00:00,  2.55it/s, loss=3.0063]


Epoch 1/2 - Loss: 3.0193, Val mIoU: 0.0378, Val Acc: 0.7188


Epoch 2/2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 25/25 [00:09<00:00,  2.76it/s, loss=3.0159]


Epoch 2/2 - Loss: 2.9972, Val mIoU: 0.0378, Val Acc: 0.7188

STAGE 2: Training Pairwise Potentials


Epoch 1/1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 25/25 [00:05<00:00,  4.52it/s, loss=3.0477]


Epoch 1/1 - Loss: 3.0335, Val mIoU: 0.0378, Val Acc: 0.7188

STAGE 3: Joint Fine-tuning


Epoch 1/1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 25/25 [00:06<00:00,  3.78it/s, loss=4.3905]


Epoch 1/1 - Loss: 4.4921, Val mIoU: 0.0378, Val Acc: 0.7188

‚úÖ TRAINING COMPLETE!


In [None]:
# Cell 13 - Save Model

save_dir = Path('alternative_model_results')
save_dir.mkdir(exist_ok=True)

model_path = save_dir / 'alternative_model_final.pth'
torch.save(model.state_dict(), model_path)

# Save training history
history_path = save_dir / 'training_history.json'
with open(history_path, 'w') as f:
    # Convert numpy arrays to lists for JSON serialization
    history_json = {}
    for stage, metrics in history.items():
        history_json[stage] = {k: [float(v) for v in vals] 
                              for k, vals in metrics.items()}
    json.dump(history_json, f, indent=2)

print(f"‚úÖ Model saved to: {model_path}")
print(f"‚úÖ History saved to: {history_path}")

In [None]:
# Cell 14 - Plot Training Curves

def plot_training_curves(history):
    """Plot training curves for all stages."""
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    stages = ['stage1', 'stage2', 'stage3']
    stage_names = ['Stage 1: Unary', 'Stage 2: Pairwise', 'Stage 3: Joint']
    
    for idx, (stage, name) in enumerate(zip(stages, stage_names)):
        ax = axes[idx]
        
        # Plot training loss
        ax.plot(history[stage]['train_loss'], label='Train Loss', linewidth=2)
        
        # Plot validation metrics if available
        if 'val_miou' in history[stage] and history[stage]['val_miou']:
            ax2 = ax.twinx()
            ax2.plot(history[stage]['val_miou'], 'g-', label='Val mIoU', linewidth=2)
            ax2.set_ylabel('mIoU', fontsize=12)
            ax2.legend(loc='upper right')
            ax2.grid(True, alpha=0.3)
        
        ax.set_xlabel('Epoch', fontsize=12)
        ax.set_ylabel('Loss', fontsize=12)
        ax.set_title(name, fontsize=14, fontweight='bold')
        ax.legend(loc='upper left')
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_dir / 'training_curves.png', dpi=150, bbox_inches='tight')
    plt.show()

plot_training_curves(history)

In [None]:
# Cell 15 - Final Evaluation

print("\n" + "="*70)
print("FINAL EVALUATION")
print("="*70)

# Evaluate on validation set
final_metrics = trainer.validate(val_loader, use_pairwise=True)

print(f"\nüìä Final Metrics:")
print(f"   Mean IoU: {final_metrics['mIoU']:.4f}")
print(f"   Pixel Accuracy: {final_metrics['pixel_acc']:.4f}")

# Per-class IoU
print(f"\nüìã Per-Class IoU:")
for idx, (name, iou) in enumerate(zip(PASCAL_VOC_CLASSES, final_metrics['iou_per_class'])):
    if iou > 0:
        print(f"   {name:15s}: {iou:.4f}")

In [None]:
# Cell 16 - Visualize Predictions

def visualize_predictions(model, dataset, device, num_samples=4):
    """Visualize model predictions."""
    model.eval()
    
    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    indices = random.sample(range(len(dataset)), num_samples)
    
    with torch.no_grad():
        for idx, sample_idx in enumerate(indices):
            image, label = dataset[sample_idx]
            image_input = image.unsqueeze(0).to(device)
            
            # Get predictions
            unary, feat = model(image_input, return_features=True)
            pairwise = model.get_pairwise_scores(feat)
            
            # Upsample pairwise
            if pairwise.shape[2:] != unary.shape[2:]:
                pairwise = F.interpolate(pairwise, size=unary.shape[2:], 
                                        mode='bilinear', align_corners=False)
            
            pred_unary = unary.argmax(1).squeeze(0).cpu()
            pred_combined = (unary + pairwise).argmax(1).squeeze(0).cpu()
            
            # Denormalize image
            mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
            image_denorm = image * std + mean
            image_denorm = torch.clamp(image_denorm, 0, 1)
            
            # Plot
            axes[idx, 0].imshow(image_denorm.permute(1, 2, 0))
            axes[idx, 0].set_title('Input Image', fontweight='bold')
            axes[idx, 0].axis('off')
            
            axes[idx, 1].imshow(label, cmap='tab20', vmin=0, vmax=20)
            axes[idx, 1].set_title('Ground Truth', fontweight='bold')
            axes[idx, 1].axis('off')
            
            axes[idx, 2].imshow(pred_unary, cmap='tab20', vmin=0, vmax=20)
            axes[idx, 2].set_title('Unary Only', fontweight='bold')
            axes[idx, 2].axis('off')
            
            axes[idx, 3].imshow(pred_combined, cmap='tab20', vmin=0, vmax=20)
            axes[idx, 3].set_title('Unary + Pairwise', fontweight='bold')
            axes[idx, 3].axis('off')
    
    plt.tight_layout()
    plt.savefig(save_dir / 'sample_predictions.png', dpi=150, bbox_inches='tight')
    plt.show()

visualize_predictions(model, val_dataset, device, num_samples=4)

In [None]:
# Cell 17 - Inference Function

def run_inference(model, image_path, device):
    """Run inference on a single image."""
    model.eval()
    
    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    original_size = image.size
    
    transform = transforms.Compose([
        transforms.Resize(config['image_size']),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    image_tensor = transform(image).unsqueeze(0).to(device)
    
    # Run inference
    with torch.no_grad():
        unary, feat = model(image_tensor, return_features=True)
        pairwise = model.get_pairwise_scores(feat)
        
        if pairwise.shape[2:] != unary.shape[2:]:
            pairwise = F.interpolate(pairwise, size=unary.shape[2:], 
                                    mode='bilinear', align_corners=False)
        
        pred = (unary + pairwise).argmax(1).squeeze(0).cpu()
    
    # Visualize
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    
    axes[0].imshow(image)
    axes[0].set_title('Input Image', fontsize=14, fontweight='bold')
    axes[0].axis('off')
    
    axes[1].imshow(pred, cmap='tab20', vmin=0, vmax=20)
    axes[1].set_title('Prediction', fontsize=14, fontweight='bold')
    axes[1].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return pred

# Example usage (update path to your test image)
# test_image_path = '/path/to/test/image.jpg'
# if os.path.exists(test_image_path):
#     prediction = run_inference(model, test_image_path, device)

In [None]:
# Cell 18 - Summary

print("\n" + "="*70)
print("TRAINING COMPLETE - SUMMARY")
print("="*70)

print(f"""
‚úÖ Model trained successfully with piecewise strategy
‚úÖ Final mIoU: {final_metrics['mIoU']:.4f}
‚úÖ Pixel Accuracy: {final_metrics['pixel_acc']:.4f}

üìÅ Generated Files:
   - Model: {model_path}
   - History: {history_path}
   - Training curves: {save_dir / 'training_curves.png'}
   - Sample predictions: {save_dir / 'sample_predictions.png'}

üéØ Next Steps:
   1. Review training curves and metrics
   2. Analyze per-class performance
   3. Run inference on your own images
   4. Fine-tune hyperparameters if needed
   5. Experiment with different scales and pool sizes
""")

print("="*70)