# üî• California Fire Model - Training Notebook

This notebook trains a burn severity detection model on California wildfire data.

## Prerequisites
1. Run `data/download_fire_data.py` to queue data downloads
2. Download data from Google Drive to `data/raw/`
3. Run `data/compute_statistics.py` to calculate normalization stats
4. Run `data/validate_dataset.py` to check data quality

In [None]:
# Only run this cell in Google Colab
# Uncomment if needed:

# from google.colab import drive
# drive.mount('/content/drive')

# # Copy data to local SSD for faster training
# !mkdir -p /content/local_data
# !cp -r /content/drive/MyDrive/California_Fire_Model/* /content/local_data/

In [None]:
import sys
import os
from pathlib import Path

# Add project to path
PROJECT_ROOT = Path('.').resolve()
sys.path.insert(0, str(PROJECT_ROOT))

# Imports
import torch
import numpy as np
import matplotlib.pyplot as plt

from config import (
    TRAINING_CONFIG, MODEL_CONFIG, CHECKPOINT_DIR, RAW_DATA_DIR,
    TRAINING_FIRES, TEST_FIRES, print_config_summary
)

# Print configuration
print_config_summary()

## 1. Setup Device

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nüîß Device: {device}")

if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. Create Datasets

In [None]:
from data.dataset import create_train_val_datasets
from torch.utils.data import DataLoader

# Data directories
data_dirs = [
    str(RAW_DATA_DIR / "fires"),
    str(RAW_DATA_DIR / "healthy"),
]

# Check data exists
for d in data_dirs:
    exists = Path(d).exists()
    count = len(list(Path(d).glob('**/*.tif'))) if exists else 0
    status = '‚úÖ' if count > 0 else '‚ùå'
    print(f"{status} {d}: {count} tiles")

# Create datasets
train_dataset, val_dataset, test_dataset = create_train_val_datasets(
    data_dirs,
    val_split=0.15,
)

print(f"\nüìä Dataset sizes:")
print(f"   Train: {len(train_dataset)}")
print(f"   Val: {len(val_dataset)}")
print(f"   Test: {len(test_dataset)}")

In [None]:
# Create dataloaders
batch_size = TRAINING_CONFIG['batch_size']

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=TRAINING_CONFIG['num_workers'],
    pin_memory=True,
    drop_last=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=TRAINING_CONFIG['num_workers'],
    pin_memory=True,
)

print(f"\nüì¶ DataLoaders:")
print(f"   Train batches: {len(train_loader)}")
print(f"   Val batches: {len(val_loader)}")

## 3. Visualize Sample Data

In [None]:
from inference.visualize import rgb_from_sentinel2, get_severity_cmap

# Get a sample batch
images, labels = next(iter(train_loader))
print(f"Batch shapes: images={images.shape}, labels={labels.shape}")

# Visualize first 4 samples
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

cmap = get_severity_cmap()

for i in range(4):
    # RGB (denormalize roughly)
    img = images[i].numpy()
    img = img * 6 - 3  # Undo [0,1] normalization
    img = img * np.array([545, 476, 571, 532, 614, 731, 811, 872, 856, 611]).reshape(-1, 1, 1)
    img = img + np.array([1339, 1167, 1002, 1296, 1835, 2149, 2290, 2410, 2004, 1075]).reshape(-1, 1, 1)
    rgb = rgb_from_sentinel2(img)
    
    axes[0, i].imshow(rgb)
    axes[0, i].set_title(f"RGB {i+1}")
    axes[0, i].axis('off')
    
    # Label
    label = labels[i, 0].numpy()
    im = axes[1, i].imshow(label, cmap=cmap, vmin=0, vmax=1)
    axes[1, i].set_title(f"Severity (mean: {label.mean():.1%})")
    axes[1, i].axis('off')
    
