# 03 — Segmentation Training: DeepLabV3+ on RiceSEG

**Purpose:** Train a semantic segmentation model to pixel-classify rice field weeds.
**Runtime:** GPU required (Kaggle P100 recommended). Expect ~30-45 min total training.
**Platform:** Works on Kaggle, Colab, and local (with GPU).
**Prerequisite:** Run notebook 02 first to understand the dataset and class distribution.

## What This Notebook Does

1. **Setup** — Platform detection, dependency install, W&B logging init
2. **Data pipeline** — Image-mask pairs, train/val split, augmentation transforms
3. **Model** — DeepLabV3+ with ResNet-50 backbone (ImageNet pretrained)
4. **Training Phase 1** — Frozen encoder, train decoder only (5 epochs)
5. **Training Phase 2** — Unfreeze encoder, fine-tune end-to-end (15 epochs)
6. **Evaluation** — Per-class IoU, mIoU, visual prediction overlays
7. **Export** — Save best model checkpoint

### Key Challenge: Weed Pixel Sparsity

From notebook 02: weed pixels are only ~1.6% of total pixels. A naive model achieves >98%
accuracy by predicting "not weed" everywhere. We use **focal loss** with class weights
to force the model to learn weed features.

| Parameter | Value |
|-----------|-------|
| **Architecture** | DeepLabV3+ (ResNet-50 encoder) |
| **Input** | 384x384 crop (train) / 512x512 (val) |
| **Classes** | 6: Background, Green veg, Senescent veg, Panicle, Weeds, Duckweed |
| **Loss** | Focal loss (gamma=2) with inverse-frequency class weights |
| **Expected mIoU** | 40-55% overall, 10-30% weed IoU |

---
## 1. Platform Detection & Setup

Same pattern as notebooks 01 and 02 — detect Kaggle vs Colab vs local.

In [None]:
import os
import sys

# --- Platform Detection ---
IS_KAGGLE = os.path.exists('/kaggle/input')

try:
    import google.colab
    IS_COLAB = True
except ImportError:
    IS_COLAB = False

IS_LOCAL = not IS_KAGGLE and not IS_COLAB

PLATFORM = 'kaggle' if IS_KAGGLE else ('colab' if IS_COLAB else 'local')
print(f'Platform detected: {PLATFORM}')
print(f'Python version: {sys.version}')

### Install Dependencies

Training requires PyTorch, `segmentation-models-pytorch` (DeepLabV3+ implementation),
`albumentations` (augmentation), and `wandb` (experiment tracking).

**On Kaggle/Colab:** PyTorch and torchvision are pre-installed.

In [None]:
import subprocess

packages = ['segmentation-models-pytorch', 'albumentations', 'wandb', 'scikit-learn']

for pkg in packages:
    mod = pkg.replace('-', '_')
    try:
        __import__(mod)
    except ImportError:
        print(f'Installing {pkg}...')
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', pkg])

print('All dependencies ready.')

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from collections import Counter
import json
import time
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split

import wandb

# Consistent plot style
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['figure.dpi'] = 100
plt.rcParams['font.size'] = 11

# Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')
if DEVICE.type == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')
else:
    print('WARNING: No GPU detected. Training will be extremely slow.')

print('\nImports ready.')

### W&B Experiment Tracking (Optional)

Weights & Biases logs training metrics, plots, and model artifacts to the cloud.
Set `USE_WANDB = False` to disable — metrics still print to stdout.

In [None]:
USE_WANDB = True  # Set to False to skip W&B logging

if USE_WANDB:
    try:
        wandb.login()
        print('W&B login successful.')
    except Exception as e:
        print(f'W&B login failed: {e}')
        print('Continuing without W&B.')
        USE_WANDB = False
else:
    print('W&B disabled.')

---
## 2. Load RiceSEG Dataset

### Class Definitions

| Class ID | Name | Training Priority |
|----------|------|-------------------|
| 0 | Background | Low — dominant class |
| 1 | Green vegetation | Medium |
| 2 | Senescent vegetation | Medium |
| 3 | Panicle | Medium |
| 4 | **Weeds** | **High** — primary target |
| 5 | **Duckweed** | **High** — secondary target |

In [None]:
# --- RiceSEG class definitions ---
RICESEG_CLASSES = {
    0: 'Background',
    1: 'Green vegetation',
    2: 'Senescent vegetation',
    3: 'Panicle',
    4: 'Weeds',
    5: 'Duckweed',
}
NUM_CLASSES = len(RICESEG_CLASSES)

# Colors for visualization (RGB)
CLASS_COLORS = {
    0: (0, 0, 0),          # Background — black
    1: (0, 200, 0),        # Green vegetation
    2: (200, 200, 0),      # Senescent vegetation
    3: (255, 165, 0),      # Panicle — orange
    4: (255, 0, 0),        # Weeds — red
    5: (0, 100, 255),      # Duckweed — blue
}

# --- Set data path ---
if IS_KAGGLE:
    DATA_ROOT = Path('/kaggle/input/riceseg')
elif IS_COLAB:
    DATA_ROOT = Path('/content/riceseg')
else:
    DATA_ROOT = Path('./data/riceseg')

print(f'Data root: {DATA_ROOT}')
print(f'Exists: {DATA_ROOT.exists()}')
print(f'Classes ({NUM_CLASSES}):')
for k, v in RICESEG_CLASSES.items():
    print(f'  {k}: {v}')

