# Piecewise Training for Semantic Segmentation - Complete Pipeline
This notebook demonstrates:
1. Download and prepare VOC 2012 dataset
2. Install dependencies
3. Configure dataset paths
4. Visualize samples
5. Train the piecewise segmentation model
6. Generate comprehensive evaluation reports
7. Run inference on test images

## 1. Install Dependencies
Run the following cell to install required packages.

In [None]:
%pip install torch torchvision numpy pillow matplotlib tqdm

## 2. Download VOC 2012 Dataset
Download from [VOC2012](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/) and extract it.
Expected structure:
```
VOCdevkit/VOC2012/
  ‚îú‚îÄ‚îÄ JPEGImages/
  ‚îú‚îÄ‚îÄ SegmentationClass/
  ‚îú‚îÄ‚îÄ ImageSets/Segmentation/
```

## 3. Configure Dataset Paths
Update the paths below to point to your VOC2012 dataset.

In [None]:
image_dir = '/path/to/VOCdevkit/VOC2012/JPEGImages'
label_dir = '/path/to/VOCdevkit/VOC2012/SegmentationClass'
train_list = '/path/to/VOCdevkit/VOC2012/ImageSets/Segmentation/train.txt'
val_list = '/path/to/VOCdevkit/VOC2012/ImageSets/Segmentation/val.txt'

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Pascal VOC class names
PASCAL_VOC_CLASSES = [
    'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
    'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
    'horse', 'motorbike', 'person', 'pottedplant', 'sheep',
    'sofa', 'train', 'tvmonitor'
]

# Device configuration

## 4. Visualize a Sample Image and Label

In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import random
import numpy as np
import os

