# 04 — Classification Baseline: EfficientNetV2-S — Rice Field Weed Detection

**Purpose:** Train a classification model to identify weed species using contextually-relevant datasets.  
**Runtime:** GPU required for training (Kaggle P100 recommended). Data prep runs on CPU.  
**Platform:** Works on Kaggle, Colab, and local (with GPU).

## What This Notebook Does

1. **Setup** — Platform detection, dependency install, W&B logging init
2. **Data pipeline** — Stratified train/val split, augmentation transforms, PyTorch DataLoaders
3. **Model** — EfficientNetV2-S (pretrained ImageNet), dynamic classifier head
4. **Training Phase 1** — Frozen backbone, train classifier head only (10 epochs)
5. **Training Phase 2** — Unfreeze backbone, fine-tune end-to-end (10 epochs, lower LR)
6. **Evaluation** — Confusion matrix, per-class F1, correct/misclassified examples
7. **Export** — Save best model checkpoint

### Dataset Options

| Dataset | Classes | Expected Accuracy | Config Key |
|---------|---------|-------------------|------------|
| **Crop & Weed Detection** (default) | 2 (crop, weed) | 85-95% | `crop_weed_yolo` |
| **Bangladesh Rice Field Weed** | 11 species | 65-80% (frozen) / 80-90% (fine-tuned) | `bangladesh_rice_weed` |

**Default:** Crop & Weed Detection — available on Kaggle, no extra setup needed.  
**Recommended upgrade:** Bangladesh Rice Field Weed — more classes, better species-level relevance.

---
## 1. Platform Detection & Setup

Same pattern as notebook 01 — detect Kaggle vs Colab vs local, then set paths and install dependencies.

In [None]:
import os
import sys

# --- Platform Detection ---
IS_KAGGLE = os.path.exists('/kaggle/input')

try:
    import google.colab
    IS_COLAB = True
except ImportError:
    IS_COLAB = False

IS_LOCAL = not IS_KAGGLE and not IS_COLAB

PLATFORM = 'kaggle' if IS_KAGGLE else ('colab' if IS_COLAB else 'local')
print(f'Platform detected: {PLATFORM}')
print(f'Python version: {sys.version}')

### Install Training Dependencies

These are heavier than notebook 01 — we need PyTorch, `timm` (pretrained models), `albumentations` (augmentation), and `wandb` (experiment tracking).

**On Kaggle/Colab:** PyTorch and torchvision are pre-installed. We only need `timm`, `albumentations`, and `wandb`.

In [None]:
import subprocess

packages = ['timm', 'albumentations', 'wandb', 'scikit-learn']

for pkg in packages:
    try:
        __import__(pkg)
    except ImportError:
        print(f'Installing {pkg}...')
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', pkg])

print('All dependencies ready.')

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from collections import Counter
from PIL import Image
import json
import time
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report

import wandb

# Consistent plot style
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['figure.dpi'] = 100
plt.rcParams['font.size'] = 11

# Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')
if DEVICE.type == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')
else:
    print('WARNING: No GPU detected. Training will be very slow.')

print('\nImports ready.')

### W&B Experiment Tracking (Optional)

Weights & Biases logs training metrics, plots, and model artifacts to the cloud. This is optional but highly recommended for comparing runs.