if not DATA_ROOT.exists():
    print()
    print('=' * 60)
    print('RICESEG NOT FOUND')
    print('=' * 60)
    print('See notebook 02 for detailed setup instructions.')
    print(f'Expected path: {DATA_ROOT}')
    print()
    print('Quick setup:')
    print('  1. Download RiceSEG from HuggingFace')
    print('  2. Upload as private Kaggle Dataset named "riceseg"')
    print('  3. Attach to this notebook via "Add Data"')
    print('=' * 60)

### Build Image-Mask Pairs

Same pairing logic as notebook 02 — match images to masks by filename across directories.
Tries multiple directory structures: `images/` + `masks/`, country-based subdirs, etc.

In [None]:
def find_image_mask_pairs(data_root):
    """Find image-mask pairs in the RiceSEG dataset.

    Tries multiple common directory structures to match images to masks.
    Returns: list of dicts with image_path, mask_path, country.
    """
    image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'}
    mask_dir_names = {'masks', 'mask', 'labels', 'label', 'annotations', 'annotation', 'gt', 'groundtruth'}
    image_dir_names = {'images', 'image', 'img', 'rgb', 'input'}

    pairs = []

    # Strategy 1: Top-level images/ + masks/
    for img_dir_name in image_dir_names:
        img_dir = data_root / img_dir_name
        if not img_dir.exists():
            continue
        for mask_dir_name in mask_dir_names:
            mask_dir = data_root / mask_dir_name
            if not mask_dir.exists():
                continue
            for img_file in sorted(img_dir.rglob('*')):
                if img_file.suffix.lower() not in image_extensions:
                    continue
                for ext in image_extensions:
                    mask_candidate = mask_dir / (img_file.stem + ext)
                    if mask_candidate.exists():
                        pairs.append({
                            'image_path': str(img_file),
                            'mask_path': str(mask_candidate),
                            'country': 'unknown',
                        })
                        break

    if pairs:
        return pairs

    # Strategy 2: country/images/ + country/masks/
    for country_dir in sorted(data_root.iterdir()):
        if not country_dir.is_dir():
            continue
        for img_dir_name in image_dir_names:
            img_dir = country_dir / img_dir_name
            if not img_dir.exists():
                continue
            for mask_dir_name in mask_dir_names:
                mask_dir = country_dir / mask_dir_name
                if not mask_dir.exists():
                    continue
                for img_file in sorted(img_dir.rglob('*')):
                    if img_file.suffix.lower() not in image_extensions:
                        continue
                    for ext in image_extensions:
                        mask_candidate = mask_dir / (img_file.stem + ext)
                        if mask_candidate.exists():
                            pairs.append({
                                'image_path': str(img_file),
                                'mask_path': str(mask_candidate),
                                'country': country_dir.name,
                            })
                            break

    if pairs:
        return pairs

    # Strategy 3: Find all images and match with nearby mask directories
    all_images = sorted(f for f in data_root.rglob('*')
                        if f.suffix.lower() in image_extensions)

    by_dir = {}
    for img in all_images:
        by_dir.setdefault(str(img.parent), []).append(img)

    for dir_path, imgs in by_dir.items():
        dir_p = Path(dir_path)
        parent = dir_p.parent
        if dir_p.name.lower() in mask_dir_names:
            continue
        for mask_dir_name in mask_dir_names:
            mask_dir = parent / mask_dir_name
            if mask_dir.exists():
                for img_file in imgs:
                    for ext in image_extensions:
                        mask_candidate = mask_dir / (img_file.stem + ext)
                        if mask_candidate.exists():
                            rel = img_file.relative_to(data_root)
                            country = rel.parts[0] if len(rel.parts) > 2 else 'unknown'
                            pairs.append({
                                'image_path': str(img_file),
                                'mask_path': str(mask_candidate),
                                'country': country,
                            })
                            break
    return pairs


# --- Build pairs ---
if DATA_ROOT.exists():
    pairs = find_image_mask_pairs(DATA_ROOT)
    df = pd.DataFrame(pairs)

    print(f'Found {len(df)} image-mask pairs')
    if len(df) > 0:
        print(f'\nCountry distribution:')
        print(df['country'].value_counts().to_string())
    else:
        print('No pairs found. Check directory structure.')
        print('Remaining cells will show "no data" messages.')
else:
    df = pd.DataFrame()
    print('Data root not found. Cannot load pairs.')
    print('Remaining cells will show "no data" messages.')

---
## 3. Train / Validation Split

80% train, 20% validation. Random split with fixed seed for reproducibility.

> **Note:** A country-stratified split would ensure each split has proportional representation
> from all countries. For this baseline we use a simple random split — the model should
> generalize across countries anyway.

In [None]:
if len(df) > 0:
    train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
    train_df = train_df.reset_index(drop=True)
    val_df = val_df.reset_index(drop=True)

    print(f'Train: {len(train_df):,} pairs')
    print(f'Val:   {len(val_df):,} pairs')

    if 'country' in df.columns and df['country'].nunique() > 1:
        print(f'\nTrain countries: {train_df["country"].value_counts().to_dict()}')
        print(f'Val countries:   {val_df["country"].value_counts().to_dict()}')
else:
    train_df = val_df = None
    print('No data available for splitting.')

---
## 4. Data Pipeline

### Augmentation Strategy

| Transform | Train | Val | Why |
|-----------|:-----:|:---:|-----|
| RandomCrop 384x384 | Yes | — | Augmentation + memory savings |
| HorizontalFlip | Yes | — | Rice fields look same mirrored |
| VerticalFlip | Yes | — | Top-down views are rotation-invariant |
| RandomRotate90 | Yes | — | Further rotation invariance |
| ColorJitter | Yes | — | Handle varying lighting conditions |
| Normalize (ImageNet) | Yes | Yes | Required for pretrained backbone |