def visualize_voc_samples(num_samples=3):
    """Visualize random samples from VOC dataset."""
    sample_images = random.sample(
        [f for f in os.listdir(image_dir) if f.endswith('.jpg')],
        num_samples
    )
    
    fig, axes = plt.subplots(num_samples, 2, figsize=(12, 4*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    for idx, img_name in enumerate(sample_images):
        img_path = os.path.join(image_dir, img_name)
        label_path = os.path.join(label_dir, img_name.replace('.jpg', '.png'))
        
        img = Image.open(img_path)
        label = Image.open(label_path)
        
        axes[idx, 0].imshow(img)
        axes[idx, 0].set_title(f'Image: {img_name}')
        axes[idx, 0].axis('off')
        
        axes[idx, 1].imshow(label, cmap='tab20', vmin=0, vmax=20)
        axes[idx, 1].set_title('Segmentation Label')
        axes[idx, 1].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize samples
visualize_voc_samples(num_samples=3)

## 5. Train the Piecewise Model
This uses the implementation from `Efficient Piecewise Training of Deep Structured Models for Semantic Segmentation` (model, trainer, dataset classes).

In [None]:
from piecewise_training.model import PiecewiseTrainedModel
from piecewise_training.trainer import PiecewiseTrainer
from piecewise_training.dataset import SegmentationDataset, RandomHorizontalFlip
from torch.utils.data import DataLoader
import torch

# Config
num_classes = 21
batch_size = 8
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Datasets
train_dataset = SegmentationDataset(image_dir=image_dir, label_dir=label_dir, transform=RandomHorizontalFlip(), image_size=(512, 512))
val_dataset = SegmentationDataset(image_dir=image_dir, label_dir=label_dir, image_size=(512, 512))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

# Model and trainer
model = PiecewiseTrainedModel(num_classes=num_classes, crf_iterations=10, use_crf=True)
trainer = PiecewiseTrainer(model=model, device=device, num_classes=num_classes, learning_rate=1e-3, weight_decay=5e-4)

print("Model and trainer ready!")


## 6. Train the Piecewise Model

In [None]:
print("\n" + "="*70)
print("STARTING PIECEWISE TRAINING")
print("="*70)

# Train with piecewise strategy
history = trainer.train_piecewise(
    train_loader=train_loader,
    stage1_epochs=20,  # Train unary network
    stage2_epochs=5,   # Train CRF parameters
    stage3_epochs=10,  # Joint fine-tuning
    val_loader=val_loader
)

# Save model
model_save_path = 'piecewise_model_final.pth'
torch.save(model.state_dict(), model_save_path)
print(f"\n‚úÖ Model saved to: {model_save_path}")

## 7. Generate Comprehensive Evaluation Report

In [None]:
from pathlib import Path

print("\n" + "="*70)
print("GENERATING COMPREHENSIVE EVALUATION REPORT")
print("="*70)

from src.piecewise_training.visualization import ComprehensiveVisualizer
import numpy as np

# Create visualizer
visualizer = ComprehensiveVisualizer(
    num_classes=num_classes,
    class_names=PASCAL_VOC_CLASSES
)

# Collect validation metrics
print("\nüìä Collecting validation metrics...")
model.eval()

confusion_matrix = np.zeros((num_classes, num_classes))
sample_predictions = []

with torch.no_grad():
    for batch_idx, (images, labels) in enumerate(val_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        # Get predictions
        unary_output, crf_output = model(images, apply_crf=True)
        unary_pred = unary_output.argmax(1)
        crf_pred = crf_output.argmax(1) if crf_output is not None else unary_pred
        
        # Update confusion matrix
        for i in range(num_classes):
            for j in range(num_classes):
                mask = labels != 255  # Ignore index
                confusion_matrix[i, j] += (
                    (labels[mask] == i) & (crf_pred[mask] == j)
                ).sum().item()
        
        # Collect sample predictions (first 10 batches)
        if batch_idx < 10:
            for b in range(min(3, images.shape[0])):
                sample_predictions.append({
                    'image': images[b],
                    'gt': labels[b],
                    'unary_pred': unary_pred[b],
                    'crf_pred': crf_pred[b],
                    'pred': crf_pred[b]
                })
        
        if batch_idx % 10 == 0:
            print(f"   Processed {batch_idx}/{len(val_loader)} batches...")

# Compute final metrics
print("\nüìà Computing final metrics...")
iou_per_class = visualizer._compute_iou_from_cm(confusion_matrix)
mean_iou = np.nanmean(iou_per_class)

final_metrics = {
    'mIoU': mean_iou,
    'Pixel Acc': confusion_matrix.diagonal().sum() / confusion_matrix.sum()
}

print(f"\n‚úÖ Final mIoU: {mean_iou:.4f}")
print(f"‚úÖ Pixel Accuracy: {final_metrics['Pixel Acc']:.4f}")

# Generate complete report
print("\nüé® Generating visualizations...")
results_dir = 'training_results'
Path(results_dir).mkdir(parents=True, exist_ok=True)

visualizer.generate_full_report(
    history=history,
    final_metrics=final_metrics,
    confusion_matrix=confusion_matrix,
    sample_predictions=sample_predictions,
    save_dir=results_dir
)

print(f"\n‚úÖ Complete report saved to: {results_dir}/")

## 8. Display Individual Visualizations in Notebook

In [None]:
# 8.1 Training Curves
print("\nüìä Training Curves:")
visualizer.plot_training_curves(history)

# 8.2 Metrics Table
print("\nüìã Metrics Summary:")
visualizer.generate_metrics_table(history, final_metrics)

# 8.3 Confusion Matrix
print("\nüî¢ Confusion Matrix:")
visualizer.plot_confusion_matrix(confusion_matrix)

# 8.4 Per-Class IoU
print("\nüìä Per-Class IoU:")
visualizer.plot_per_class_iou(iou_per_class)

# 8.5 Sample Predictions
print("\nüñºÔ∏è Sample Predictions:")
visualizer.visualize_predictions_grid(sample_predictions[:6])

# 8.6 CRF Comparison
print("\nüîç CRF Refinement Comparison:")
visualizer.plot_crf_comparison(sample_predictions[:3])

## 9. Detailed Per-Class Performance Analysis

In [None]:
print("\n" + "="*70)
print("PER-CLASS PERFORMANCE ANALYSIS")
print("="*70)

# Create detailed table
import pandas as pd
from tabulate import tabulate

class_performance = []
for idx, class_name in enumerate(PASCAL_VOC_CLASSES):
    if not np.isnan(iou_per_class[idx]):
        tp = confusion_matrix[idx, idx]
        fp = confusion_matrix[:, idx].sum() - tp
        fn = confusion_matrix[idx, :].sum() - tp
        
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        class_performance.append({
            'Class': class_name,
            'IoU': iou_per_class[idx],
            'Precision': precision,
            'Recall': recall,
            'F1-Score': f1,
            'Support': int(tp + fn)
        })

df_performance = pd.DataFrame(class_performance)
df_performance = df_performance.sort_values('IoU', ascending=False)

print(tabulate(df_performance, headers='keys', tablefmt='grid', floatfmt='.4f', showindex=False))

# Save to CSV
df_performance.to_csv(f'{results_dir}/per_class_performance.csv', index=False)
print(f"\n‚úÖ Saved to: {results_dir}/per_class_performance.csv")

## 10. Run Inference on Test Images

In [None]:
print("\n" + "="*70)
print("RUNNING INFERENCE ON TEST IMAGES")
print("="*70)

def run_inference_on_image(image_path, model, device, visualize=True):
    """Run inference on a single image."""
    from torchvision import transforms
    
    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    original_size = image.size
    
    # Resize and normalize
    transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])
    
    image_tensor = transform(image).unsqueeze(0).to(device)
    
    # Run inference
    model.eval()
    with torch.no_grad():
        unary_output, crf_output = model(image_tensor, apply_crf=True)
        
        unary_pred = unary_output.argmax(1).squeeze(0)
        crf_pred = crf_output.argmax(1).squeeze(0) if crf_output is not None else unary_pred
    
    if visualize:
        # Visualize results
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))
        
        # Original image
        axes[0].imshow(image)
        axes[0].set_title('Original Image', fontsize=14, fontweight='bold')
        axes[0].axis('off')
        
        # Unary prediction
        colors = plt.cm.get_cmap('tab20', num_classes)
        unary_colored = colors(unary_pred.cpu().numpy())[:, :, :3]
        axes[1].imshow(unary_colored)
        axes[1].set_title('Unary (CNN Only)', fontsize=14, fontweight='bold')
        axes[1].axis('off')
        
        # CRF prediction
        crf_colored = colors(crf_pred.cpu().numpy())[:, :, :3]
        axes[2].imshow(crf_colored)
        axes[2].set_title('CRF Refined', fontsize=14, fontweight='bold')
        axes[2].axis('off')
        
        plt.tight_layout()
        plt.show()
    
    return unary_pred, crf_pred

