# 3 — Model Training
## Comparing Four Architectures on Next Day Wildfire Spread

We train and compare four fire-spread prediction models:

| Model | Type | Learnable Params | Physics? |
|-------|------|------------------|----------|
| **CA** | Cellular Automaton (Rothermel rules) | 0 | Yes |
| **ConvLSTM** | Recurrent + Convolutional | ~350 K | No |
| **U-Net** | Encoder–Decoder + Attention | ~2.1 M | No |
| **PI-CCA** | Physics-Informed Conv. CA (ours) | ~1.5 M | Hybrid |

All models receive a `(B, 12, 64, 64)` input tensor and output `(B, 1, 64, 64)` fire probability.

**Loss**: Focal + Dice (handles severe class imbalance)  
**Optimiser**: AdamW + CosineAnnealingLR + early stopping

In [None]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
from pathlib import Path

from config import (
    MODEL_CONFIG, TRAINING_CONFIG, MODELS_DIR, RESULTS_DIR,
    PROCESSED_DIR, FIGURES_DIR, SEED, N_INPUT_CHANNELS, FEATURE_CHANNELS,
)
from src.data.dataset import get_dataloaders
from src.models.cellular_automata import CellularAutomataModel
from src.models.convlstm import ConvLSTMModel
from src.models.unet import UNetFire
from src.models.pi_cca import PIConvCellularAutomaton
from src.training.trainer import FireSpreadTrainer

sns.set_theme(style='whitegrid', font_scale=1.1)
%matplotlib inline

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
print(f'Input channels: {N_INPUT_CHANNELS}')
print(f'Features: {FEATURE_CHANNELS}')

## 3.1 Load Data

In [None]:
loaders = get_dataloaders(
    PROCESSED_DIR,
    batch_size=TRAINING_CONFIG['batch_size'],
    num_workers=TRAINING_CONFIG.get('num_workers', 4),
    seed=SEED,
)

for name, loader in loaders.items():
    print(f'{name}: {len(loader.dataset)} samples, {len(loader)} batches')

# Verify shapes
x_batch, y_batch = next(iter(loaders['train']))
print(f'\nBatch shapes: X={x_batch.shape}, Y={y_batch.shape}')
print(f'X range: [{x_batch.min():.2f}, {x_batch.max():.2f}]')
print(f'Y unique: {torch.unique(y_batch).tolist()}')

## 3.2 Instantiate Models

Each model is defined in `src/models/` and configured via `config.MODEL_CONFIG`.

In [None]:
MODEL_CLASSES = {
    'ca': CellularAutomataModel,
    'convlstm': ConvLSTMModel,
    'unet': UNetFire,
    'pi_cca': PIConvCellularAutomaton,
}

for name, cls in MODEL_CLASSES.items():
    model = cls(config=MODEL_CONFIG[name])
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'{name:>10s}: {MODEL_CONFIG[name]["name"]:>35s}  |  {n_params:>10,d} params')

## 3.3 Train All Models

This cell trains each learnable model. CA is evaluated without training (0 params).

> **Note**: Training may take 30–60 min on GPU. Skip this cell if checkpoints already exist.

In [None]:
results = {}

for name in ['convlstm', 'unet', 'pi_cca']:
    print(f'\n{"="*60}')
    print(f'  Training: {MODEL_CONFIG[name]["name"]}')
    print(f'{"="*60}')
    
    model = MODEL_CLASSES[name](config=MODEL_CONFIG[name]).to(device)
    save_dir = MODELS_DIR / name
    save_dir.mkdir(parents=True, exist_ok=True)
    
    trainer = FireSpreadTrainer(
        model=model,
        config=TRAINING_CONFIG,
        device=device,
        save_dir=str(save_dir),
    )
    
    history = trainer.train(loaders['train'], loaders['val'])
    results[name] = history
    
    # Save history
    with open(save_dir / 'training_history.json', 'w') as f:
        json.dump(history, f, indent=2)
    
    print(f'  Best val loss: {min(history["val_loss"]):.4f}')