**Key:** Albumentations applies the SAME spatial transform to both image and mask.
Color transforms only affect the image — the mask's integer class IDs stay untouched.

In [None]:
CROP_SIZE = 384  # Random crop from native 512x512

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

train_transform = A.Compose([
    A.RandomCrop(CROP_SIZE, CROP_SIZE),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.5),
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ToTensorV2(),
])

val_transform = A.Compose([
    # No crop — evaluate on full 512x512 for best metrics.
    # DeepLabV3+ is fully convolutional and handles any input size.
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ToTensorV2(),
])

print(f'Train: RandomCrop({CROP_SIZE}) + flip + rotate + color jitter + normalize')
print(f'Val:   normalize only (full 512x512)')

In [None]:
class RiceSEGDataset(Dataset):
    """PyTorch dataset for RiceSEG segmentation."""

    def __init__(self, dataframe, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = np.array(Image.open(row['image_path']).convert('RGB'))
        mask = np.array(Image.open(row['mask_path']))

        # Ensure mask is 2D (H, W) with integer class IDs
        if mask.ndim == 3:
            mask = mask[:, :, 0]

        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']

        return image, mask.long()


if train_df is not None:
    train_dataset = RiceSEGDataset(train_df, transform=train_transform)
    val_dataset = RiceSEGDataset(val_df, transform=val_transform)

    img, msk = train_dataset[0]
    print(f'Train dataset: {len(train_dataset)} samples')
    print(f'Val dataset:   {len(val_dataset)} samples')
    print(f'Image shape:   {img.shape} (C, H, W)')
    print(f'Mask shape:    {msk.shape} (H, W)')
    print(f'Mask dtype:    {msk.dtype}')
    print(f'Mask classes:  {torch.unique(msk).tolist()}')
else:
    print('No data to create datasets.')

In [None]:
BATCH_SIZE = 8      # Training (384x384) — reduce to 4 if OOM
VAL_BATCH_SIZE = 4  # Validation (512x512) — larger images need smaller batch
NUM_WORKERS = 2     # Kaggle has 2 CPU cores

if train_df is not None:
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        drop_last=True,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=VAL_BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )

    print(f'Train: batch_size={BATCH_SIZE}, {len(train_loader)} batches')
    print(f'Val:   batch_size={VAL_BATCH_SIZE}, {len(val_loader)} batches')
else:
    print('No data for DataLoaders.')

### Preview Augmented Samples

Sanity check: do augmented images and masks still align correctly after spatial transforms?

In [None]:
def mask_to_rgb(mask, class_colors):
    """Convert class ID mask to RGB image for visualization."""
    h, w = mask.shape
    rgb = np.zeros((h, w, 3), dtype=np.uint8)
    for cls_id, color in class_colors.items():
        rgb[mask == cls_id] = color
    return rgb


def denormalize(tensor, mean=IMAGENET_MEAN, std=IMAGENET_STD):
    """Reverse ImageNet normalization for display."""
    mean = torch.tensor(mean).view(3, 1, 1)
    std = torch.tensor(std).view(3, 1, 1)
    return (tensor * std + mean).clamp(0, 1)


if train_df is not None:
    fig, axes = plt.subplots(3, 4, figsize=(16, 12))

    for row in range(3):
        img, msk = train_dataset[row]
        img_np = denormalize(img).permute(1, 2, 0).numpy()
        msk_np = msk.numpy()
        msk_rgb = mask_to_rgb(msk_np, CLASS_COLORS)

        # Overlay: blend image with colored mask
        overlay = img_np.copy()
        for cls_id in range(1, NUM_CLASSES):  # skip background
            cls_mask = msk_np == cls_id
            if cls_mask.any():
                color = np.array(CLASS_COLORS[cls_id]) / 255.0
                overlay[cls_mask] = overlay[cls_mask] * 0.5 + color * 0.5

        axes[row, 0].imshow(img_np)
        axes[row, 0].set_title('Image', fontsize=10)
        axes[row, 1].imshow(msk_rgb)
        axes[row, 1].set_title('Mask', fontsize=10)
        axes[row, 2].imshow(overlay)
        axes[row, 2].set_title('Overlay', fontsize=10)

        # Per-sample class histogram
        unique, counts = np.unique(msk_np, return_counts=True)
        names = [RICESEG_CLASSES.get(u, '?') for u in unique]
        colors = [np.array(CLASS_COLORS.get(u, (128, 128, 128))) / 255.0 for u in unique]
        axes[row, 3].barh(names, counts, color=colors)
        axes[row, 3].set_title('Pixel counts', fontsize=10)

        for ax in axes[row, :3]:
            ax.axis('off')

    plt.suptitle('Augmented Training Samples (384x384 crops)', fontsize=13, fontweight='bold')
    plt.tight_layout()
    plt.show()

    # Legend
    print('Class colors:')
    for cls_id, name in RICESEG_CLASSES.items():
        r, g, b = CLASS_COLORS[cls_id]
        print(f'  {cls_id}: {name} -- RGB({r}, {g}, {b})')

---
## 5. Model: DeepLabV3+ with ResNet-50

### Why DeepLabV3+?