- **First time?** Get your API key at [wandb.ai/authorize](https://wandb.ai/authorize)
- **Don't want W&B?** Set `USE_WANDB = False` below — metrics still print to stdout

In [None]:
USE_WANDB = True  # Set to False to skip W&B logging

if USE_WANDB:
    try:
        wandb.login()
        print('W&B login successful.')
    except Exception as e:
        print(f'W&B login failed: {e}')
        print('Continuing without W&B. Set USE_WANDB = False to suppress this.')
        USE_WANDB = False
else:
    print('W&B disabled. Metrics will print to stdout only.')

---
## 2. Load Dataset

Reuse the same dataset loading pattern from notebook 01. DeepWeeds provides a CSV mapping filenames to class labels (0-8).

In [None]:
# ============================================================
# DATASET CONFIGURATION — Change ACTIVE_DATASET to switch
# ============================================================

DATASET_CONFIGS = {
    'crop_weed_yolo': {
        'name': 'Crop & Weed Detection',
        'format': 'yolo',
        'paths': {
            'kaggle': '/kaggle/input/crop-and-weed-detection-data-with-bounding-boxes',
            'colab': '/content/crop_weed_yolo',
            'local': './data/crop_weed_yolo',
        },
    },
    'bangladesh_rice_weed': {
        'name': 'Bangladesh Rice Field Weed',
        'format': 'folder',
        'paths': {
            'kaggle': '/kaggle/input/bangladesh-rice-field-weed',
            'colab': '/content/bangladesh_rice_weed',
            'local': './data/bangladesh_rice_weed',
        },
    },
}

# >>> CHANGE THIS to switch datasets <<<
ACTIVE_DATASET = 'crop_weed_yolo'

config = DATASET_CONFIGS[ACTIVE_DATASET]

# --- Dataset path ---
if IS_KAGGLE:
    DATA_ROOT = Path(config['paths']['kaggle'])
elif IS_COLAB:
    DATA_ROOT = Path(config['paths']['colab'])
else:
    DATA_ROOT = Path(config['paths']['local'])

print(f'Active dataset: {config["name"]}')
print(f'Format: {config["format"]}')
print(f'Data root: {DATA_ROOT}')
print(f'Exists: {DATA_ROOT.exists()}')

if not DATA_ROOT.exists():
    print()
    if ACTIVE_DATASET == 'crop_weed_yolo':
        print('Add the dataset on Kaggle: search "crop and weed detection"')
    elif ACTIVE_DATASET == 'bangladesh_rice_weed':
        print('Upload from Mendeley Data: https://data.mendeley.com/datasets/mt72bmxz73/4')
        print('Create a private Kaggle Dataset named "bangladesh-rice-field-weed"')

In [None]:
def find_image_path(filename):
    """Find the full path of an image file across different dataset structures."""
    p = Path(filename)
    if p.is_absolute() and p.exists():
        return str(p)
    direct = DATA_ROOT / filename
    if direct.exists():
        return str(direct)
    for subdir in ['images', 'train', 'data', 'img']:
        candidate = DATA_ROOT / subdir / filename
        if candidate.exists():
            return str(candidate)
    if DATA_ROOT.exists():
        matches = list(DATA_ROOT.rglob(Path(filename).name))
        if matches:
            return str(matches[0])
    return None


def load_classes_txt(data_root):
    """Load class names from classes.txt or similar files."""
    if not data_root.exists():
        return None
    for candidate in ['classes.txt', 'obj.names', 'data.names']:
        for search_root in [data_root] + [d for d in data_root.iterdir() if d.is_dir()]:
            path = search_root / candidate
            if path.exists():
                with open(path) as f:
                    names = [line.strip() for line in f if line.strip()]
                return {i: name for i, name in enumerate(names)}
    return None


# --- Load dataset based on format ---
df = pd.DataFrame()
CLASS_NAMES = {}
NUM_CLASSES = 0
LABEL_COL = 'label'

if not DATA_ROOT.exists():
    print(f'Data directory not found: {DATA_ROOT}')
    print('Remaining cells will show "no data" messages.')
elif config['format'] == 'yolo':
    image_extensions = {'.jpg', '.jpeg', '.png', '.bmp'}
    all_images = sorted(f for f in DATA_ROOT.rglob('*') if f.suffix.lower() in image_extensions)

    CLASS_NAMES = load_classes_txt(DATA_ROOT)
    if CLASS_NAMES is None:
        CLASS_NAMES = {0: 'crop', 1: 'weed'}

    rows = []
    for img_path in all_images:
        txt_path = img_path.with_suffix('.txt')
        if not txt_path.exists():
            rel = img_path.relative_to(DATA_ROOT)
            parts = list(rel.parts)
            for i, part in enumerate(parts):
                if part.lower() in ('images', 'image', 'img'):
                    parts[i] = 'labels'
                    alt = DATA_ROOT / Path(*parts)
                    alt = alt.with_suffix('.txt')
                    if alt.exists():
                        txt_path = alt
                        break

        if txt_path.exists():
            with open(txt_path) as f:
                class_ids = []
                for line in f:
                    parts = line.strip().split()
                    if len(parts) >= 5:
                        class_ids.append(int(parts[0]))

            if class_ids:
                dominant_class = Counter(class_ids).most_common(1)[0][0]
                rows.append({
                    'image_path': str(img_path),
                    ACTIVE_DATASET + '_label': dominant_class,
                })

    df = pd.DataFrame(rows)
    LABEL_COL = ACTIVE_DATASET + '_label'

elif config['format'] == 'folder':
    image_extensions = {'.jpg', '.jpeg', '.png', '.bmp'}
    subdirs = sorted(d for d in DATA_ROOT.iterdir() if d.is_dir())

    if len(subdirs) == 1:
        deeper = subdirs[0]
        deeper_subdirs = sorted(d for d in deeper.iterdir() if d.is_dir())
        if len(deeper_subdirs) > 1:
            subdirs = deeper_subdirs

    CLASS_NAMES = {}
    rows = []
    for idx, subdir in enumerate(subdirs):
        imgs = [f for f in subdir.iterdir() if f.suffix.lower() in image_extensions]
        if imgs:
            CLASS_NAMES[idx] = subdir.name
            for img in imgs:
                rows.append({'image_path': str(img), 'label': idx})

    df = pd.DataFrame(rows)
    LABEL_COL = 'label'

NUM_CLASSES = len(CLASS_NAMES)
CLASS_LIST = [CLASS_NAMES[i] for i in range(NUM_CLASSES)] if NUM_CLASSES > 0 else []
HAS_DATA = len(df) > 0

if HAS_DATA:
    missing_mask = df['image_path'].apply(lambda p: not Path(p).exists())
    if missing_mask.any():
        print(f'WARNING: {missing_mask.sum()} images not found. Dropping them.')
        df = df[~missing_mask].reset_index(drop=True)
    print(f'Dataset: {config["name"]}')
    print(f'Images: {len(df)}')
    print(f'Classes ({NUM_CLASSES}): {CLASS_NAMES}')
else:
    print(f'No data loaded. Downstream cells will skip.')

### Class Distribution & Weight Calculation

If the class imbalance ratio is >3x, we use inverse-frequency weights in the loss function. This penalizes mistakes on rare classes more heavily, preventing the model from ignoring them.

In [None]:
if not HAS_DATA:
    class_weights = None
    USE_WEIGHTED_LOSS = False
    print('No data available for class distribution analysis.')
else:
    class_counts = df[LABEL_COL].value_counts().sort_index()
    imbalance_ratio = class_counts.max() / class_counts.min()

    print('Class distribution:')
    for idx, count in class_counts.items():
        pct = count / len(df) * 100
        name = CLASS_NAMES.get(idx, f'class_{idx}')
        print(f'  {name:30s} {count:>5,} images ({pct:5.1f}%)')
    print(f'\nImbalance ratio: {imbalance_ratio:.1f}x')

    USE_WEIGHTED_LOSS = imbalance_ratio > 3.0

    if USE_WEIGHTED_LOSS:
        total = len(df)
        class_weights = torch.tensor(
            [total / (NUM_CLASSES * class_counts.get(i, 1)) for i in range(NUM_CLASSES)],
            dtype=torch.float32
        )
        class_weights = class_weights / class_weights.mean()
        print(f'\nUsing weighted loss. Weights:')
        for i in range(NUM_CLASSES):
            print(f'  {CLASS_NAMES.get(i, f"class_{i}"):30s} {class_weights[i]:.3f}')
    else:
        class_weights = None
        print(f'\nImbalance ratio <=3x. Using standard (unweighted) CrossEntropyLoss.')

---
## 3. Train/Validation Split

**Strategy:** 80% train / 20% validation, **stratified** by class label.

Stratified splitting ensures each split has roughly the same class proportions as the full dataset. Without this, a random split might put very few examples of a rare class into validation, making metrics unreliable.

In [None]:
if not HAS_DATA:
    train_df = pd.DataFrame()
    val_df = pd.DataFrame()
    print('No data available for splitting.')
else:
    train_df, val_df = train_test_split(
        df,
        test_size=0.2,
        stratify=df[LABEL_COL],
        random_state=42
    )
    train_df = train_df.reset_index(drop=True)
    val_df = val_df.reset_index(drop=True)

    print(f'Train: {len(train_df):,} images')
    print(f'Val:   {len(val_df):,} images')
    print(f'\nTrain class distribution:')
    print(train_df[LABEL_COL].value_counts().sort_index())
    print(f'\nVal class distribution:')
    print(val_df[LABEL_COL].value_counts().sort_index())

---
## 4. Data Pipeline

### Augmentation Strategy

**Training transforms** add variety to prevent overfitting:
- **Geometric:** horizontal/vertical flip, rotation (30 deg), slight shift/scale
- **Color:** brightness, contrast, hue, saturation jitter
- **Normalize:** ImageNet statistics (required for pretrained backbone)

**Validation transforms** only resize and normalize — no augmentation. We want a clean evaluation.

We use [Albumentations](https://albumentations.ai/) instead of torchvision transforms because it's faster and has more options.

In [None]:
IMG_SIZE = 224  # EfficientNetV2-S default input size

# ImageNet normalization (required for pretrained backbone)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

train_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),  # Resize first — images may vary in size (smartphone photos)
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Rotate(limit=30, p=0.5),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0, p=0.3),
    A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.5),
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),  # Resize — images may not be uniform size
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ToTensorV2(),
])