# Example: Run inference on a test image
test_image_path = '/path/to/test/image.jpg'  # Update this path

if os.path.exists(test_image_path):
    print(f"\nüñºÔ∏è Running inference on: {test_image_path}")
    unary_pred, crf_pred = run_inference_on_image(test_image_path, model, device)
    print("‚úÖ Inference complete!")
else:
    print(f"‚ö†Ô∏è Test image not found: {test_image_path}")
    print("Please update the path to a valid image.")

## 11. Batch Inference on Multiple Images

In [None]:

def batch_inference(image_paths, model, device, save_dir='inference_results'):
    """Run inference on multiple images and save results."""
    Path(save_dir).mkdir(parents=True, exist_ok=True)
    
    for idx, img_path in enumerate(image_paths):
        print(f"\nProcessing {idx+1}/{len(image_paths)}: {img_path}")
        
        unary_pred, crf_pred = run_inference_on_image(
            img_path, model, device, visualize=False
        )
        
        # Save predictions
        save_path = Path(save_dir) / f"prediction_{idx:03d}.png"
        
        # Convert to color image
        colors = plt.cm.get_cmap('tab20', num_classes)
        crf_colored = (colors(crf_pred.cpu().numpy())[:, :, :3] * 255).astype(np.uint8)
        Image.fromarray(crf_colored).save(save_path)
        
        print(f"   Saved to: {save_path}")

# Example: Process multiple test images
test_images = [
    '/path/to/test/image1.jpg',
    '/path/to/test/image2.jpg',
    '/path/to/test/image3.jpg',
]

# Uncomment to run batch inference
# batch_inference(test_images, model, device)

## 12. Summary and Next Steps

In [None]:
print("\n" + "="*70)
print("TRAINING COMPLETE - SUMMARY")
print("="*70)

print(f"""
‚úÖ Model trained successfully with piecewise strategy
‚úÖ Final mIoU: {mean_iou:.4f}
‚úÖ Pixel Accuracy: {final_metrics['Pixel Acc']:.4f}

üìÅ Generated Files:
   - Model: {model_save_path}
   - Results: {results_dir}/
     ‚îú‚îÄ‚îÄ training_curves.png
     ‚îú‚îÄ‚îÄ metrics_summary.txt
     ‚îú‚îÄ‚îÄ metrics_summary.csv
     ‚îú‚îÄ‚îÄ confusion_matrix.png
     ‚îú‚îÄ‚îÄ per_class_iou.png
     ‚îú‚îÄ‚îÄ sample_predictions.png
     ‚îú‚îÄ‚îÄ crf_comparison.png
     ‚îî‚îÄ‚îÄ per_class_performance.csv

üéØ Next Steps:
   1. Review training curves and metrics
   2. Analyze per-class performance
   3. Run inference on your own images
   4. Fine-tune hyperparameters if needed
   5. Experiment with different CRF iterations
""")

print("="*70)