DeepLabV3+ adds an **encoder-decoder** structure to DeepLabV3:
- **Encoder** (ResNet-50): Extracts multi-scale features using atrous (dilated) convolutions
- **ASPP** (Atrous Spatial Pyramid Pooling): Captures context at multiple scales
- **Decoder**: Recovers spatial detail by combining low-level encoder features with high-level ASPP output

The decoder is what distinguishes DeepLabV3+ from DeepLabV3 — it produces sharper boundaries,
which matters for small, irregular weed patches.

We use `segmentation-models-pytorch` (smp) for a clean implementation.

### Transfer Learning Strategy

| Phase | Encoder | Decoder/Head | LR | Epochs |
|-------|---------|-------------|------|--------|
| 1 (frozen) | Frozen (ImageNet) | Training | 1e-3 | 5 |
| 2 (fine-tune) | Unfrozen | Training | 1e-4 | 15 |

In [None]:
def create_model(num_classes, encoder='resnet50', pretrained=True):
    """Create DeepLabV3+ with specified encoder."""
    model = smp.DeepLabV3Plus(
        encoder_name=encoder,
        encoder_weights='imagenet' if pretrained else None,
        in_channels=3,
        classes=num_classes,
    )
    return model


def freeze_encoder(model):
    """Freeze encoder — only decoder trains."""
    for param in model.encoder.parameters():
        param.requires_grad = False


def unfreeze_all(model):
    """Unfreeze all layers for end-to-end fine-tuning."""
    for param in model.parameters():
        param.requires_grad = True


def count_parameters(model):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    return trainable, total


model = create_model(NUM_CLASSES)
model = model.to(DEVICE)

# Start with frozen encoder (Phase 1)
freeze_encoder(model)
trainable, total = count_parameters(model)

print(f'Model: DeepLabV3+ (ResNet-50 encoder)')
print(f'Total parameters:     {total:>12,}')
print(f'Trainable parameters: {trainable:>12,} (decoder only)')
print(f'Frozen parameters:    {total - trainable:>12,} (encoder)')

### Loss Function: Focal Loss

Standard cross-entropy treats all pixels equally, so the model learns to predict "background"
everywhere (>98% accuracy for free). **Focal loss** down-weights easy examples and focuses
on hard ones:

**FL(p_t) = -alpha_t (1 - p_t)^gamma log(p_t)**

- **gamma:** Controls focus. gamma=0 is standard CE. gamma=2 strongly penalizes confident wrong predictions.
- **alpha:** Per-class weights. Weed class gets high weight to compensate for sparsity.

In [None]:
# Compute class weights from training data (inverse frequency, normalized to mean=1)
if train_df is not None:
    print('Computing class pixel distribution from training set...')
    class_pixel_counts = np.zeros(NUM_CLASSES, dtype=np.int64)
    total_pixels = 0

    for i, (_, row) in enumerate(train_df.iterrows()):
        try:
            mask = np.array(Image.open(row['mask_path']))
            if mask.ndim == 3:
                mask = mask[:, :, 0]
            for cls_id in range(NUM_CLASSES):
                class_pixel_counts[cls_id] += np.sum(mask == cls_id)
            total_pixels += mask.size
        except Exception:
            pass
        if (i + 1) % 500 == 0:
            print(f'  {i + 1}/{len(train_df)}...')

    print(f'Done. Total pixels: {total_pixels:,}')

    # Inverse frequency weights
    freq = class_pixel_counts / total_pixels
    raw_weights = 1.0 / np.where(freq > 0, freq, 1)
    class_weights_np = raw_weights / raw_weights.mean()

    CLASS_WEIGHTS = torch.tensor(class_weights_np, dtype=torch.float32)

    print('\nComputed class weights (inverse frequency, normalized):')
    for i in range(NUM_CLASSES):
        pct = freq[i] * 100
        print(f'  {RICESEG_CLASSES[i]:25s}: {pct:6.2f}% pixels -> weight {CLASS_WEIGHTS[i]:.2f}')
else:
    # Fallback weights (approximate, from notebook 02 analysis)
    CLASS_WEIGHTS = torch.tensor([0.03, 0.08, 0.30, 0.50, 3.00, 2.00], dtype=torch.float32)
    CLASS_WEIGHTS = CLASS_WEIGHTS / CLASS_WEIGHTS.mean()
    print('Using fallback class weights (data not available).')

In [None]:
class FocalLoss(nn.Module):
    """Focal Loss for semantic segmentation with per-class weights.

    Args:
        alpha: Per-class weight tensor of shape (num_classes,).
        gamma: Focusing parameter. 0 = standard CE. 2 = strong focus on hard examples.
        ignore_index: Class index to ignore in loss computation.
    """

    def __init__(self, alpha=None, gamma=2.0, ignore_index=-1):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.ignore_index = ignore_index

    def forward(self, logits, targets):
        # logits: (B, C, H, W), targets: (B, H, W)
        ce_loss = F.cross_entropy(
            logits, targets,
            reduction='none',
            ignore_index=self.ignore_index,
        )

        # p_t = probability assigned to the correct class
        pt = torch.exp(-ce_loss)

        # Focal modulation: down-weight easy examples
        focal_loss = (1 - pt) ** self.gamma * ce_loss

        # Apply per-class weights
        if self.alpha is not None:
            # Clamp targets to valid range for indexing (ignore_index might be -1)
            alpha_t = self.alpha.to(logits.device)[targets.clamp(0)]
            focal_loss = alpha_t * focal_loss

        return focal_loss.mean()


GAMMA = 2.0
criterion = FocalLoss(alpha=CLASS_WEIGHTS, gamma=GAMMA)