print(f'Image size: {IMG_SIZE}x{IMG_SIZE}')
print(f'Train augmentations: resize, flip, rotate, shift/scale, color jitter')
print(f'Val augmentations: resize + normalize only')

In [None]:
class WeedClassificationDataset(Dataset):
    """PyTorch dataset for weed classification."""

    def __init__(self, dataframe, label_col, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.label_col = label_col
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        img = Image.open(row['image_path']).convert('RGB')
        img = np.array(img)

        if self.transform:
            img = self.transform(image=img)['image']

        label = int(row[self.label_col])
        return img, label


if not HAS_DATA:
    train_dataset = None
    val_dataset = None
    print('No data to create datasets.')
else:
    train_dataset = WeedClassificationDataset(train_df, LABEL_COL, transform=train_transform)
    val_dataset = WeedClassificationDataset(val_df, LABEL_COL, transform=val_transform)

    print(f'Train dataset: {len(train_dataset)} samples')
    print(f'Val dataset:   {len(val_dataset)} samples')

    img, label = train_dataset[0]
    print(f'\nSample shape: {img.shape} (C, H, W)')
    print(f'Sample label: {label} ({CLASS_NAMES.get(label, "?")})')
    print(f'Pixel range: [{img.min():.2f}, {img.max():.2f}] (normalized)')

In [None]:
BATCH_SIZE = 32
NUM_WORKERS = 2

if not HAS_DATA:
    train_loader = None
    val_loader = None
    print('No data for DataLoaders.')
else:
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        drop_last=True,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )

    print(f'Batch size: {BATCH_SIZE}')
    print(f'Train batches: {len(train_loader)}')
    print(f'Val batches:   {len(val_loader)}')