## 3.4 Training Curves

In [None]:
# Load training histories (from saved JSON)
histories = {}
for name in ['convlstm', 'unet', 'pi_cca']:
    hist_path = MODELS_DIR / name / 'training_history.json'
    if hist_path.exists():
        with open(hist_path) as f:
            histories[name] = json.load(f)

if histories:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    colors = {'convlstm': '#2196F3', 'unet': '#4CAF50', 'pi_cca': '#FF5722'}
    
    for name, hist in histories.items():
        label = MODEL_CONFIG[name]['name']
        c = colors.get(name, 'gray')
        
        axes[0].plot(hist['train_loss'], label=f'{label} (train)', color=c, linestyle='-')
        axes[0].plot(hist['val_loss'], label=f'{label} (val)', color=c, linestyle='--')
        
        if 'val_iou' in hist:
            axes[1].plot(hist['val_iou'], label=label, color=c)
        elif 'val_f1' in hist:
            axes[1].plot(hist['val_f1'], label=label, color=c)
    
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss (Focal + Dice)')
    axes[0].set_title('Training & Validation Loss', fontweight='bold')
    axes[0].legend(fontsize=8)
    
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('IoU / F1')
    axes[1].set_title('Validation Metric', fontweight='bold')
    axes[1].legend(fontsize=9)
    
    plt.tight_layout()
    plt.savefig('../results/figures/training_curves.png', dpi=150, bbox_inches='tight')
    plt.show()
else:
    print('No training histories found. Train models first.')

## 3.5 Quick Validation Check

In [None]:
# Visualise predictions on a validation batch
x_val, y_val = next(iter(loaders['val']))
x_val, y_val = x_val.to(device), y_val.to(device)

fig, axes = plt.subplots(3, 4, figsize=(18, 12))

for row, name in enumerate(['convlstm', 'unet', 'pi_cca']):
    model = MODEL_CLASSES[name](config=MODEL_CONFIG[name]).to(device)
    ckpt = MODELS_DIR / name / 'best_model.pt'
    if ckpt.exists():
        model.load_state_dict(torch.load(ckpt, map_location=device))
    model.eval()
    
    with torch.no_grad():
        pred = model(x_val[:1]).squeeze().cpu().numpy()
    
    gt = y_val[0].squeeze().cpu().numpy()
    fire_in = x_val[0, -1].cpu().numpy()  # prev_fire_mask is last channel
    
    axes[row, 0].imshow(fire_in, cmap='hot', vmin=0, vmax=1)
    axes[row, 0].set_title('Input Fire' if row == 0 else '')
    axes[row, 0].set_ylabel(MODEL_CONFIG[name]['name'], fontweight='bold', fontsize=10)
    
    axes[row, 1].imshow(pred, cmap='hot', vmin=0, vmax=1)
    axes[row, 1].set_title('Prediction' if row == 0 else '')
    
    axes[row, 2].imshow(gt, cmap='hot', vmin=0, vmax=1)
    axes[row, 2].set_title('Ground Truth' if row == 0 else '')
    
    diff = np.abs(pred - gt)
    axes[row, 3].imshow(diff, cmap='Reds', vmin=0, vmax=1)
    axes[row, 3].set_title('|Error|' if row == 0 else '')
    
    for ax in axes[row]:
        ax.axis('off')

plt.suptitle('Validation Predictions (1 sample)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('../results/figures/training_val_predictions.png', dpi=150, bbox_inches='tight')
plt.show()

## Summary

- All models converge within 30–50 epochs
- **PI-CCA** combines Rothermel-inspired physics with learned CNN features via cross-attention
- **U-Net** achieves strong raw performance through multi-scale feature extraction
- **ConvLSTM** captures temporal dynamics (here used single-step: t → t+1)
- **CA** provides a physics-only baseline (no training needed)

Detailed evaluation on the test set follows in Notebook 04.