print(f'Loss: Focal Loss (gamma={GAMMA})')
print(f'Class weights applied: yes ({NUM_CLASSES} classes)')

### Evaluation Metrics: IoU (Intersection over Union)

For segmentation, **pixel accuracy is misleading** (>98% by predicting background everywhere).
We use IoU instead:

**IoU = Intersection / Union** for each class separately.

- **mIoU:** Mean IoU across all classes — the standard segmentation benchmark metric.
- **Weed IoU:** The single most important metric for our task.
- IoU ranges from 0 (no overlap) to 1 (perfect overlap).

In [None]:
def compute_epoch_iou(intersection_sum, union_sum, num_classes):
    """Compute per-class IoU from accumulated intersection/union counts."""
    ious = {}
    for cls in range(num_classes):
        if union_sum[cls] > 0:
            ious[cls] = intersection_sum[cls] / union_sum[cls]
        else:
            ious[cls] = float('nan')

    valid_ious = [v for v in ious.values() if not np.isnan(v)]
    miou = np.mean(valid_ious) if valid_ious else 0.0

    return ious, miou


def update_iou_accumulators(preds, targets, intersection_sum, union_sum, num_classes):
    """Update running intersection/union counts for IoU computation."""
    for cls in range(num_classes):
        pred_cls = (preds == cls)
        target_cls = (targets == cls)
        intersection_sum[cls] += (pred_cls & target_cls).sum()
        union_sum[cls] += (pred_cls | target_cls).sum()


print('IoU metric functions defined.')

In [None]:
# --- Phase 1 optimizer (frozen encoder) ---
LR_PHASE1 = 1e-3
EPOCHS_PHASE1 = 5

optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=LR_PHASE1,
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS_PHASE1)

print(f'Phase 1 config:')
print(f'  Optimizer:     Adam')
print(f'  Learning rate: {LR_PHASE1}')
print(f'  Epochs:        {EPOCHS_PHASE1}')
print(f'  Scheduler:     CosineAnnealingLR')

---
## 6. Training Loop

Two-phase transfer learning:

| Phase | What trains | LR | Epochs | Goal |
|-------|------------|------|--------|------|
| 1 | Decoder only | 1e-3 | 5 | Learn to decode ImageNet features into segmentation masks |
| 2 | Everything | 1e-4 | 15 | Adapt encoder features to rice field weed patterns |

In [None]:
def train_one_epoch(model, loader, criterion, optimizer, device, num_classes):
    """Train for one epoch. Returns loss and mIoU."""
    model.train()
    running_loss = 0.0
    total_samples = 0
    intersection_sum = np.zeros(num_classes)
    union_sum = np.zeros(num_classes)

    for images, masks in loader:
        images = images.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True)

        optimizer.zero_grad()
        logits = model(images)  # (B, C, H, W)
        loss = criterion(logits, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        total_samples += images.size(0)

        # Accumulate IoU
        preds = logits.argmax(dim=1)  # (B, H, W)
        update_iou_accumulators(
            preds.cpu().numpy(), masks.cpu().numpy(),
            intersection_sum, union_sum, num_classes,
        )

    avg_loss = running_loss / total_samples
    _, miou = compute_epoch_iou(intersection_sum, union_sum, num_classes)

    return avg_loss, miou


def validate(model, loader, criterion, device, num_classes):
    """Validate the model. Returns loss, mIoU, and per-class IoUs."""
    model.eval()
    running_loss = 0.0
    total_samples = 0
    intersection_sum = np.zeros(num_classes)
    union_sum = np.zeros(num_classes)

    with torch.no_grad():
        for images, masks in loader:
            images = images.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)

            logits = model(images)
            loss = criterion(logits, masks)

            running_loss += loss.item() * images.size(0)
            total_samples += images.size(0)

            preds = logits.argmax(dim=1)
            update_iou_accumulators(
                preds.cpu().numpy(), masks.cpu().numpy(),
                intersection_sum, union_sum, num_classes,
            )

    avg_loss = running_loss / total_samples
    ious, miou = compute_epoch_iou(intersection_sum, union_sum, num_classes)

    return avg_loss, miou, ious


print('Training functions defined.')

### Phase 1: Frozen Encoder

Only the decoder trains. The encoder acts as a fixed feature extractor
using ImageNet-learned representations (edges, textures, shapes).
This is fast — few parameters to update.

In [None]:
# --- W&B run init ---
if USE_WANDB:
    wandb.init(
        project='agri-weed-detection',
        config={
            'model': 'deeplabv3plus_resnet50',
            'task': 'segmentation',
            'dataset': 'riceseg',
            'crop_size': CROP_SIZE,
            'batch_size': BATCH_SIZE,
            'lr_phase1': LR_PHASE1,
            'epochs_phase1': EPOCHS_PHASE1,
            'gamma': GAMMA,
            'num_classes': NUM_CLASSES,
            'train_size': len(train_df) if train_df is not None else 0,
            'val_size': len(val_df) if val_df is not None else 0,
            'platform': PLATFORM,
        },
        tags=['baseline', 'segmentation', 'riceseg'],
    )
    print(f'W&B run: {wandb.run.name}')
    print(f'W&B URL: {wandb.run.get_url()}')

# --- Checkpoint tracking ---
SAVE_DIR = Path('/kaggle/working') if IS_KAGGLE else Path('./checkpoints')
SAVE_DIR.mkdir(parents=True, exist_ok=True)
best_miou = 0.0
best_model_path = SAVE_DIR / 'deeplabv3plus_resnet50_riceseg_v1.pt'