### Preview Augmented Samples

Let's see what the training data looks like after augmentation. This is a sanity check — the images should look like plausible weed photos, not distorted beyond recognition.

In [None]:
if not HAS_DATA:
    print('No data to preview augmented samples.')
else:
    def denormalize(tensor, mean=IMAGENET_MEAN, std=IMAGENET_STD):
        """Reverse ImageNet normalization for display."""
        mean = torch.tensor(mean).view(3, 1, 1)
        std = torch.tensor(std).view(3, 1, 1)
        return (tensor * std + mean).clamp(0, 1)

    fig, axes = plt.subplots(2, 6, figsize=(18, 6))
    for i, ax in enumerate(axes.flat):
        img, label = train_dataset[i]
        img_display = denormalize(img).permute(1, 2, 0).numpy()
        ax.imshow(img_display)
        ax.set_title(CLASS_NAMES[label], fontsize=9)
        ax.axis('off')

    plt.suptitle('Augmented Training Samples', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

---
## 5. Model Setup

### Why EfficientNetV2-S?

| Property | Value |
|----------|-------|
| Parameters | ~21M |
| ImageNet top-1 | ~84% |
| Input size | 224x224 (default) |
| Architecture | Fused-MBConv + MBConv blocks |
| Training speed | 5-11x faster than EfficientNetV1 |

EfficientNetV2-S is a strong modern baseline — accurate, fast, and small enough for a Kaggle P100. We load it with ImageNet-pretrained weights from `timm`.

### Transfer Learning Strategy

1. **Phase 1 (frozen):** Freeze backbone, only train the new classifier head. The backbone already knows generic image features (edges, textures, shapes) from ImageNet. We just need to teach the head which features correspond to which weed class.

2. **Phase 2 (unfrozen):** Unfreeze the entire model, fine-tune with a lower learning rate. This lets the backbone adapt its features to weed-specific patterns (leaf shapes, stem textures, color profiles).

In [None]:
def create_model(num_classes, pretrained=True):
    """Create EfficientNetV2-S with a new classifier head."""
    model = timm.create_model(
        'tf_efficientnetv2_s',
        pretrained=pretrained,
        num_classes=num_classes,
    )
    return model


def freeze_backbone(model):
    """Freeze all layers except the classifier head."""
    for param in model.parameters():
        param.requires_grad = False
    # Unfreeze the classifier head
    for param in model.classifier.parameters():
        param.requires_grad = True


def unfreeze_all(model):
    """Unfreeze all layers for end-to-end fine-tuning."""
    for param in model.parameters():
        param.requires_grad = True


def count_parameters(model):
    """Count trainable and total parameters."""
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    return trainable, total


# Create model
model = create_model(NUM_CLASSES, pretrained=True)
model = model.to(DEVICE)

# Start with frozen backbone
freeze_backbone(model)
trainable, total = count_parameters(model)

print(f'Model: EfficientNetV2-S')
print(f'Total parameters:     {total:>12,}')
print(f'Trainable parameters: {trainable:>12,} (classifier head only)')
print(f'Frozen parameters:    {total - trainable:>12,} (backbone)')

### Loss Function & Optimizer

- **Loss:** CrossEntropyLoss (with class weights if imbalanced)
- **Optimizer:** Adam with lr=1e-3 (Phase 1) or lr=1e-4 (Phase 2)
- **Scheduler:** CosineAnnealingLR — smoothly decays learning rate to near-zero

In [None]:
if not HAS_DATA:
    criterion = nn.CrossEntropyLoss()
    print('Loss: Standard CrossEntropyLoss (placeholder — no data)')
    print('No data available. Skipping optimizer setup.')
    LR_PHASE1 = 1e-3
    EPOCHS_PHASE1 = 10
else:
    if USE_WEIGHTED_LOSS and class_weights is not None:
        criterion = nn.CrossEntropyLoss(weight=class_weights.to(DEVICE))
        print('Loss: Weighted CrossEntropyLoss')
    else:
        criterion = nn.CrossEntropyLoss()
        print('Loss: Standard CrossEntropyLoss')

    LR_PHASE1 = 1e-3
    EPOCHS_PHASE1 = 10

    optimizer = optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=LR_PHASE1,
    )
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS_PHASE1)

    print(f'\nPhase 1 config:')
    print(f'  Optimizer: Adam')
    print(f'  Learning rate: {LR_PHASE1}')
    print(f'  Epochs: {EPOCHS_PHASE1}')
    print(f'  Scheduler: CosineAnnealingLR')