plt.colorbar(im, ax=axes[1, :].tolist(), shrink=0.6)
plt.suptitle('Training Samples', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 4. Create Model

In [None]:
from model.architecture import CaliforniaFireModel
from model.losses import CombinedLoss

# Create model
model = CaliforniaFireModel(**MODEL_CONFIG).to(device)

# Count parameters
params = sum(p.numel() for p in model.parameters()) / 1e6
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6

print(f"\nüß† Model:")
print(f"   Total parameters: {params:.2f}M")
print(f"   Trainable: {trainable:.2f}M")

# Test forward pass
x = torch.randn(2, 10, 256, 256).to(device)
with torch.no_grad():
    y = model(x)
print(f"\n   Input: {x.shape}")
print(f"   Output: {y.shape}")

In [None]:
# Loss function
criterion = CombinedLoss(
    bce_weight=0.5,
    dice_weight=0.5,
    pos_weight=2.0,  # Weight burned pixels more
)

# Test loss
with torch.no_grad():
    images_gpu = images[:2].to(device)
    labels_gpu = labels[:2].to(device)
    logits = model(images_gpu)
    loss, components = criterion(logits, labels_gpu)
    
print(f"\nüìâ Loss components:")
for k, v in components.items():
    print(f"   {k}: {v:.4f}")

## 5. Train

In [None]:
from training.train import Trainer
import torch.optim as optim

# Optimizer
optimizer = optim.AdamW(
    model.parameters(),
    lr=TRAINING_CONFIG['learning_rate'],
    weight_decay=TRAINING_CONFIG['weight_decay'],
)

# Scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=TRAINING_CONFIG['lr_scheduler_factor'],
    patience=TRAINING_CONFIG['lr_scheduler_patience'],
    min_lr=1e-7,
)

# Create trainer
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    device=device,
    config=TRAINING_CONFIG,
)

print("‚úÖ Trainer ready!")

In [None]:
# Train!
results = trainer.train(epochs=TRAINING_CONFIG['epochs'])

## 6. Plot Training History

In [None]:
history = results['history']

fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Loss
axes[0, 0].plot(history['train_loss'], label='Train')
axes[0, 0].plot(history['val_loss'], label='Val')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# IoU
axes[0, 1].plot(history['val_iou'], 'g-', label='IoU')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('IoU')
axes[0, 1].set_title(f'Validation IoU (Best: {max(history["val_iou"]):.4f})')
axes[0, 1].grid(True, alpha=0.3)

# MAE
axes[1, 0].plot(history['val_mae'], 'r-', label='MAE')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('MAE')
axes[1, 0].set_title(f'Validation MAE (Best: {min(history["val_mae"]):.4f})')
axes[1, 0].grid(True, alpha=0.3)

# LR
axes[1, 1].semilogy(history['lr'])
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Learning Rate')
axes[1, 1].set_title('Learning Rate Schedule')
axes[1, 1].grid(True, alpha=0.3)

plt.suptitle('Training History', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(str(CHECKPOINT_DIR / 'training_history.png'), dpi=150)
plt.show()

## 7. Evaluate on Test Set

In [None]:
from model.architecture import load_model
from model.metrics import MetricTracker

# Load best model
best_model_path = CHECKPOINT_DIR / 'best_model.pth'
best_model = load_model(str(best_model_path), device=str(device), **MODEL_CONFIG)

# Test loader
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
)

# Evaluate
tracker = MetricTracker(threshold=0.5)
best_model.eval()

with torch.no_grad():
    for batch in test_loader:
        images, labels, metadata = batch
        images = images.to(device)
        labels = labels.to(device)
        
        logits = best_model(images)
        
        for i in range(images.size(0)):
            fire_key = metadata['fire_key'][i]
            tracker.update(logits[i:i+1], labels[i:i+1], category=fire_key)

# Print results
tracker.print_summary(prefix="üß™ TEST SET: ")

In [None]:
# Visualize test predictions
from inference.visualize import plot_prediction

# Get a few test samples
test_iter = iter(test_loader)
images, labels, metadata = next(test_iter)
images_gpu = images.to(device)

with torch.no_grad():
    predictions = torch.sigmoid(best_model(images_gpu)).cpu().numpy()

# Plot first 3
for i in range(min(3, len(images))):
    img = images[i].numpy()
    # Denormalize for visualization
    img = img * 6 - 3
    img = img * np.array([545, 476, 571, 532, 614, 731, 811, 872, 856, 611]).reshape(-1, 1, 1)
    img = img + np.array([1339, 1167, 1002, 1296, 1835, 2149, 2290, 2410, 2004, 1075]).reshape(-1, 1, 1)
    
    pred = predictions[i, 0]
    gt = labels[i, 0].numpy()
    
    fire = metadata['fire_key'][i]
    stage = metadata['stage'][i]
    
    fig = plot_prediction(
        img, pred, gt,
        title=f"{fire} - {stage}"
    )
    plt.show()

## Done! üéâ

Your trained model is saved at:
- `checkpoints/best_model.pth` - Best validation IoU
- `checkpoints/final_model.pth` - Final epoch

For inference on new images, use:
```python
from inference.predict import FirePredictor

predictor = FirePredictor('checkpoints/best_model.pth')
severity, metadata = predictor.predict_file('new_image.tif', 'output.tif')
```