history = {
    'train_loss': [], 'val_loss': [],
    'train_miou': [], 'val_miou': [],
    'val_weed_iou': [], 'lr': [],
}

print(f'\n=== Phase 1: Frozen Encoder ===')
print(f'Training decoder only ({count_parameters(model)[0]:,} params)')
print(f'Checkpoint saves to: {best_model_path}\n')

if train_df is not None:
    for epoch in range(EPOCHS_PHASE1):
        t0 = time.time()

        train_loss, train_miou = train_one_epoch(
            model, train_loader, criterion, optimizer, DEVICE, NUM_CLASSES,
        )
        val_loss, val_miou, val_ious = validate(
            model, val_loader, criterion, DEVICE, NUM_CLASSES,
        )

        scheduler.step()
        current_lr = scheduler.get_last_lr()[0]
        elapsed = time.time() - t0

        weed_iou = val_ious.get(4, float('nan'))

        # Track history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_miou'].append(train_miou)
        history['val_miou'].append(val_miou)
        history['val_weed_iou'].append(weed_iou)
        history['lr'].append(current_lr)

        # Save best model by mIoU
        marker = ''
        if val_miou > best_miou:
            best_miou = val_miou
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'val_miou': val_miou,
                'val_loss': val_loss,
                'val_ious': val_ious,
                'num_classes': NUM_CLASSES,
                'class_names': RICESEG_CLASSES,
                'crop_size': CROP_SIZE,
            }, best_model_path)
            marker = ' * (best)'

        print(f'Epoch [{epoch+1:2d}/{EPOCHS_PHASE1}] '
              f'loss={train_loss:.4f}/{val_loss:.4f} '
              f'mIoU={train_miou:.4f}/{val_miou:.4f} '
              f'weed_IoU={weed_iou:.4f} '
              f'lr={current_lr:.6f} '
              f'({elapsed:.0f}s){marker}')

        if USE_WANDB:
            wandb.log({
                'epoch': epoch + 1, 'phase': 1,
                'train_loss': train_loss, 'val_loss': val_loss,
                'train_miou': train_miou, 'val_miou': val_miou,
                'val_weed_iou': weed_iou, 'lr': current_lr,
            })

    print(f'\nPhase 1 complete. Best mIoU: {best_miou:.4f}')
else:
    print('No training data. Skipping Phase 1.')

### Phase 2: Full Fine-Tuning

Unfreeze the encoder and train the entire model with a lower learning rate (1e-4).
This lets the ResNet-50 backbone adapt its features from generic ImageNet patterns
to rice-field-specific patterns (leaf textures, weed shapes, color profiles).

**Why lower LR?** The encoder weights are already well-trained on ImageNet. Large updates
would destroy learned features ("catastrophic forgetting"). Small, careful updates let the
backbone specialize without losing its foundation.

In [None]:
if train_df is not None:
    LR_PHASE2 = 1e-4
    EPOCHS_PHASE2 = 15

    # Unfreeze everything
    unfreeze_all(model)
    trainable, total = count_parameters(model)

    # New optimizer for all parameters
    optimizer = optim.Adam(model.parameters(), lr=LR_PHASE2)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS_PHASE2)

    print(f'=== Phase 2: Full Fine-Tuning ===')
    print(f'All parameters unfrozen ({trainable:,} trainable)\n')

    for epoch in range(EPOCHS_PHASE2):
        t0 = time.time()
        global_epoch = EPOCHS_PHASE1 + epoch

        train_loss, train_miou = train_one_epoch(
            model, train_loader, criterion, optimizer, DEVICE, NUM_CLASSES,
        )
        val_loss, val_miou, val_ious = validate(
            model, val_loader, criterion, DEVICE, NUM_CLASSES,
        )

        scheduler.step()
        current_lr = scheduler.get_last_lr()[0]
        elapsed = time.time() - t0

        weed_iou = val_ious.get(4, float('nan'))

        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_miou'].append(train_miou)
        history['val_miou'].append(val_miou)
        history['val_weed_iou'].append(weed_iou)
        history['lr'].append(current_lr)

        marker = ''
        if val_miou > best_miou:
            best_miou = val_miou
            torch.save({
                'epoch': global_epoch,
                'model_state_dict': model.state_dict(),
                'val_miou': val_miou,
                'val_loss': val_loss,
                'val_ious': val_ious,
                'num_classes': NUM_CLASSES,
                'class_names': RICESEG_CLASSES,
                'crop_size': CROP_SIZE,
            }, best_model_path)
            marker = ' * (best)'

        print(f'Epoch [{global_epoch+1:2d}/{EPOCHS_PHASE1 + EPOCHS_PHASE2}] '
              f'loss={train_loss:.4f}/{val_loss:.4f} '
              f'mIoU={train_miou:.4f}/{val_miou:.4f} '
              f'weed_IoU={weed_iou:.4f} '
              f'lr={current_lr:.6f} '
              f'({elapsed:.0f}s){marker}')

        if USE_WANDB:
            wandb.log({
                'epoch': global_epoch + 1, 'phase': 2,
                'train_loss': train_loss, 'val_loss': val_loss,
                'train_miou': train_miou, 'val_miou': val_miou,
                'val_weed_iou': weed_iou, 'lr': current_lr,
            })

    print(f'\nPhase 2 complete. Best mIoU: {best_miou:.4f}')
else:
    print('No training data. Skipping Phase 2.')

