# 4 — Results & Evaluation
## Test-Set Performance on Next Day Wildfire Spread

This notebook evaluates all four models on the held-out **test set** (geographically disjoint):

| Metric | Description |
|--------|-------------|
| **IoU** | Intersection over Union (Jaccard) |
| **Dice / F1** | Harmonic mean of precision & recall |
| **Precision** | True positive / predicted positive |
| **Recall** | True positive / actual positive |
| **AUC-ROC** | Area under ROC curve |

In [None]:
import json, numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# ── Charger la config générée par 00_Setup.ipynb ──────────────────────────────
_cfg_path = Path().resolve() / "setup_config.json"
if not _cfg_path.exists():
    raise FileNotFoundError(
        "setup_config.json introuvable.\n"
        "→ Lance d'abord le notebook 00_Setup.ipynb"
    )
cfg = json.load(open(_cfg_path))

PROCESSED_DIR    = Path(cfg["PROCESSED_DIR"])
FIGURES_DIR      = Path(cfg["FIGURES_DIR"])
MODELS_DIR       = Path(cfg["MODELS_DIR"])
FEATURE_CHANNELS = cfg["FEATURE_CHANNELS"]
N_INPUT_CHANNELS = cfg["N_INPUT_CHANNELS"]
CH               = cfg["CH"]
GRID_SIZE        = cfg["GRID_SIZE"]
norm_stats       = cfg["norm_stats"]

sns.set_theme(style='whitegrid', font_scale=1.1)
%matplotlib inline
print(f"Config chargée depuis : {_cfg_path}")

## 4.1 Load Models & Test Data

In [None]:
loaders = get_dataloaders(PROCESSED_DIR, batch_size=32, seed=SEED)
test_loader = loaders['test']
print(f'Test set: {len(test_loader.dataset)} samples, {len(test_loader)} batches')

models = {}
for name, cls in MODEL_CLASSES.items():
    model = cls(config=MODEL_CONFIG[name]).to(device)
    ckpt = MODELS_DIR / name / 'best_model.pt'
    if ckpt.exists():
        model.load_state_dict(torch.load(ckpt, map_location=device))
        print(f'Loaded {name} from checkpoint')
    else:
        print(f'{name}: no checkpoint, using init weights')
    model.eval()
    models[name] = model

## 4.2 Collect Predictions

In [None]:
all_preds = {name: [] for name in models}
all_targets = []

with torch.no_grad():
    for x_batch, y_batch in test_loader:
        x_batch = x_batch.to(device)
        all_targets.append(y_batch.numpy())
        
        for name, model in models.items():
            pred = model(x_batch).cpu().numpy()
            all_preds[name].append(pred)

Y_true = np.concatenate(all_targets, axis=0).ravel()
Y_preds = {name: np.concatenate(preds, axis=0).ravel() for name, preds in all_preds.items()}

print(f'Total test pixels: {len(Y_true):,}')
print(f'Fire prevalence: {Y_true.mean():.4f}')

## 4.3 Metrics Table

In [None]:
def compute_metrics(y_true, y_prob, threshold=0.5):
    y_pred = (y_prob >= threshold).astype(int)
    tp = ((y_pred == 1) & (y_true == 1)).sum()
    fp = ((y_pred == 1) & (y_true == 0)).sum()
    fn = ((y_pred == 0) & (y_true == 1)).sum()
    tn = ((y_pred == 0) & (y_true == 0)).sum()
    
    precision = tp / (tp + fp + 1e-8)
    recall = tp / (tp + fn + 1e-8)
    f1 = 2 * precision * recall / (precision + recall + 1e-8)
    iou = tp / (tp + fp + fn + 1e-8)
    
    try:
        auc = roc_auc_score(y_true, y_prob)
    except ValueError:
        auc = float('nan')
    
    return {
        'Precision': precision, 'Recall': recall,
        'F1/Dice': f1, 'IoU': iou, 'AUC-ROC': auc
    }

rows = []
for name in models:
    m = compute_metrics(Y_true, Y_preds[name])
    m['Model'] = MODEL_CONFIG[name]['name']
    rows.append(m)

metrics_df = pd.DataFrame(rows).set_index('Model')
metrics_df = metrics_df.round(4)
print(metrics_df.to_string())

# Save
metrics_df.to_json(RESULTS_DIR / 'model_comparison_summary.json', indent=2)
metrics_df

## 4.4 ROC Curves

In [None]:
fig, ax = plt.subplots(figsize=(8, 7))
colors = {'ca': '#9E9E9E', 'convlstm': '#2196F3', 'unet': '#4CAF50', 'pi_cca': '#FF5722'}

for name in models:
    fpr, tpr, _ = roc_curve(Y_true, Y_preds[name])
    auc_val = roc_auc_score(Y_true, Y_preds[name])
    ax.plot(fpr, tpr, label=f'{MODEL_CONFIG[name]["name"]} (AUC={auc_val:.3f})',
            color=colors.get(name, 'gray'), linewidth=2)

ax.plot([0, 1], [0, 1], 'k--', alpha=0.4, label='Random')
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('ROC Curves — Test Set', fontweight='bold')
ax.legend(loc='lower right')
plt.tight_layout()
plt.savefig(FIGURES_DIR / 'roc_curves.png', dpi=150, bbox_inches='tight')
plt.show()

## 4.5 Precision–Recall Curves