---
## 6. Training Loop

The training loop runs in two phases:

| Phase | Backbone | LR | Epochs | Goal |
|-------|----------|----|--------|---------|
| 1 | Frozen | 1e-3 | 10 | Train classifier head on DeepWeeds features |
| 2 | Unfrozen | 1e-4 | 10 | Fine-tune backbone for weed-specific features |

Phase 2 only runs if Phase 1 accuracy is <60% OR you want to push accuracy higher.

In [None]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    """Train for one epoch. Returns average loss and accuracy."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (images, labels) in enumerate(loader):
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    avg_loss = running_loss / total
    accuracy = correct / total
    return avg_loss, accuracy


def validate(model, loader, criterion, device):
    """Validate the model. Returns average loss, accuracy, all predictions, and all labels."""
    model.train(False)  # set to evaluation mode
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    avg_loss = running_loss / total
    accuracy = correct / total
    return avg_loss, accuracy, np.array(all_preds), np.array(all_labels)


print('Training functions defined.')

In [None]:
if not HAS_DATA:
    USE_WANDB = False
    print('No data. Skipping W&B init.')
elif USE_WANDB:
    wandb.init(
        project='agri-weed-detection',
        config={
            'model': 'efficientnetv2_s',
            'dataset': ACTIVE_DATASET,
            'dataset_name': config['name'],
            'img_size': IMG_SIZE,
            'batch_size': BATCH_SIZE,
            'lr_phase1': LR_PHASE1,
            'epochs_phase1': EPOCHS_PHASE1,
            'weighted_loss': USE_WEIGHTED_LOSS,
            'num_classes': NUM_CLASSES,
            'train_size': len(train_df),
            'val_size': len(val_df),
            'platform': PLATFORM,
        },
        tags=['baseline', 'classification', ACTIVE_DATASET],
    )
    print(f'W&B run: {wandb.run.name}')
    print(f'W&B URL: {wandb.run.get_url()}')
else:
    print('W&B disabled -- logging to stdout.')

### Phase 1: Frozen Backbone

Only the classifier head trains. This is fast because we only backpropagate through the final layer. The backbone acts as a fixed feature extractor.

In [None]:
SAVE_DIR = Path('/kaggle/working') if IS_KAGGLE else Path('./checkpoints')
SAVE_DIR.mkdir(parents=True, exist_ok=True)
best_val_acc = 0.0
best_model_path = SAVE_DIR / f'efficientnet_v2s_{ACTIVE_DATASET}_v1.pt'

history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [], 'lr': []}

if not HAS_DATA:
    print('No training data. Skipping Phase 1.')
else:
    print('=== Phase 1: Frozen Backbone ===')
    print(f'Training classifier head only ({count_parameters(model)[0]:,} params)')
    print(f'Checkpoint will save to: {best_model_path}\n')

    for epoch in range(EPOCHS_PHASE1):
        t0 = time.time()

        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, DEVICE)
        val_loss, val_acc, val_preds, val_labels = validate(model, val_loader, criterion, DEVICE)

        scheduler.step()
        current_lr = scheduler.get_last_lr()[0]
        elapsed = time.time() - t0

        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        history['lr'].append(current_lr)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'val_acc': val_acc,
                'val_loss': val_loss,
                'class_names': CLASS_NAMES,
                'num_classes': NUM_CLASSES,
                'img_size': IMG_SIZE,
                'dataset': ACTIVE_DATASET,
            }, best_model_path)
            marker = ' * (best)'
        else:
            marker = ''

        print(f'Epoch [{epoch+1:2d}/{EPOCHS_PHASE1}] '
              f'train_loss={train_loss:.4f} train_acc={train_acc:.4f} | '
              f'val_loss={val_loss:.4f} val_acc={val_acc:.4f} | '
              f'lr={current_lr:.6f} | {elapsed:.1f}s{marker}')

        if USE_WANDB:
            wandb.log({
                'epoch': epoch + 1,
                'phase': 1,
                'train_loss': train_loss,
                'val_loss': val_loss,
                'train_acc': train_acc,
                'val_acc': val_acc,
                'lr': current_lr,
            })

    print(f'\nPhase 1 complete. Best val accuracy: {best_val_acc:.4f}')

### Phase 2: Unfreeze Backbone (Fine-Tuning)

Now we unfreeze the entire model and train with a lower learning rate (1e-4). This lets the backbone adapt its features to weed-specific patterns.

**Why lower LR?** The backbone weights are already well-trained on ImageNet. We want small, careful updates — large updates would destroy the learned features ("catastrophic forgetting").

**When to skip Phase 2:** If Phase 1 accuracy is already >85%, Phase 2 may not add much. But running it is still recommended for best results.

In [None]:
LR_PHASE2 = 1e-4
EPOCHS_PHASE2 = 10

if not HAS_DATA:
    print('No training data. Skipping Phase 2.')
else:
    unfreeze_all(model)
    trainable, total = count_parameters(model)

    optimizer = optim.Adam(model.parameters(), lr=LR_PHASE2)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS_PHASE2)

    print('=== Phase 2: Full Fine-Tuning ===')
    print(f'All parameters unfrozen ({trainable:,} trainable)\n')

    for epoch in range(EPOCHS_PHASE2):
        t0 = time.time()
        global_epoch = EPOCHS_PHASE1 + epoch

        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, DEVICE)
        val_loss, val_acc, val_preds, val_labels = validate(model, val_loader, criterion, DEVICE)

        scheduler.step()
        current_lr = scheduler.get_last_lr()[0]
        elapsed = time.time() - t0

        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        history['lr'].append(current_lr)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': global_epoch,
                'model_state_dict': model.state_dict(),
                'val_acc': val_acc,
                'val_loss': val_loss,
                'class_names': CLASS_NAMES,
                'num_classes': NUM_CLASSES,
                'img_size': IMG_SIZE,
            }, best_model_path)
            marker = ' * (best)'
        else:
            marker = ''

        print(f'Epoch [{global_epoch+1:2d}/{EPOCHS_PHASE1 + EPOCHS_PHASE2}] '
              f'train_loss={train_loss:.4f} train_acc={train_acc:.4f} | '
              f'val_loss={val_loss:.4f} val_acc={val_acc:.4f} | '
              f'lr={current_lr:.6f} | {elapsed:.1f}s{marker}')

        if USE_WANDB:
            wandb.log({
                'epoch': global_epoch + 1,
                'phase': 2,
                'train_loss': train_loss,
                'val_loss': val_loss,
                'train_acc': train_acc,
                'val_acc': val_acc,
                'lr': current_lr,
            })

    print(f'\nPhase 2 complete. Best val accuracy: {best_val_acc:.4f}')

### Training History Plot

Visualize how loss and accuracy evolved across both phases.

In [None]:
if not history['train_loss']:
    print('No training history to plot.')
else:
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    epochs_range = range(1, len(history['train_loss']) + 1)
    phase_boundary = EPOCHS_PHASE1

    axes[0].plot(epochs_range, history['train_loss'], 'b-', label='Train')
    axes[0].plot(epochs_range, history['val_loss'], 'r-', label='Val')
    axes[0].axvline(x=phase_boundary, color='gray', linestyle='--', alpha=0.5, label='Unfreeze')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Loss')
    axes[0].legend()

    axes[1].plot(epochs_range, history['train_acc'], 'b-', label='Train')
    axes[1].plot(epochs_range, history['val_acc'], 'r-', label='Val')
    axes[1].axvline(x=phase_boundary, color='gray', linestyle='--', alpha=0.5, label='Unfreeze')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('Accuracy')
    axes[1].legend()

    axes[2].plot(epochs_range, history['lr'], 'g-')
    axes[2].axvline(x=phase_boundary, color='gray', linestyle='--', alpha=0.5, label='Unfreeze')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Learning Rate')
    axes[2].set_title('Learning Rate Schedule')
    axes[2].legend()

    plt.suptitle('Training History (Phase 1: frozen -> Phase 2: unfrozen)', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

---
## 7. Evaluation

Load the best checkpoint and run a thorough evaluation:
- Confusion matrix
- Per-class precision, recall, F1
- Correctly classified and misclassified examples

In [None]:
if not HAS_DATA or not best_model_path.exists():
    print('No model checkpoint found. Skipping evaluation.')
    val_preds = np.array([])
    val_labels = np.array([])
else:
    checkpoint = torch.load(best_model_path, map_location=DEVICE, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f'Loaded best model from epoch {checkpoint["epoch"] + 1}')
    print(f'Best val accuracy: {checkpoint["val_acc"]:.4f}')
    print(f'Best val loss: {checkpoint["val_loss"]:.4f}')

    val_loss, val_acc, val_preds, val_labels = validate(model, val_loader, criterion, DEVICE)
    print(f'\nFinal validation: loss={val_loss:.4f}, accuracy={val_acc:.4f}')

### Confusion Matrix

The confusion matrix shows which classes the model confuses. Diagonal entries = correct predictions. Off-diagonal = mistakes.

**What to look for:**
- Large off-diagonal values = classes the model confuses (visually similar?)
- A row with low diagonal value = a class the model struggles to recognize
- A column with many values = a class the model over-predicts

In [None]:
if len(val_preds) == 0:
    print('No predictions available. Skipping confusion matrix.')
else:
    cm = confusion_matrix(val_labels, val_preds)
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    fig, axes = plt.subplots(1, 2, figsize=(20, 8))

    im0 = axes[0].imshow(cm, interpolation='nearest', cmap='Blues')
    axes[0].set_title('Confusion Matrix (Counts)')
    for i in range(NUM_CLASSES):
        for j in range(NUM_CLASSES):
            axes[0].text(j, i, str(cm[i, j]), ha='center', va='center',
                         color='white' if cm[i, j] > cm.max() / 2 else 'black', fontsize=8)

    im1 = axes[1].imshow(cm_normalized, interpolation='nearest', cmap='Blues', vmin=0, vmax=1)
    axes[1].set_title('Confusion Matrix (Normalized by Row = Recall)')
    for i in range(NUM_CLASSES):
        for j in range(NUM_CLASSES):
            axes[1].text(j, i, f'{cm_normalized[i, j]:.2f}', ha='center', va='center',
                         color='white' if cm_normalized[i, j] > 0.5 else 'black', fontsize=8)

    for ax in axes:
        ax.set_xticks(range(NUM_CLASSES))
        ax.set_yticks(range(NUM_CLASSES))
        ax.set_xticklabels(CLASS_LIST, rotation=45, ha='right', fontsize=9)
        ax.set_yticklabels(CLASS_LIST, fontsize=9)
        ax.set_xlabel('Predicted')
        ax.set_ylabel('Actual')

    fig.colorbar(im0, ax=axes[0], fraction=0.046)
    fig.colorbar(im1, ax=axes[1], fraction=0.046)
    plt.suptitle('Classification Results', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

    if USE_WANDB:
        wandb.log({'confusion_matrix': wandb.Image(fig)})

In [None]:
if len(val_preds) == 0:
    print('No predictions available. Skipping classification report.')
else:
    report = classification_report(
        val_labels, val_preds,
        target_names=CLASS_LIST,
        digits=3,
        output_dict=True
    )

    print('Per-Class Classification Report:')
    print('=' * 70)
    print(classification_report(val_labels, val_preds, target_names=CLASS_LIST, digits=3))

    f1_scores = [report[name]['f1-score'] for name in CLASS_LIST]
    colors = ['green' if f1 > 0.8 else 'orange' if f1 > 0.6 else 'red' for f1 in f1_scores]

    fig, ax = plt.subplots(figsize=(12, 5))
    bars = ax.barh(CLASS_LIST, f1_scores, color=colors)
    ax.set_xlabel('F1 Score')
    ax.set_title('Per-Class F1 Score')
    ax.set_xlim(0, 1)
    ax.axvline(x=0.8, color='gray', linestyle='--', alpha=0.5, label='Good threshold')

    for bar, f1 in zip(bars, f1_scores):
        ax.text(bar.get_width() + 0.01, bar.get_y() + bar.get_height()/2,
                f'{f1:.3f}', va='center', fontsize=10)

    ax.legend()
    plt.tight_layout()
    plt.show()

    if USE_WANDB:
        wandb.log({'f1_per_class': wandb.Image(fig)})
        for name in CLASS_LIST:
            wandb.log({f'f1_{name}': report[name]['f1-score']})

### Correct vs Misclassified Examples

Looking at what the model gets right and wrong is more informative than any metric. Pay attention to:
- **Misclassified images:** Are they genuinely ambiguous? Mislabeled? Or a clear model failure?
- **Which classes get confused?** Cross-reference with the confusion matrix.

In [None]:
if len(val_preds) == 0:
    print('No predictions available. Skipping examples display.')
else:
    correct_indices = np.where(val_preds == val_labels)[0]
    incorrect_indices = np.where(val_preds != val_labels)[0]

    print(f'Correct:   {len(correct_indices):,} ({len(correct_indices)/len(val_labels)*100:.1f}%)')
    print(f'Incorrect: {len(incorrect_indices):,} ({len(incorrect_indices)/len(val_labels)*100:.1f}%)')


    def show_predictions(indices, title, n=5):
        """Show predicted vs actual labels for a set of image indices."""
        if len(indices) == 0:
            print(f'{title}: no examples to show.')
            return
        np.random.seed(42)
        selected = np.random.choice(indices, size=min(n, len(indices)), replace=False)

        fig, axes = plt.subplots(1, n, figsize=(3 * n, 4))
        if n == 1:
            axes = [axes]

        for ax, idx in zip(axes, selected):
            img_path = val_df.iloc[idx]['image_path']
            img = Image.open(img_path).convert('RGB')
            ax.imshow(img)

            pred_name = CLASS_NAMES[val_preds[idx]]
            true_name = CLASS_NAMES[val_labels[idx]]

            if val_preds[idx] == val_labels[idx]:
                ax.set_title(f'Pred: {pred_name}\nTrue: {true_name}', fontsize=9, color='green')
            else:
                ax.set_title(f'Pred: {pred_name}\nTrue: {true_name}', fontsize=9, color='red')
            ax.axis('off')

        plt.suptitle(title, fontsize=13, fontweight='bold')
        plt.tight_layout()
        plt.show()


    show_predictions(correct_indices, 'Correctly Classified Examples', n=5)
    if len(incorrect_indices) > 0:
        show_predictions(incorrect_indices, 'Misclassified Examples', n=5)

---
## 8. Export Model

Save the final model checkpoint for use in notebook 05 (inference pipeline).

In [None]:
if not HAS_DATA or not best_model_path.exists():
    print('No model checkpoint found.')
    print('Train with data on Kaggle/Colab to generate a checkpoint.')
else:
    model_size_mb = best_model_path.stat().st_size / (1024 * 1024)

    print(f'=== Model Export Summary ===')
    print(f'Model:          EfficientNetV2-S')
    print(f'Dataset:        {config["name"]} ({NUM_CLASSES} classes)')
    print(f'Best val acc:   {best_val_acc:.4f} ({best_val_acc*100:.1f}%)')
    print(f'Checkpoint:     {best_model_path}')
    print(f'Model size:     {model_size_mb:.1f} MB')
    print(f'Input size:     {IMG_SIZE}x{IMG_SIZE} RGB')
    print(f'Normalization:  ImageNet (mean={IMAGENET_MEAN}, std={IMAGENET_STD})')
    print(f'\nCheckpoint contents:')
    for key in checkpoint.keys():
        print(f'  - {key}')

    if USE_WANDB:
        wandb.summary['best_val_acc'] = best_val_acc
        wandb.summary['model_size_mb'] = model_size_mb
        wandb.summary['dataset'] = ACTIVE_DATASET
        wandb.save(str(best_model_path))
        wandb.finish()
        print(f'\nW&B run finished. Model artifact saved.')

    print(f'\n=== Next Steps ===')
    print(f'1. Download {best_model_path.name} from Kaggle output tab (if needed locally)')
    print(f'2. Notebook 05: Build inference pipeline using this checkpoint')
    print(f'3. Notebook 02: Explore RiceSEG for segmentation training')