### Training History

Visualize how loss, mIoU, and weed IoU evolved across both phases.

In [None]:
if len(history['train_loss']) > 0:
    fig, axes = plt.subplots(2, 2, figsize=(16, 10))
    epochs_range = range(1, len(history['train_loss']) + 1)
    phase_boundary = EPOCHS_PHASE1

    # Loss
    axes[0, 0].plot(epochs_range, history['train_loss'], 'b-', label='Train')
    axes[0, 0].plot(epochs_range, history['val_loss'], 'r-', label='Val')
    axes[0, 0].axvline(x=phase_boundary, color='gray', linestyle='--', alpha=0.5, label='Unfreeze')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Focal Loss')
    axes[0, 0].set_title('Loss')
    axes[0, 0].legend()

    # mIoU
    axes[0, 1].plot(epochs_range, history['train_miou'], 'b-', label='Train')
    axes[0, 1].plot(epochs_range, history['val_miou'], 'r-', label='Val')
    axes[0, 1].axvline(x=phase_boundary, color='gray', linestyle='--', alpha=0.5, label='Unfreeze')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('mIoU')
    axes[0, 1].set_title('Mean IoU (all classes)')
    axes[0, 1].legend()

    # Weed IoU
    weed_ious_valid = [v if not np.isnan(v) else 0 for v in history['val_weed_iou']]
    axes[1, 0].plot(epochs_range[:len(weed_ious_valid)],
                    weed_ious_valid, 'r-o', markersize=4, label='Weed IoU')
    axes[1, 0].axvline(x=phase_boundary, color='gray', linestyle='--', alpha=0.5, label='Unfreeze')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Weed IoU')
    axes[1, 0].set_title('Weed Class IoU (the metric that matters)')
    axes[1, 0].legend()

    # Learning rate
    axes[1, 1].plot(epochs_range, history['lr'], 'g-')
    axes[1, 1].axvline(x=phase_boundary, color='gray', linestyle='--', alpha=0.5, label='Unfreeze')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Learning Rate')
    axes[1, 1].set_title('Learning Rate Schedule')
    axes[1, 1].legend()

    plt.suptitle('Training History (Phase 1: frozen encoder -> Phase 2: full fine-tune)',
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
else:
    print('No training history to plot.')

---
## 7. Evaluation

Load the best checkpoint and run a thorough evaluation:
- Per-class IoU bar chart
- Visual predictions (image, ground truth mask, predicted mask, error map)

In [None]:
if train_df is not None and best_model_path.exists():
    # Load best checkpoint
    checkpoint = torch.load(best_model_path, map_location=DEVICE, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f'Loaded best model from epoch {checkpoint["epoch"] + 1}')
    print(f'Best val mIoU: {checkpoint["val_miou"]:.4f}')

    # Run final validation
    val_loss, val_miou, val_ious = validate(
        model, val_loader, criterion, DEVICE, NUM_CLASSES,
    )

    print(f'\n=== Final Validation Results ===')
    print(f'Loss:       {val_loss:.4f}')
    print(f'mIoU:       {val_miou:.4f} ({val_miou*100:.1f}%)')
    print(f'\nPer-class IoU:')
    for cls_id in range(NUM_CLASSES):
        iou = val_ious.get(cls_id, float('nan'))
        bar_len = int(iou * 40) if not np.isnan(iou) else 0
        bar_str = '#' * bar_len
        print(f'  {RICESEG_CLASSES[cls_id]:25s}: {iou:.4f} ({iou*100:.1f}%) {bar_str}')

    weed_iou = val_ious.get(4, float('nan'))
    print(f'\n>>> Weed IoU: {weed_iou:.4f} ({weed_iou*100:.1f}%)')
else:
    print('No model to evaluate.')

In [None]:
if train_df is not None and best_model_path.exists():
    class_names_list = [RICESEG_CLASSES[i] for i in range(NUM_CLASSES)]
    iou_values = [val_ious.get(i, 0) for i in range(NUM_CLASSES)]
    bar_colors_list = [np.array(CLASS_COLORS[i]) / 255.0 for i in range(NUM_CLASSES)]

    fig, ax = plt.subplots(figsize=(12, 5))
    bars = ax.barh(class_names_list, iou_values, color=bar_colors_list, edgecolor='gray')
    ax.set_xlabel('IoU')
    ax.set_title(f'Per-Class IoU | mIoU: {val_miou:.4f}')
    ax.set_xlim(0, 1)
    ax.axvline(x=val_miou, color='gray', linestyle='--', alpha=0.5, label=f'mIoU={val_miou:.3f}')

    for bar, iou in zip(bars, iou_values):
        label = f'{iou:.3f}' if not np.isnan(iou) else 'N/A'
        ax.text(bar.get_width() + 0.01, bar.get_y() + bar.get_height() / 2,
                label, va='center', fontsize=10)

    ax.legend()
    plt.tight_layout()
    plt.show()

    if USE_WANDB:
        wandb.log({'per_class_iou': wandb.Image(fig)})
        for cls_id in range(NUM_CLASSES):
            wandb.log({f'iou_{RICESEG_CLASSES[cls_id]}': val_ious.get(cls_id, 0)})

### Visual Predictions

Compare ground truth masks with model predictions. The error map highlights
where the model disagrees with the ground truth:
- **Correct pixels:** dark
- **Incorrect pixels:** bright (white)

Pay attention to weed regions (red in ground truth) — does the model detect them?

In [None]:
if train_df is not None and best_model_path.exists():
    model.eval()

    # Select validation samples — prefer ones with weed pixels
    sample_indices = []
    for i in range(len(val_df)):
        mask = np.array(Image.open(val_df.iloc[i]['mask_path']))
        if mask.ndim == 3:
            mask = mask[:, :, 0]
        if np.sum(mask == 4) > 0:  # Has weed pixels
            sample_indices.append(i)
        if len(sample_indices) >= 6:
            break

    # If not enough weed samples, add random ones
    if len(sample_indices) < 6:
        remaining = [i for i in range(min(50, len(val_df))) if i not in sample_indices]
        sample_indices.extend(remaining[:6 - len(sample_indices)])

    n_samples = len(sample_indices)
    fig, axes = plt.subplots(n_samples, 4, figsize=(20, 4 * n_samples))
    if n_samples == 1:
        axes = axes.reshape(1, -1)

    with torch.no_grad():
        for row_idx, data_idx in enumerate(sample_indices):
            row = val_df.iloc[data_idx]
            image = np.array(Image.open(row['image_path']).convert('RGB'))
            gt_mask = np.array(Image.open(row['mask_path']))
            if gt_mask.ndim == 3:
                gt_mask = gt_mask[:, :, 0]

            # Run model
            transformed = val_transform(image=image, mask=gt_mask)
            img_tensor = transformed['image'].unsqueeze(0).to(DEVICE)
            logits = model(img_tensor)
            pred_mask = logits.argmax(dim=1).squeeze(0).cpu().numpy()

            # Visualize
            gt_rgb = mask_to_rgb(gt_mask, CLASS_COLORS)
            pred_rgb = mask_to_rgb(pred_mask, CLASS_COLORS)
            error_map = (gt_mask != pred_mask).astype(np.uint8) * 255

            axes[row_idx, 0].imshow(image)
            axes[row_idx, 0].set_title('Image', fontsize=10)
            axes[row_idx, 1].imshow(gt_rgb)
            axes[row_idx, 1].set_title('Ground Truth', fontsize=10)
            axes[row_idx, 2].imshow(pred_rgb)
            axes[row_idx, 2].set_title('Prediction', fontsize=10)
            axes[row_idx, 3].imshow(error_map, cmap='gray')
            axes[row_idx, 3].set_title('Error Map', fontsize=10)

            for ax in axes[row_idx]:
                ax.axis('off')

    plt.suptitle('Validation Predictions (prioritizing images with weeds)',
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

    # Legend
    print('Mask colors:')
    for cls_id, name in RICESEG_CLASSES.items():
        r, g, b = CLASS_COLORS[cls_id]
        print(f'  {cls_id}: {name} -- RGB({r}, {g}, {b})')

    if USE_WANDB:
        wandb.log({'predictions': wandb.Image(fig)})

---
## 8. Export Model

Save the final model checkpoint for use in notebook 05 (inference pipeline).

In [None]:
if best_model_path.exists():
    model_size_mb = best_model_path.stat().st_size / (1024 * 1024)

    print(f'=== Model Export Summary ===')
    print(f'Model:          DeepLabV3+ (ResNet-50)')
    print(f'Dataset:        RiceSEG ({NUM_CLASSES} classes)')
    print(f'Best val mIoU:  {best_miou:.4f} ({best_miou*100:.1f}%)')
    print(f'Checkpoint:     {best_model_path}')
    print(f'Model size:     {model_size_mb:.1f} MB')
    print(f'Input:          3xHxW RGB (any size, trained at {CROP_SIZE}x{CROP_SIZE})')
    print(f'Output:         {NUM_CLASSES}xHxW logits (argmax -> class IDs)')
    print(f'Normalization:  ImageNet (mean={IMAGENET_MEAN}, std={IMAGENET_STD})')

    print(f'\nCheckpoint contents:')
    checkpoint = torch.load(best_model_path, map_location='cpu', weights_only=False)
    for key in checkpoint.keys():
        print(f'  - {key}')

    if USE_WANDB:
        wandb.summary['best_val_miou'] = best_miou
        wandb.summary['model_size_mb'] = model_size_mb
        wandb.save(str(best_model_path))
        wandb.finish()
        print(f'\nW&B run finished. Model artifact saved.')
    elif 'wandb' in dir() and hasattr(wandb, 'run') and wandb.run is not None:
        wandb.finish()
else:
    print('No model checkpoint found.')
    if USE_WANDB and hasattr(wandb, 'run') and wandb.run is not None:
        wandb.finish()

---
## 9. What's Next

| Next Step | Notebook | What It Does |
|-----------|----------|-------------|
| **Inference pipeline** | `05-inference-pipeline.ipynb` | Run trained classification + segmentation models on new images |
| **Improve weed IoU** | This notebook | Try: larger crop, more epochs, different backbone (ResNet-101), Dice+Focal combo loss |
| **Classification** | `04-classification-baseline.ipynb` | Train EfficientNetV2-S on Crop & Weed or Bangladesh data |

### Ideas to Improve Weed IoU

1. **Dice + Focal combo loss:** Dice loss directly optimizes IoU. Combining it with focal loss often works better than either alone.
2. **Larger backbone:** ResNet-101 or EfficientNet-B4 may capture finer weed textures.
3. **Oversample weed images:** Create a weighted sampler that presents weed-containing images more often.
4. **Philippines-only training:** Smaller dataset but most relevant to Indonesian conditions.
5. **Test-time augmentation (TTA):** Average predictions from flipped/rotated versions of each image.