In [None]:
fig, ax = plt.subplots(figsize=(8, 7))

for name in models:
    prec, rec, _ = precision_recall_curve(Y_true, Y_preds[name])
    ax.plot(rec, prec, label=MODEL_CONFIG[name]['name'],
            color=colors.get(name, 'gray'), linewidth=2)

ax.set_xlabel('Recall')
ax.set_ylabel('Precision')
ax.set_title('Precision–Recall Curves — Test Set', fontweight='bold')
ax.legend(loc='upper right')
plt.tight_layout()
plt.savefig(FIGURES_DIR / 'pr_curves.png', dpi=150, bbox_inches='tight')
plt.show()

## 4.6 Visual Comparison on Test Samples

In [None]:
# Pick 4 diverse test samples
test_ds = test_loader.dataset
rng = np.random.default_rng(SEED)
sample_ids = rng.choice(len(test_ds), min(4, len(test_ds)), replace=False)

n_models = len(models)
fig, axes = plt.subplots(len(sample_ids), n_models + 2, figsize=(4 * (n_models + 2), 4 * len(sample_ids)))

for i, sid in enumerate(sample_ids):
    x, y = test_ds[sid]
    gt = y.squeeze().numpy()
    fire_in = x[CH['prev_fire_mask']].numpy()
    
    axes[i, 0].imshow(fire_in, cmap='hot', vmin=0, vmax=1)
    axes[i, 0].set_title('Input Fire' if i == 0 else '', fontsize=10)
    if i == 0:
        axes[i, 0].set_ylabel('Day t', fontweight='bold')
    
    for j, (name, model) in enumerate(models.items()):
        with torch.no_grad():
            pred = model(x.unsqueeze(0).to(device)).squeeze().cpu().numpy()
        axes[i, j+1].imshow(pred, cmap='hot', vmin=0, vmax=1)
        axes[i, j+1].contour(gt, levels=[0.5], colors='lime', linewidths=1)
        if i == 0:
            axes[i, j+1].set_title(MODEL_CONFIG[name]['name'], fontsize=10, fontweight='bold')
    
    axes[i, -1].imshow(gt, cmap='hot', vmin=0, vmax=1)
    if i == 0:
        axes[i, -1].set_title('Ground Truth', fontsize=10)
    
    for ax in axes[i]:
        ax.axis('off')

plt.suptitle('Model Predictions vs Ground Truth (Test Set)', fontsize=14, fontweight='bold', y=1.01)
plt.tight_layout()
plt.savefig(FIGURES_DIR / 'test_visual_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

## 4.7 Confusion Matrices

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(20, 4.5))

for j, name in enumerate(models):
    y_pred_bin = (Y_preds[name] >= 0.5).astype(int)
    cm = confusion_matrix(Y_true.astype(int), y_pred_bin, normalize='true')
    disp = ConfusionMatrixDisplay(cm, display_labels=['No Fire', 'Fire'])
    disp.plot(ax=axes[j], cmap='Blues', colorbar=False, values_format='.3f')
    axes[j].set_title(MODEL_CONFIG[name]['name'], fontweight='bold', fontsize=10)

plt.suptitle('Normalised Confusion Matrices (Test)', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.savefig(FIGURES_DIR / 'confusion_matrices.png', dpi=150, bbox_inches='tight')
plt.show()

## 4.8 PI-CCA Uncertainty (MC-Dropout)

In [None]:
pi_cca = models['pi_cca']

if hasattr(pi_cca, 'predict_with_uncertainty'):
    x_test, y_test = test_ds[sample_ids[0]]
    x_dev = x_test.unsqueeze(0).to(device)
    
    mean_pred, std_pred = pi_cca.predict_with_uncertainty(x_dev, n_samples=30)
    mean_np = mean_pred.squeeze().cpu().numpy()
    std_np = std_pred.squeeze().cpu().numpy()
    gt = y_test.squeeze().numpy()
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    im0 = axes[0].imshow(mean_np, cmap='hot', vmin=0, vmax=1)
    axes[0].contour(gt, levels=[0.5], colors='lime', linewidths=1.5)
    axes[0].set_title('Mean Prediction', fontweight='bold')
    plt.colorbar(im0, ax=axes[0], fraction=0.046)
    
    im1 = axes[1].imshow(std_np, cmap='magma')
    axes[1].set_title('Uncertainty (σ)', fontweight='bold')
    plt.colorbar(im1, ax=axes[1], fraction=0.046)
    
    axes[2].imshow(gt, cmap='hot', vmin=0, vmax=1)
    axes[2].set_title('Ground Truth', fontweight='bold')
    
    for ax in axes:
        ax.axis('off')
    
    plt.suptitle('PI-CCA Uncertainty Quantification (MC-Dropout, 30 samples)', fontweight='bold')
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / 'pi_cca_uncertainty.png', dpi=150, bbox_inches='tight')
    plt.show()
else:
    print('PI-CCA does not support uncertainty estimation')

## Summary

- **PI-CCA** achieves the best balance of precision and recall by combining learned features with physics priors
- **U-Net** excels at capturing spatial patterns through multi-scale encoding
- **ConvLSTM** performs competitively as a temporal backbone
- **CA** (pure physics) provides interpretable but lower-accuracy predictions
- MC-Dropout uncertainty maps highlight prediction boundaries where models are least confident