# Piecewise Training for Semantic Segmentation
This notebook demonstrates how to:
1. Download and prepare the VOC 2012 dataset
2. Install dependencies
3. Configure dataset paths
4. Visualize samples
5. Train the piecewise segmentation model
6. Evaluate and visualize results
7. Run inference on a test image

## 1. Install Dependencies
Run the following cell to install required packages.

In [None]:
%pip install torch torchvision numpy pillow matplotlib tqdm

## 2. Download VOC 2012 Dataset
Download from [VOC2012](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/) and extract it.
Expected structure:
```
VOCdevkit/VOC2012/
  ├── JPEGImages/
  ├── SegmentationClass/
  ├── ImageSets/Segmentation/
```

## 3. Configure Dataset Paths
Update the paths below to point to your VOC2012 dataset.

In [None]:
image_dir = '/path/to/VOCdevkit/VOC2012/JPEGImages'
label_dir = '/path/to/VOCdevkit/VOC2012/SegmentationClass'
train_list = '/path/to/VOCdevkit/VOC2012/ImageSets/Segmentation/train.txt'
val_list = '/path/to/VOCdevkit/VOC2012/ImageSets/Segmentation/val.txt'

## 4. Visualize a Sample Image and Label

In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import random
import os

# Pick a random image from JPEGImages
sample_img = random.choice([f for f in os.listdir(image_dir) if f.endswith('.jpg')])
img_path = os.path.join(image_dir, sample_img)
label_path = os.path.join(label_dir, sample_img.replace('.jpg', '.png'))

img = Image.open(img_path)
label = Image.open(label_path)

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(img)
axes[0].set_title('Image')
axes[1].imshow(label)
axes[1].set_title('Segmentation Label')
plt.show()

## 5. Train the Piecewise Model
This uses the implementation from `Efficient Piecewise Training of Deep Structured Models for Semantic Segmentation` (model, trainer, dataset classes).

In [None]:
from piecewise_training.model import PiecewiseTrainedModel
from piecewise_training.trainer import PiecewiseTrainer
from piecewise_training.dataset import SegmentationDataset, RandomHorizontalFlip
from torch.utils.data import DataLoader
import torch

# Config
num_classes = 21
batch_size = 8
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Datasets
train_dataset = SegmentationDataset(image_dir=image_dir, label_dir=label_dir, transform=RandomHorizontalFlip(), image_size=(512, 512))
val_dataset = SegmentationDataset(image_dir=image_dir, label_dir=label_dir, image_size=(512, 512))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# Model and trainer
model = PiecewiseTrainedModel(num_classes=num_classes, crf_iterations=10, use_crf=True)
trainer = PiecewiseTrainer(model=model, device=device, num_classes=num_classes, learning_rate=1e-3, weight_decay=5e-4)

# Train
history = trainer.train_piecewise(train_loader=train_loader, stage1_epochs=20, stage2_epochs=5, stage3_epochs=10, val_loader=val_loader)

# Save model
torch.save(model.state_dict(), 'piecewise_model_final.pth')

## 6. Evaluate and Visualize Training History

In [None]:
from piecewise_training.utils import plot_training_history
plot_training_history(history, save_path='training_history.png')

## 7. Run Inference on a Test Image

In [None]:
from piecewise_training.utils import visualize_segmentation
from PIL import Image
import numpy as np

# Load model
model.load_state_dict(torch.load('piecewise_model_final.pth', map_location=device))
model.eval()

# Load test image
test_image_path = '/path/to/test/image.jpg'
image = Image.open(test_image_path).convert('RGB').resize((512, 512))
image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
image_tensor = (image_tensor - mean) / std
image_tensor = image_tensor.unsqueeze(0).to(device)

with torch.no_grad():
    unary_output, crf_output = model(image_tensor, apply_crf=True)
    prediction = crf_output.argmax(dim=1).squeeze(0)

visualize_segmentation(image_tensor.squeeze(0), prediction.cpu(), prediction.cpu(), num_classes=num_classes)