# 🔥 California Fire Model - A100 Optimized Training

**High-performance training notebook for Google Colab A100**

Optimizations:
- ✅ Batch size 36 (A100 40GB can handle it)
- ✅ BF16 mixed precision (A100 native)
- ✅ 8 workers + prefetch for local SSD
- ✅ Gradient accumulation for larger effective batch
- ✅ OneCycleLR for faster convergence
- ✅ Aggressive augmentation
- ✅ Per-fire validation tracking

**Prerequisites:** Run `01_colab_setup.ipynb` first!

In [None]:
# ============================================================
# CONFIGURATION - A100 OPTIMIZED
# ============================================================

# Data paths (local SSD - fast!)
LOCAL_DATA_PATH = "/content/local_data"
CODE_PATH = "/content/California-Fire-Model"

# A100 Optimized settings
BATCH_SIZE = 36           # A100 40GB can handle 36 easily
NUM_WORKERS = 8           # Local SSD can feed 8 workers
PREFETCH_FACTOR = 4       # Prefetch 4 batches per worker

# Training
EPOCHS = 60               # More epochs for better convergence
LEARNING_RATE = 3e-4      # Higher LR with OneCycleLR
WEIGHT_DECAY = 1e-4
GRADIENT_ACCUMULATION = 2 # Effective batch = 36*2 = 72

# Mixed precision (BF16 is optimal for A100)
USE_BF16 = True           # Use bfloat16 on A100

# Model
BASE_CHANNELS = 64        # Standard U-Net width
USE_ATTENTION = True
DROPOUT = 0.2

# Loss weights (tuned for good probability output)
BCE_WEIGHT = 0.4
DICE_WEIGHT = 0.4
FOCAL_WEIGHT = 0.2        # Add focal loss for hard examples
POS_WEIGHT = 2.5          # Weight burned pixels more

# Checkpointing
SAVE_TO_DRIVE = True
DRIVE_CHECKPOINT_PATH = "/content/drive/MyDrive/California_Fire_Model/checkpoints"

print("✅ Configuration loaded")
print(f"   Batch size: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} = {BATCH_SIZE * GRADIENT_ACCUMULATION} effective")
print(f"   Workers: {NUM_WORKERS}")
print(f"   Mixed precision: {'BF16' if USE_BF16 else 'FP16'}")

In [None]:
import sys
sys.path.insert(0, CODE_PATH)

import os
import time
import json
from pathlib import Path
from datetime import datetime
from collections import defaultdict

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n🔧 Device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    
    # Check for A100 and BF16 support
    if torch.cuda.is_bf16_supported():
        print("   ✅ BF16 supported")
    else:
        USE_BF16 = False
        print("   ⚠️ BF16 not supported, using FP16")

## 1. Create Datasets

In [None]:
import rasterio
import albumentations as A
from torch.utils.data import Dataset

# ============================================================
# BAND STATISTICS - Update these after running compute_statistics.py!
# ============================================================
# Default California stats (update if you computed your own)
BAND_MEANS = [1339, 1167, 1002, 1296, 1835, 2149, 2290, 2410, 2004, 1075]
BAND_STDS = [545, 476, 571, 532, 614, 731, 811, 872, 856, 611]

BAND_MEANS = np.array(BAND_MEANS, dtype=np.float32)
BAND_STDS = np.array(BAND_STDS, dtype=np.float32)

# Training fires and test fires
TRAINING_FIRE_KEYS = ['dixie', 'caldor', 'creek', 'camp', 'mendocino', 'thomas', 'kincade', 'woolsey']
TEST_FIRE_KEYS = []  # Using all fires for training

print(f"Training fires: {len(TRAINING_FIRE_KEYS)}")
print(f"Test fires: {len(TEST_FIRE_KEYS)}")

In [None]:
class CaliforniaFireDatasetColab(Dataset):
    """
    Optimized dataset for Colab A100 training.
    """
    
    def __init__(self, data_dirs, mode='train', augment=True, fire_keys=None):
        self.mode = mode
        self.augment = augment and (mode == 'train')
        
        # Collect tiles
        self.samples = []
        for data_dir in data_dirs:
            data_path = Path(data_dir)
            if not data_path.exists():
                continue
                
            for tif_file in data_path.rglob("*.tif"):
                # Extract fire key from path
                rel_path = str(tif_file.relative_to(data_path))
                parts = rel_path.split('/')
                
                if 'fires' in str(tif_file):
                    fire_key = parts[0]  # Folder name (caldor, dixie, etc.)
                else:
                    fire_key = 'healthy'
                
                # Filter by fire keys if specified
                if fire_keys is not None:
                    if fire_key not in fire_keys and fire_key != 'healthy':
                        continue
                
                self.samples.append({
                    'path': str(tif_file),
                    'fire_key': fire_key,
                })
        
        print(f"   {mode.upper()}: {len(self.samples)} tiles")
        
        # Aggressive augmentation for training
        if self.augment:
            self.transform = A.Compose([
                A.RandomRotate90(p=0.5),
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5),
                A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.15, rotate_limit=45, p=0.5),
                A.RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15, p=0.4),
                A.GaussNoise(std_limit=(5, 25), per_channel=True, p=0.3),
                A.CoarseDropout(num_holes_range=(2, 5), hole_height_range=(16, 32), 
                               hole_width_range=(16, 32), p=0.25),
            ])
        else:
            self.transform = None
    
    def __len__(self):
        return len(self.samples)
    
    def normalize(self, image):
        image = np.clip(image, 0, 10000).astype(np.float32)
        for i in range(10):
            image[i] = (image[i] - BAND_MEANS[i]) / (BAND_STDS[i] + 1e-6)
        image = np.clip(image, -3, 3)
        image = (image + 3) / 6
        return image
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        with rasterio.open(sample['path']) as src:
            data = src.read()
        
        # Handle size
        _, h, w = data.shape
        if h != 256 or w != 256:
            padded = np.zeros((data.shape[0], 256, 256), dtype=data.dtype)
            padded[:, :min(h,256), :min(w,256)] = data[:, :min(h,256), :min(w,256)]
            data = padded
        
        # Split bands and label
        image = data[:10]
        label = data[10] if data.shape[0] > 10 else np.zeros((256, 256), dtype=np.float32)
        
        # Clean NaN
        image = np.nan_to_num(image, nan=0.0, posinf=10000.0, neginf=0.0)
        label = np.nan_to_num(label, nan=0.0, posinf=1.0, neginf=0.0)
        label = np.clip(label, 0.0, 1.0).astype(np.float32)
        
        # Normalize
        image = self.normalize(image)
        
        # Augment
        if self.transform:
            image_hwc = image.transpose(1, 2, 0)
            augmented = self.transform(image=image_hwc, mask=label)
            image = augmented['image'].transpose(2, 0, 1)
            label = augmented['mask']
        
        return (
            torch.from_numpy(image).float(),
            torch.from_numpy(label).float().unsqueeze(0),
            sample['fire_key']
        )

print("✅ Dataset class defined")

## 1.5 Compute Band Statistics (Run Once)

Computes actual mean/std from your training data for better normalization.

In [None]:
# ============================================================
# COMPUTE BAND STATISTICS FROM TRAINING DATA
# ============================================================
# Uses efficient streaming algorithm - memory efficient

def compute_band_statistics(data_dirs, max_tiles=1000):
    """
    Compute mean and std for each band using streaming algorithm.
    Memory efficient - doesn't load all data at once.
    """
    print("📊 Computing band statistics from training data...")
    
    # Collect tiles
    all_tiles = []
    for data_dir in data_dirs:
        data_path = Path(data_dir)
        if data_path.exists():
            all_tiles.extend(list(data_path.rglob('*.tif')))
    
    if len(all_tiles) == 0:
        print("⚠️ No tiles found! Using default statistics.")
        return None
    
    print(f"   Found {len(all_tiles)} tiles")
    
    # Sample if too many
    if max_tiles and len(all_tiles) > max_tiles:
        np.random.seed(42)
        all_tiles = list(np.random.choice(all_tiles, max_tiles, replace=False))
        print(f"   Sampling {max_tiles} tiles")
    
    NUM_BANDS = 10
    
    # Batch statistics collection
    all_band_means = []
    all_band_stds = []
    all_band_mins = []
    all_band_maxs = []
    
    valid_tiles = 0
    total_pixels = 0
    
    for tile_path in tqdm(all_tiles, desc="Computing stats"):
        try:
            with rasterio.open(tile_path) as src:
                data = src.read()
                
                if data.shape[0] < NUM_BANDS:
                    continue
                
                # Extract spectral bands (not label)
                spectral = data[:NUM_BANDS].astype(np.float32)
                spectral = spectral.reshape(NUM_BANDS, -1)
                
                # Mask invalid pixels
                valid_mask = np.isfinite(spectral).all(axis=0)
                valid_mask &= (spectral[0] > 0) & (spectral[0] < 10000)
                
                spectral = spectral[:, valid_mask]
                
                if spectral.shape[1] < 100:
                    continue
                
                valid_tiles += 1
                total_pixels += spectral.shape[1]
                
                # Collect per-tile stats
                all_band_means.append(spectral.mean(axis=1))
                all_band_stds.append(spectral.std(axis=1))
                all_band_mins.append(spectral.min(axis=1))
                all_band_maxs.append(spectral.max(axis=1))
                
        except Exception as e:
            continue
    
    if valid_tiles == 0:
        print("⚠️ No valid tiles! Using default statistics.")
        return None
    
    # Aggregate statistics
    all_band_means = np.array(all_band_means)
    all_band_stds = np.array(all_band_stds)
    all_band_mins = np.array(all_band_mins)
    all_band_maxs = np.array(all_band_maxs)
    
    stats = {
        'means': all_band_means.mean(axis=0).tolist(),
        'stds': all_band_stds.mean(axis=0).tolist(),
        'mins': all_band_mins.min(axis=0).tolist(),
        'maxs': all_band_maxs.max(axis=0).tolist(),
        'tile_count': valid_tiles,
        'pixel_count': int(total_pixels),
    }
    
    return stats

print("✅ Statistics computation function defined")

In [None]:
# Run statistics computation
data_dirs_for_stats = [
    f"{LOCAL_DATA_PATH}/fires",
    f"{LOCAL_DATA_PATH}/healthy",
]

computed_stats = compute_band_statistics(data_dirs_for_stats, max_tiles=1000)

if computed_stats:
    # Update the global variables
    BAND_MEANS = np.array(computed_stats['means'], dtype=np.float32)
    BAND_STDS = np.array(computed_stats['stds'], dtype=np.float32)
    
    print("\n" + "="*60)
    print("📊 COMPUTED BAND STATISTICS")
    print("="*60)
    print(f"   Tiles processed: {computed_stats['tile_count']}")
    print(f"   Pixels processed: {computed_stats['pixel_count']:,}")
    
    BAND_NAMES = ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B11', 'B12']
    print("\n   Band Statistics:")
    print(f"   {'Band':<6} {'Mean':>10} {'Std':>10} {'Min':>10} {'Max':>10}")
    print("   " + "-"*46)
    
    for i, band in enumerate(BAND_NAMES):
        print(f"   {band:<6} {BAND_MEANS[i]:>10.1f} {BAND_STDS[i]:>10.1f} {computed_stats['mins'][i]:>10.1f} {computed_stats['maxs'][i]:>10.1f}")
    
    print("\n   💾 Saving to Drive...")
    
    # Save to Drive for reference
    stats_path = f"{DRIVE_CHECKPOINT_PATH}/band_statistics.json"
    os.makedirs(DRIVE_CHECKPOINT_PATH, exist_ok=True)
    
    with open(stats_path, 'w') as f:
        json.dump(computed_stats, f, indent=2)
    
    print(f"   ✅ Saved to: {stats_path}")
    
    # Print Python code for future reference
    print("\n   📝 Python code for config:")
    means_str = ', '.join([f"{v:.0f}" for v in BAND_MEANS])
    stds_str = ', '.join([f"{v:.0f}" for v in BAND_STDS])
    print(f"   BAND_MEANS = [{means_str}]")
    print(f"   BAND_STDS = [{stds_str}]")
    print("="*60)
else:
    print("\n⚠️ Using default statistics (update after computing from your data)")
    print(f"   BAND_MEANS = {BAND_MEANS.tolist()}")
    print(f"   BAND_STDS = {BAND_STDS.tolist()}")

In [None]:
# Create datasets
data_dirs = [
    f"{LOCAL_DATA_PATH}/fires",
    f"{LOCAL_DATA_PATH}/healthy",
]

# Split training fires into train/val (1 fire for validation)
np.random.seed(42)
val_fire = TRAINING_FIRE_KEYS[-1]  # Use last fire for validation
train_fires = [f for f in TRAINING_FIRE_KEYS if f != val_fire]

print(f"\n📊 Data Split:")
print(f"   Train fires: {train_fires}")
print(f"   Val fire: {val_fire}")
print(f"   Test fires: {TEST_FIRE_KEYS}")

print("\n📂 Creating datasets...")
train_dataset = CaliforniaFireDatasetColab(data_dirs, mode='train', augment=True, fire_keys=train_fires)
val_dataset = CaliforniaFireDatasetColab(data_dirs, mode='val', augment=False, fire_keys=[val_fire])
test_dataset = CaliforniaFireDatasetColab(data_dirs, mode='test', augment=False, fire_keys=TEST_FIRE_KEYS)

In [None]:
# Create dataloaders - A100 optimized
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    prefetch_factor=PREFETCH_FACTOR,
    persistent_workers=True,
    drop_last=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    prefetch_factor=PREFETCH_FACTOR,
    persistent_workers=True,
)

print(f"\n📦 DataLoaders:")
print(f"   Train: {len(train_loader)} batches ({len(train_dataset)} samples)")
print(f"   Val: {len(val_loader)} batches ({len(val_dataset)} samples)")

## 2. Create Model

In [None]:
# Import model components
from model.architecture import CaliforniaFireModel
from model.losses import CombinedLoss, FocalLoss, DiceLoss

# Create model
model = CaliforniaFireModel(
    input_channels=10,
    output_channels=1,
    base_channels=BASE_CHANNELS,
    use_attention=USE_ATTENTION,
    dropout=DROPOUT,
).to(device)

# Count parameters
params = sum(p.numel() for p in model.parameters()) / 1e6
print(f"\n🧠 Model: {params:.2f}M parameters")

# Enable cudnn benchmarking for A100
torch.backends.cudnn.benchmark = True
print("   ✅ cuDNN benchmark enabled")

In [None]:
# Enhanced loss function for better probability output
class EnhancedLoss(nn.Module):
    """
    Combined loss optimized for good probability calibration.
    """
    def __init__(self, bce_w=0.4, dice_w=0.4, focal_w=0.2, pos_weight=2.5):
        super().__init__()
        self.bce_w = bce_w
        self.dice_w = dice_w
        self.focal_w = focal_w
        self.pos_weight = pos_weight
    
    def dice_loss(self, logits, targets):
        probs = torch.sigmoid(logits)
        probs_flat = probs.view(-1)
        targets_flat = targets.view(-1)
        intersection = (probs_flat * targets_flat).sum()
        return 1 - (2 * intersection + 1e-6) / (probs_flat.sum() + targets_flat.sum() + 1e-6)
    
    def focal_loss(self, logits, targets, alpha=0.25, gamma=2.0):
        bce = nn.functional.binary_cross_entropy_with_logits(logits, targets, reduction='none')
        pt = torch.exp(-bce)
        focal = alpha * (1 - pt) ** gamma * bce
        return focal.mean()
    
    def forward(self, logits, targets):
        # BCE with positive weight
        weight = torch.ones_like(targets)
        weight[targets > 0.5] = self.pos_weight
        bce_loss = nn.functional.binary_cross_entropy_with_logits(logits, targets, weight=weight)
        
        # Dice
        dice = self.dice_loss(logits, targets)
        
        # Focal
        focal = self.focal_loss(logits, targets)
        
        total = self.bce_w * bce_loss + self.dice_w * dice + self.focal_w * focal
        
        return total, {
            'bce': bce_loss.item(),
            'dice': dice.item(),
            'focal': focal.item(),
        }

criterion = EnhancedLoss(
    bce_w=BCE_WEIGHT,
    dice_w=DICE_WEIGHT,
    focal_w=FOCAL_WEIGHT,
    pos_weight=POS_WEIGHT,
)

print(f"\n📉 Loss function:")
print(f"   BCE weight: {BCE_WEIGHT}, Dice weight: {DICE_WEIGHT}, Focal weight: {FOCAL_WEIGHT}")
print(f"   Positive class weight: {POS_WEIGHT}")

In [None]:
# Optimizer
optimizer = optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    betas=(0.9, 0.999),
)

# OneCycleLR - faster convergence
steps_per_epoch = len(train_loader) // GRADIENT_ACCUMULATION
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=LEARNING_RATE,
    epochs=EPOCHS,
    steps_per_epoch=steps_per_epoch,
    pct_start=0.1,  # Warmup 10% of training
    anneal_strategy='cos',
    div_factor=25,
    final_div_factor=1000,
)

# Mixed precision scaler
scaler = torch.cuda.amp.GradScaler()

print(f"\n⚙️ Optimizer: AdamW (lr={LEARNING_RATE}, wd={WEIGHT_DECAY})")
print(f"   Scheduler: OneCycleLR")
print(f"   Gradient accumulation: {GRADIENT_ACCUMULATION} steps")
print(f"   Mixed precision: {'BF16' if USE_BF16 else 'FP16'}")

## 3. Training Loop

In [None]:
def train_epoch(model, loader, criterion, optimizer, scheduler, scaler, device, accum_steps):
    """Train for one epoch with gradient accumulation."""
    model.train()
    
    total_loss = 0.0
    loss_components = defaultdict(float)
    num_batches = 0
    
    optimizer.zero_grad()
    
    pbar = tqdm(loader, desc="Training", leave=False)
    
    for batch_idx, (images, labels, fire_keys) in enumerate(pbar):
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        # Mixed precision forward
        dtype = torch.bfloat16 if USE_BF16 else torch.float16
        with torch.amp.autocast('cuda', dtype=dtype):
            logits = model(images)
            loss, components = criterion(logits, labels)
            loss = loss / accum_steps  # Scale for accumulation
        
        # Backward
        scaler.scale(loss).backward()
        
        # Accumulate gradients
        if (batch_idx + 1) % accum_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            scheduler.step()
        
        total_loss += loss.item() * accum_steps
        for k, v in components.items():
            loss_components[k] += v
        num_batches += 1
        
        pbar.set_postfix({'loss': total_loss / num_batches})
    
    avg_loss = total_loss / num_batches
    avg_components = {k: v / num_batches for k, v in loss_components.items()}
    
    return avg_loss, avg_components

@torch.no_grad()
def validate(model, loader, criterion, device):
    """Validate with per-fire tracking."""
    model.eval()
    
    total_loss = 0.0
    per_fire_metrics = defaultdict(lambda: {'mae': [], 'iou': []})
    
    for images, labels, fire_keys in tqdm(loader, desc="Validating", leave=False):
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        dtype = torch.bfloat16 if USE_BF16 else torch.float16
        with torch.amp.autocast('cuda', dtype=dtype):
            logits = model(images)
            loss, _ = criterion(logits, labels)
        
        total_loss += loss.item()
        
        # Per-sample metrics
        probs = torch.sigmoid(logits)
        
        for i in range(images.size(0)):
            fire = fire_keys[i]
            p = probs[i, 0].float().cpu().numpy()
            t = labels[i, 0].float().cpu().numpy()
            
            # MAE
            mae = np.abs(p - t).mean()
            per_fire_metrics[fire]['mae'].append(mae)
            
            # IoU
            p_bin = (p > 0.5).astype(float)
            t_bin = (t > 0.5).astype(float)
            intersection = (p_bin * t_bin).sum()
            union = p_bin.sum() + t_bin.sum() - intersection
            iou = intersection / (union + 1e-6)
            per_fire_metrics[fire]['iou'].append(iou)
    
    avg_loss = total_loss / len(loader)
    
    # Aggregate per-fire
    fire_summary = {}
    all_mae = []
    all_iou = []
    
    for fire, metrics in per_fire_metrics.items():
        fire_summary[fire] = {
            'mae': np.mean(metrics['mae']),
            'iou': np.mean(metrics['iou']),
        }
        all_mae.extend(metrics['mae'])
        all_iou.extend(metrics['iou'])
    
    return {
        'loss': avg_loss,
        'mae': np.mean(all_mae),
        'iou': np.mean(all_iou),
        'per_fire': fire_summary,
    }

print("✅ Training functions defined")

In [None]:
# Create checkpoint directory
os.makedirs(DRIVE_CHECKPOINT_PATH, exist_ok=True)
local_checkpoint_path = "/content/checkpoints"
os.makedirs(local_checkpoint_path, exist_ok=True)

print(f"💾 Checkpoints:")
print(f"   Local: {local_checkpoint_path}")
print(f"   Drive: {DRIVE_CHECKPOINT_PATH}")

In [None]:
# ============================================================
# RESUME FROM CHECKPOINT (Optional)
# ============================================================
# Set to True to resume from epoch 10 checkpoint

RESUME_FROM_CHECKPOINT = True
CHECKPOINT_PATH = f"{DRIVE_CHECKPOINT_PATH}/epoch_10.pth"
START_EPOCH = 10  # Will start from epoch 11

# Boost POS_WEIGHT for better fire detection
NEW_POS_WEIGHT = 7.0

if RESUME_FROM_CHECKPOINT:
    print("📂 Loading checkpoint...")
    
    checkpoint = torch.load(CHECKPOINT_PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    print(f"   ✅ Loaded model from epoch {checkpoint.get('epoch', 10)}")
    print(f"   📈 Previous best IoU: {checkpoint.get('val_iou', 'N/A')}")
    
    # Update criterion with new POS_WEIGHT
    criterion.pos_weight = NEW_POS_WEIGHT
    print(f"   🔧 POS_WEIGHT boosted: 2.5 → {NEW_POS_WEIGHT}")
    
    # Note: We create a fresh optimizer/scheduler for the remaining epochs
    remaining_epochs = EPOCHS - START_EPOCH
    steps_per_epoch = len(train_loader) // GRADIENT_ACCUMULATION
    
    optimizer = optim.AdamW(
        model.parameters(),
        lr=LEARNING_RATE * 0.5,  # Lower LR for fine-tuning
        weight_decay=WEIGHT_DECAY,
    )
    
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=LEARNING_RATE * 0.5,
        epochs=remaining_epochs,
        steps_per_epoch=steps_per_epoch,
        pct_start=0.1,
        anneal_strategy='cos',
    )
    
    print(f"   ⚙️ New optimizer with LR={LEARNING_RATE * 0.5:.2e}")
    print(f"   �� Will train for {remaining_epochs} more epochs")
else:
    START_EPOCH = 0
    print("Starting fresh training from epoch 0")

In [None]:
# ============================================================
# MAIN TRAINING LOOP - ADAPTIVE
# ============================================================
# Dynamically adjusts POS_WEIGHT if IoU stagnates

print("\n" + "="*70)
print("🚀 STARTING ADAPTIVE A100 TRAINING")
print("="*70)
print(f"   Epochs: {EPOCHS}")
print(f"   Batch size: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} = {BATCH_SIZE * GRADIENT_ACCUMULATION}")
print(f"   Learning rate: {LEARNING_RATE}")
print(f"   Device: {device}")
print("="*70 + "\n")

# Tracking
history = {
    'train_loss': [],
    'val_loss': [],
    'val_mae': [],
    'val_iou': [],
    'lr': [],
    'pos_weight': [],
}

best_iou = 0.0
best_mae = 1.0
patience_counter = 0
PATIENCE = 15

# Adaptive training parameters
current_pos_weight = POS_WEIGHT
iou_stagnation_counter = 0
IOU_STAGNATION_THRESHOLD = 5  # epochs without IoU improvement
IOU_MIN_THRESHOLD = 0.05      # if IoU < this after threshold epochs, boost
MAX_POS_WEIGHT = 15.0         # maximum pos_weight to try

training_start = time.time()

for epoch in range(START_EPOCH, EPOCHS):
    epoch_start = time.time()
    
    print(f"\n📈 Epoch {epoch + 1}/{EPOCHS}")
    print("-" * 50)
    
    # Train
    train_loss, train_components = train_epoch(
        model, train_loader, criterion, optimizer, scheduler, scaler, device, GRADIENT_ACCUMULATION
    )
    
    # Validate
    val_metrics = validate(model, val_loader, criterion, device)
    
    current_lr = optimizer.param_groups[0]['lr']
    epoch_time = time.time() - epoch_start
    
    # Log
    print(f"   Train Loss: {train_loss:.4f} (BCE={train_components['bce']:.4f}, Dice={train_components['dice']:.4f})")
    print(f"   Val Loss: {val_metrics['loss']:.4f}")
    print(f"   Val MAE: {val_metrics['mae']:.4f}")
    print(f"   Val IoU: {val_metrics['iou']:.4f}")
    print(f"   LR: {current_lr:.2e} | POS_WEIGHT: {current_pos_weight:.1f}")
    print(f"   Time: {epoch_time:.1f}s")
    
    # Per-fire metrics
    if val_metrics['per_fire']:
        print("\n   Per-Fire IoU:")
        for fire, metrics in sorted(val_metrics['per_fire'].items()):
            print(f"      {fire:<25}: {metrics['iou']:.4f} (MAE: {metrics['mae']:.4f})")
    
    # History
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_metrics['loss'])
    history['val_mae'].append(val_metrics['mae'])
    history['val_iou'].append(val_metrics['iou'])
    history['lr'].append(current_lr)
    history['pos_weight'].append(current_pos_weight)
    
    # ============================================================
    # ADAPTIVE POS_WEIGHT ADJUSTMENT
    # ============================================================
    improved = False
    
    if val_metrics['iou'] > best_iou * 1.05:  # 5% improvement threshold
        best_iou = val_metrics['iou']
        improved = True
        iou_stagnation_counter = 0
        
        # Save best (local first, then Drive)
        best_path = f"{local_checkpoint_path}/best_model.pth"
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_iou': val_metrics['iou'],
            'val_mae': val_metrics['mae'],
            'pos_weight': current_pos_weight,
        }, best_path)
        
        # Copy to Drive
        if SAVE_TO_DRIVE:
            import shutil
            shutil.copy(best_path, f"{DRIVE_CHECKPOINT_PATH}/best_model.pth")
        
        print(f"\n   ⭐ NEW BEST IoU: {val_metrics['iou']:.4f}")
    else:
        iou_stagnation_counter += 1
    
    if val_metrics['mae'] < best_mae:
        best_mae = val_metrics['mae']
    
    # Check for stagnation and boost POS_WEIGHT
    if (iou_stagnation_counter >= IOU_STAGNATION_THRESHOLD and 
        val_metrics['iou'] < IOU_MIN_THRESHOLD and 
        current_pos_weight < MAX_POS_WEIGHT):
        
        old_weight = current_pos_weight
        current_pos_weight = min(current_pos_weight * 2, MAX_POS_WEIGHT)
        
        # Update criterion with new pos_weight
        criterion.pos_weight = current_pos_weight
        
        print(f"\n   🔧 ADAPTIVE ADJUSTMENT: POS_WEIGHT {old_weight:.1f} → {current_pos_weight:.1f}")
        print(f"      (IoU stagnated at {val_metrics['iou']:.4f} for {iou_stagnation_counter} epochs)")
        
        iou_stagnation_counter = 0  # Reset counter
    
    # Early stopping based on IoU
    if improved:
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print(f"\n⚠️ Early stopping triggered (no improvement for {PATIENCE} epochs)")
            break
    
    # Periodic checkpoint
    if (epoch + 1) % 10 == 0:
        periodic_path = f"{local_checkpoint_path}/epoch_{epoch+1}.pth"
        torch.save({'epoch': epoch, 'model_state_dict': model.state_dict()}, periodic_path)
        
        if SAVE_TO_DRIVE:
            import shutil
            shutil.copy(periodic_path, f"{DRIVE_CHECKPOINT_PATH}/epoch_{epoch+1}.pth")
        
        print(f"   💾 Checkpoint saved: epoch_{epoch+1}.pth")

# Final save
total_time = time.time() - training_start

print("\n" + "="*70)
print("🎉 TRAINING COMPLETE!")
print("="*70)
print(f"   Total time: {total_time/3600:.1f} hours")
print(f"   Best IoU: {best_iou:.4f}")
print(f"   Best MAE: {best_mae:.4f}")
print(f"   Final POS_WEIGHT: {current_pos_weight:.1f}")
print(f"   Checkpoints saved to: {DRIVE_CHECKPOINT_PATH}")
print("="*70)

## 4. Training History

In [None]:
# Plot training history
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss
axes[0, 0].plot(history['train_loss'], label='Train', linewidth=2)
axes[0, 0].plot(history['val_loss'], label='Val', linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Loss', fontsize=12, fontweight='bold')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# IoU
axes[0, 1].plot(history['val_iou'], 'g-', linewidth=2)
axes[0, 1].axhline(best_iou, color='r', linestyle='--', label=f'Best: {best_iou:.4f}')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('IoU')
axes[0, 1].set_title('Validation IoU', fontsize=12, fontweight='bold')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# MAE
axes[1, 0].plot(history['val_mae'], 'r-', linewidth=2)
axes[1, 0].axhline(best_mae, color='g', linestyle='--', label=f'Best: {best_mae:.4f}')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('MAE')
axes[1, 0].set_title('Validation MAE', fontsize=12, fontweight='bold')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# LR
axes[1, 1].semilogy(history['lr'], 'purple', linewidth=2)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Learning Rate')
axes[1, 1].set_title('Learning Rate Schedule', fontsize=12, fontweight='bold')
axes[1, 1].grid(True, alpha=0.3)

plt.suptitle('🔥 California Fire Model - Training History', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(f"{DRIVE_CHECKPOINT_PATH}/training_history.png", dpi=150, bbox_inches='tight')
plt.show()

# Save history
with open(f"{DRIVE_CHECKPOINT_PATH}/training_history.json", 'w') as f:
    json.dump(history, f, indent=2)

print(f"\n✅ Training history saved!")

## 5. Quick Test

In [None]:
# Load best model and test
model.load_state_dict(torch.load(f"{local_checkpoint_path}/best_model.pth")['model_state_dict'])
model.eval()

# Get a sample
sample_batch = next(iter(val_loader))
images, labels, fire_keys = sample_batch
images = images.to(device)

with torch.no_grad():
    with torch.cuda.amp.autocast():
        logits = model(images)
    probs = torch.sigmoid(logits).cpu().numpy()

# Visualize
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for i in range(4):
    pred = probs[i, 0]
    gt = labels[i, 0].numpy()
    fire = fire_keys[i]
    
    # Prediction
    axes[0, i].imshow(pred, cmap='hot', vmin=0, vmax=1)
    axes[0, i].set_title(f"{fire}\nPred: {pred.mean():.1%}")
    axes[0, i].axis('off')
    
    # Ground truth
    axes[1, i].imshow(gt, cmap='hot', vmin=0, vmax=1)
    axes[1, i].set_title(f"GT: {gt.mean():.1%}")
    axes[1, i].axis('off')

axes[0, 0].set_ylabel('Prediction', fontsize=12)
axes[1, 0].set_ylabel('Ground Truth', fontsize=12)

plt.suptitle('Sample Predictions', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("\n✅ Model is working!")

---

## ✅ Training Complete!

**Your model is saved at:**
- Google Drive: `California_Fire_Model/checkpoints/best_model.pth`

**Next steps:**
1. Download `best_model.pth` from Drive
2. Use it in your hackathon demo with `03_demo.ipynb`

**Metrics:**
- Best IoU: Check the training output above
- Best MAE: Check the training output above

## 6. Comprehensive Model Test\n\nTest the trained model on 10+ California forest fire images with proper preprocessing.

In [None]:
# ============================================================
# COMPREHENSIVE MODEL TEST - 10+ CALIFORNIA FOREST FIRES
# ============================================================
# Tests model on forest/mountain fire tiles with proper preprocessing

import matplotlib.pyplot as plt
from collections import defaultdict
import pandas as pd

# Load best model
checkpoint = torch.load(f"{DRIVE_CHECKPOINT_PATH}/best_model.pth", weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"✅ Loaded model from epoch {checkpoint.get('epoch', 'unknown')}")
print(f"   Best val IoU: {checkpoint.get('val_iou', 'N/A')}")

# Forest/mountain fires for testing (similar to training distribution)
TEST_FIRES = ['dixie', 'caldor', 'camp', 'creek', 'mendocino', 'thomas']
SAMPLES_PER_FIRE = 2  # 2 samples per fire = 12 total
MIN_FIRE_PERCENT = 10  # At least 10% fire damage in tile

# Collect test samples
test_samples = []

for fire in TEST_FIRES:
    fire_samples = [s for s in train_dataset.samples if s['fire_key'] == fire]
    found = 0
    
    for sample in fire_samples:
        if found >= SAMPLES_PER_FIRE:
            break
            
        with rasterio.open(sample['path']) as src:
            data = src.read()
        
        if data.shape[0] < 11:
            continue
            
        label = data[10].astype(np.float32)
        label = np.nan_to_num(label, nan=0.0)
        fire_percent = (label > 0.5).mean() * 100
        
        if fire_percent >= MIN_FIRE_PERCENT:
            test_samples.append({
                'path': sample['path'],
                'fire': fire,
                'fire_percent': fire_percent,
            })
            found += 1

print(f"\n📊 Found {len(test_samples)} test samples with >={MIN_FIRE_PERCENT}% fire damage")
for fire in TEST_FIRES:
    count = len([s for s in test_samples if s['fire'] == fire])
    print(f"   {fire}: {count} samples")


In [None]:
# ============================================================
# RUN PREDICTIONS AND COMPUTE METRICS
# ============================================================

results = []

for i, sample in enumerate(test_samples):
    # Load data
    with rasterio.open(sample['path']) as src:
        data = src.read()
    
    image = data[:10].astype(np.float32)
    label = data[10].astype(np.float32)
    
    # Clean NaN
    image = np.nan_to_num(image, nan=0.0, posinf=10000.0, neginf=0.0)
    label = np.nan_to_num(label, nan=0.0)
    label = np.clip(label, 0.0, 1.0)
    
    # Normalize - EXACT same as training
    image = np.clip(image, 0, 10000)
    for j in range(10):
        image[j] = (image[j] - BAND_MEANS[j]) / (BAND_STDS[j] + 1e-6)
    image = np.clip(image, -3, 3)
    image = (image + 3) / 6
    
    # Predict
    x = torch.from_numpy(image).unsqueeze(0).float().to(device)
    with torch.no_grad():
        with torch.amp.autocast('cuda'):
            logits = model(x)
        probs = torch.sigmoid(logits).float().cpu().numpy()[0, 0]
    
    # Compute metrics at different thresholds
    gt_binary = (label > 0.5).astype(float)
    
    result = {
        'fire': sample['fire'],
        'fire_percent': sample['fire_percent'],
        'pred_max': probs.max(),
        'pred_mean': probs.mean(),
        'mae': np.abs(probs - label).mean(),
    }
    
    for thresh in [0.2, 0.3, 0.4, 0.5]:
        pred_bin = (probs > thresh).astype(float)
        intersection = (pred_bin * gt_binary).sum()
        union = pred_bin.sum() + gt_binary.sum() - intersection
        iou = intersection / (union + 1e-6)
        result[f'iou_{thresh}'] = iou
    
    result['image'] = image
    result['label'] = label
    result['probs'] = probs
    results.append(result)

print(f"\n✅ Processed {len(results)} samples")


In [None]:
# ============================================================
# METRICS SUMMARY TABLE
# ============================================================

# Create summary dataframe
df = pd.DataFrame([{
    'Fire': r['fire'],
    'Fire %': f"{r['fire_percent']:.1f}%",
    'Pred Max': f"{r['pred_max']:.2f}",
    'MAE': f"{r['mae']:.3f}",
    'IoU@0.2': f"{r['iou_0.2']:.2f}",
    'IoU@0.3': f"{r['iou_0.3']:.2f}",
    'IoU@0.4': f"{r['iou_0.4']:.2f}",
    'IoU@0.5': f"{r['iou_0.5']:.2f}",
} for r in results])

print("\n" + "="*80)
print("📊 MODEL TEST RESULTS - CALIFORNIA FOREST FIRES")
print("="*80)
print(df.to_string(index=False))

# Aggregate stats
print("\n" + "-"*80)
print("📈 AGGREGATE METRICS")
print("-"*80)

avg_iou_02 = np.mean([r['iou_0.2'] for r in results])
avg_iou_03 = np.mean([r['iou_0.3'] for r in results])
avg_iou_04 = np.mean([r['iou_0.4'] for r in results])
avg_iou_05 = np.mean([r['iou_0.5'] for r in results])
avg_mae = np.mean([r['mae'] for r in results])

print(f"   Average MAE:      {avg_mae:.4f}")
print(f"   Average IoU@0.2:  {avg_iou_02:.4f}")
print(f"   Average IoU@0.3:  {avg_iou_03:.4f}")
print(f"   Average IoU@0.4:  {avg_iou_04:.4f}")
print(f"   Average IoU@0.5:  {avg_iou_05:.4f}")

# Per-fire aggregate
print("\n" + "-"*80)
print("🔥 PER-FIRE AVERAGE IoU@0.3")
print("-"*80)

for fire in TEST_FIRES:
    fire_results = [r for r in results if r['fire'] == fire]
    if fire_results:
        avg = np.mean([r['iou_0.3'] for r in fire_results])
        print(f"   {fire:<15}: {avg:.4f}")

print("="*80)


In [None]:
# ============================================================
# VISUALIZATION - 10 SAMPLE PREDICTIONS
# ============================================================

n_samples = min(10, len(results))
fig, axes = plt.subplots(n_samples, 4, figsize=(16, 4*n_samples))

for i in range(n_samples):
    r = results[i]
    
    # RGB composite
    rgb = np.stack([r['image'][3], r['image'][2], r['image'][1]], axis=0)
    rgb = np.clip(rgb, 0, 1).transpose(1, 2, 0)
    
    # Row 1: RGB
    axes[i, 0].imshow(rgb)
    axes[i, 0].set_title(f"{r['fire']}\nRGB", fontsize=10)
    axes[i, 0].axis('off')
    
    # Row 2: Ground Truth
    axes[i, 1].imshow(r['label'], cmap='hot', vmin=0, vmax=1)
    axes[i, 1].set_title(f"Ground Truth\n{r['fire_percent']:.1f}% fire", fontsize=10)
    axes[i, 1].axis('off')
    
    # Row 3: Prediction Heatmap
    axes[i, 2].imshow(r['probs'], cmap='hot', vmin=0, vmax=1)
    axes[i, 2].set_title(f"Prediction\nmax={r['pred_max']:.2f}", fontsize=10)
    axes[i, 2].axis('off')
    
    # Row 4: Binary @ 0.3
    pred_binary = (r['probs'] > 0.3).astype(float)
    axes[i, 3].imshow(pred_binary, cmap='hot', vmin=0, vmax=1)
    axes[i, 3].set_title(f"Binary @ 0.3\nIoU={r['iou_0.3']:.2f}", fontsize=10)
    axes[i, 3].axis('off')

plt.suptitle('🔥 California Fire Model - Test Results', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(f"{DRIVE_CHECKPOINT_PATH}/comprehensive_test_results.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"\n✅ Results saved to {DRIVE_CHECKPOINT_PATH}/comprehensive_test_results.png")


## 7. Test on Random Forest/Deforestation Areas (Earth Engine)

Download 10 random forest deforestation areas from Google Earth Engine and test the model.

**Note:** The model was trained on FIRE damage. Performance on deforestation is unknown - 
it may generalize if spectral signatures are similar.

In [None]:
# ============================================================
# DOWNLOAD RANDOM FOREST/DEFORESTATION IMAGES FROM GEE
# ============================================================
# Uses same bands and preprocessing as training data

import ee
import requests
from io import BytesIO

# Initialize Earth Engine (should already be authenticated from setup)
try:
    ee.Initialize()
    print("✅ Earth Engine initialized")
except:
    ee.Authenticate()
    ee.Initialize()
    print("✅ Earth Engine authenticated and initialized")

# Sentinel-2 bands (same as training)
S2_BANDS = ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B11', 'B12']

# Random forest/deforestation locations (global hotspots)
# Mix of deforestation, logging, and intact forest for comparison
TEST_LOCATIONS = [
    # Amazon deforestation hotspots
    {'name': 'amazon_rondonia_1', 'lon': -63.5, 'lat': -10.5, 'year': 2023},
    {'name': 'amazon_para_1', 'lon': -55.0, 'lat': -6.0, 'year': 2023},
    {'name': 'amazon_mato_grosso', 'lon': -56.5, 'lat': -12.0, 'year': 2023},
    
    # Southeast Asia deforestation
    {'name': 'borneo_kalimantan', 'lon': 116.0, 'lat': 0.5, 'year': 2023},
    {'name': 'sumatra_riau', 'lon': 102.0, 'lat': 0.5, 'year': 2023},
    
    # California forests (for comparison - should detect if degraded)
    {'name': 'california_sierra_1', 'lon': -120.5, 'lat': 38.5, 'year': 2023},
    {'name': 'california_plumas', 'lon': -121.0, 'lat': 40.0, 'year': 2023},
    
    # African deforestation
    {'name': 'congo_basin', 'lon': 21.0, 'lat': 1.0, 'year': 2023},
    
    # More California/Oregon forests
    {'name': 'oregon_cascade', 'lon': -122.0, 'lat': 44.0, 'year': 2023},
    {'name': 'california_shasta', 'lon': -122.5, 'lat': 41.0, 'year': 2023},
]

print(f"📍 Will download {len(TEST_LOCATIONS)} test locations")


In [None]:
# ============================================================
# DOWNLOAD FUNCTION - Same preprocessing as training data
# ============================================================

def download_gee_tile(location, tile_size=256):
    """
    Download a tile from GEE with same processing as training data.
    """
    lon, lat = location['lon'], location['lat']
    year = location['year']
    
    # Create point and buffer for tile
    point = ee.Geometry.Point([lon, lat])
    
    # Get Sentinel-2 imagery
    # Use cloud-free composite for the year
    s2 = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED') \
        .filterBounds(point) \
        .filterDate(f'{year}-01-01', f'{year}-12-31') \
        .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 20)) \
        .select(S2_BANDS) \
        .median()
    
    # Get 256x256 tile (10m resolution = 2560m x 2560m)
    scale = 10  # 10m per pixel
    region = point.buffer(tile_size * scale / 2).bounds()
    
    try:
        # Get URL for the image
        url = s2.getDownloadURL({
            'bands': S2_BANDS,
            'region': region,
            'scale': scale,
            'format': 'NPY',
        })
        
        # Download
        response = requests.get(url, timeout=60)
        if response.status_code == 200:
            data = np.load(BytesIO(response.content), allow_pickle=True)
            
            # Convert structured array to regular array
            if data.dtype.names:
                # Structured array - extract bands
                bands = [data[band] for band in S2_BANDS]
                data = np.stack(bands, axis=0)
            
            return data
        else:
            print(f"   ❌ Download failed: {response.status_code}")
            return None
            
    except Exception as e:
        print(f"   ❌ Error: {str(e)[:50]}")
        return None

print("✅ Download function defined")


In [None]:
# ============================================================
# DOWNLOAD ALL TEST TILES
# ============================================================

gee_test_results = []

print("📥 Downloading test tiles from Earth Engine...")
print("-" * 50)

for loc in TEST_LOCATIONS:
    print(f"   Downloading {loc['name']}...", end=" ")
    
    data = download_gee_tile(loc)
    
    if data is not None:
        # Preprocess EXACTLY like training data
        image = data.astype(np.float32)
        
        # Handle shape
        if len(image.shape) == 2:
            print(f"❌ Wrong shape {image.shape}")
            continue
        
        # Ensure 10 bands
        if image.shape[0] != 10:
            print(f"❌ Wrong bands {image.shape[0]}")
            continue
        
        # Resize to 256x256 if needed
        h, w = image.shape[1], image.shape[2]
        if h != 256 or w != 256:
            # Center crop or pad
            new_image = np.zeros((10, 256, 256), dtype=np.float32)
            sh, sw = min(h, 256), min(w, 256)
            oh, ow = (h - sh) // 2, (w - sw) // 2
            nh, nw = (256 - sh) // 2, (256 - sw) // 2
            new_image[:, nh:nh+sh, nw:nw+sw] = image[:, oh:oh+sh, ow:ow+sw]
            image = new_image
        
        # Clean NaN
        image = np.nan_to_num(image, nan=0.0, posinf=10000.0, neginf=0.0)
        
        # Normalize - EXACT same as training
        image = np.clip(image, 0, 10000)
        for j in range(10):
            image[j] = (image[j] - BAND_MEANS[j]) / (BAND_STDS[j] + 1e-6)
        image = np.clip(image, -3, 3)
        image = (image + 3) / 6
        
        gee_test_results.append({
            'name': loc['name'],
            'image': image,
            'lon': loc['lon'],
            'lat': loc['lat'],
        })
        print(f"✅ Shape: {image.shape}")
    else:
        print("❌ Failed")

print("-" * 50)
print(f"\n✅ Downloaded {len(gee_test_results)} tiles successfully")


In [None]:
# ============================================================
# RUN PREDICTIONS ON GEE TILES
# ============================================================

print("🔮 Running predictions on downloaded tiles...")

for result in gee_test_results:
    image = result['image']
    
    # Predict
    x = torch.from_numpy(image).unsqueeze(0).float().to(device)
    with torch.no_grad():
        with torch.amp.autocast('cuda'):
            logits = model(x)
        probs = torch.sigmoid(logits).float().cpu().numpy()[0, 0]
    
    result['probs'] = probs
    result['pred_max'] = probs.max()
    result['pred_mean'] = probs.mean()
    result['degradation_percent'] = (probs > 0.3).mean() * 100

print("\n" + "="*70)
print("📊 DEFORESTATION/DEGRADATION PREDICTIONS")
print("="*70)
print(f"{'Location':<25} {'Max Prob':>10} {'Mean Prob':>10} {'Degradation %':>15}")
print("-"*70)

for r in gee_test_results:
    print(f"{r['name']:<25} {r['pred_max']:>10.2f} {r['pred_mean']:>10.4f} {r['degradation_percent']:>14.1f}%")

print("="*70)
print("\n💡 Note: High degradation % may indicate fire damage, deforestation, or")
print("   other vegetation stress the model learned to detect.")


In [None]:
# ============================================================
# VISUALIZE GEE PREDICTIONS
# ============================================================

n_gee = min(10, len(gee_test_results))
fig, axes = plt.subplots(n_gee, 3, figsize=(12, 4*n_gee))

for i in range(n_gee):
    r = gee_test_results[i]
    
    # RGB composite
    rgb = np.stack([r['image'][3], r['image'][2], r['image'][1]], axis=0)
    rgb = np.clip(rgb, 0, 1).transpose(1, 2, 0)
    
    # RGB
    axes[i, 0].imshow(rgb)
    axes[i, 0].set_title(f"{r['name']}\n({r['lat']:.1f}, {r['lon']:.1f})", fontsize=10)
    axes[i, 0].axis('off')
    
    # Prediction heatmap
    axes[i, 1].imshow(r['probs'], cmap='hot', vmin=0, vmax=1)
    axes[i, 1].set_title(f"Degradation Probability\nmax={r['pred_max']:.2f}", fontsize=10)
    axes[i, 1].axis('off')
    
    # Binary @ 0.3
    axes[i, 2].imshow(r['probs'] > 0.3, cmap='hot')
    axes[i, 2].set_title(f"Binary @ 0.3\n{r['degradation_percent']:.1f}% detected", fontsize=10)
    axes[i, 2].axis('off')

plt.suptitle('🌲 Forest Degradation Detection - GEE Test Images', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(f"{DRIVE_CHECKPOINT_PATH}/gee_deforestation_test.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"\n✅ Results saved to {DRIVE_CHECKPOINT_PATH}/gee_deforestation_test.png")


## 8. Test on California Vegetation & Fire Areas (NEW)

Download 10 random California locations - mix of fire-affected and healthy vegetation.

In [None]:
# ============================================================
# CALIFORNIA-SPECIFIC TEST LOCATIONS
# ============================================================
# Recent fire areas + healthy vegetation

# California fire locations (2020-2023 major fires)
CA_TEST_LOCATIONS = [
    # 2021 Dixie Fire area (largest in CA history)
    {'name': 'dixie_fire_2023_center', 'lon': -121.4, 'lat': 40.1, 'expected': 'fire'},
    {'name': 'dixie_fire_2023_edge', 'lon': -121.2, 'lat': 39.9, 'expected': 'fire'},
    
    # 2020 Creek Fire
    {'name': 'creek_fire_2023', 'lon': -119.3, 'lat': 37.2, 'expected': 'fire'},
    
    # 2021 Caldor Fire
    {'name': 'caldor_fire_2023', 'lon': -120.3, 'lat': 38.7, 'expected': 'fire'},
    
    # 2020 August Complex Fire
    {'name': 'august_complex_2023', 'lon': -122.7, 'lat': 39.8, 'expected': 'fire'},
    
    # Healthy California forests (should NOT detect fire)
    {'name': 'yosemite_healthy', 'lon': -119.6, 'lat': 37.8, 'expected': 'healthy'},
    {'name': 'redwood_coast', 'lon': -123.9, 'lat': 41.2, 'expected': 'healthy'},
    {'name': 'big_sur_forest', 'lon': -121.8, 'lat': 36.2, 'expected': 'healthy'},
    {'name': 'tahoe_forest', 'lon': -120.1, 'lat': 39.1, 'expected': 'healthy'},
    {'name': 'sequoia_healthy', 'lon': -118.6, 'lat': 36.5, 'expected': 'healthy'},
]

print(f"📍 California Test Locations: {len(CA_TEST_LOCATIONS)}")
print(f"   Fire areas: {len([l for l in CA_TEST_LOCATIONS if l['expected'] == 'fire'])}")
print(f"   Healthy vegetation: {len([l for l in CA_TEST_LOCATIONS if l['expected'] == 'healthy'])}")


In [None]:
# ============================================================
# DOWNLOAD CALIFORNIA TEST TILES
# ============================================================

ca_test_results = []

print("📥 Downloading California test tiles...")
print("-" * 60)

for loc in CA_TEST_LOCATIONS:
    print(f"   {loc['name']} ({loc['expected']})...", end=" ")
    
    # Use 2023 imagery to capture current state
    loc_with_year = {**loc, 'year': 2023}
    data = download_gee_tile(loc_with_year)
    
    if data is not None:
        # Preprocess EXACTLY like training data
        image = data.astype(np.float32)
        
        if len(image.shape) == 2 or image.shape[0] != 10:
            print(f"❌ Wrong shape")
            continue
        
        # Resize if needed
        h, w = image.shape[1], image.shape[2]
        if h != 256 or w != 256:
            new_image = np.zeros((10, 256, 256), dtype=np.float32)
            sh, sw = min(h, 256), min(w, 256)
            oh, ow = (h - sh) // 2, (w - sw) // 2
            nh, nw = (256 - sh) // 2, (256 - sw) // 2
            new_image[:, nh:nh+sh, nw:nw+sw] = image[:, oh:oh+sh, ow:ow+sw]
            image = new_image
        
        # Normalize
        image = np.nan_to_num(image, nan=0.0, posinf=10000.0, neginf=0.0)
        image = np.clip(image, 0, 10000)
        for j in range(10):
            image[j] = (image[j] - BAND_MEANS[j]) / (BAND_STDS[j] + 1e-6)
        image = np.clip(image, -3, 3)
        image = (image + 3) / 6
        
        ca_test_results.append({
            'name': loc['name'],
            'expected': loc['expected'],
            'image': image,
            'lon': loc['lon'],
            'lat': loc['lat'],
        })
        print(f"✅")
    else:
        print(f"❌")

print("-" * 60)
print(f"✅ Downloaded {len(ca_test_results)} California tiles")


In [None]:
# ============================================================
# PREDICT & EVALUATE CALIFORNIA TILES
# ============================================================

print("🔮 Running predictions...")

for r in ca_test_results:
    x = torch.from_numpy(r['image']).unsqueeze(0).float().to(device)
    with torch.no_grad():
        with torch.amp.autocast('cuda'):
            logits = model(x)
        probs = torch.sigmoid(logits).float().cpu().numpy()[0, 0]
    
    r['probs'] = probs
    r['pred_max'] = probs.max()
    r['pred_mean'] = probs.mean()
    r['fire_percent'] = (probs > 0.3).mean() * 100
    
    # Evaluate correctness
    if r['expected'] == 'fire':
        r['correct'] = r['fire_percent'] > 5  # Should detect some fire
    else:
        r['correct'] = r['fire_percent'] < 10  # Should NOT detect fire

# Results table
print("\n" + "="*80)
print("📊 CALIFORNIA FIRE DETECTION TEST")
print("="*80)
print(f"{'Location':<25} {'Expected':>10} {'Max':>8} {'Fire%':>8} {'Result':>10}")
print("-"*80)

for r in ca_test_results:
    result_icon = '✅' if r['correct'] else '❌'
    print(f"{r['name']:<25} {r['expected']:>10} {r['pred_max']:>8.2f} {r['fire_percent']:>7.1f}% {result_icon:>10}")

# Summary
correct = sum(1 for r in ca_test_results if r['correct'])
total = len(ca_test_results)
accuracy = correct / total * 100 if total > 0 else 0

print("="*80)
print(f"\n📈 ACCURACY: {correct}/{total} = {accuracy:.0f}%")

# Breakdown
fire_locs = [r for r in ca_test_results if r['expected'] == 'fire']
healthy_locs = [r for r in ca_test_results if r['expected'] == 'healthy']

fire_correct = sum(1 for r in fire_locs if r['correct'])
healthy_correct = sum(1 for r in healthy_locs if r['correct'])

print(f"   Fire detection: {fire_correct}/{len(fire_locs)}")
print(f"   Healthy (no false positives): {healthy_correct}/{len(healthy_locs)}")


In [None]:
# ============================================================
# VISUALIZE CALIFORNIA RESULTS
# ============================================================

n = len(ca_test_results)
fig, axes = plt.subplots(n, 3, figsize=(12, 4*n))

for i, r in enumerate(ca_test_results):
    # RGB
    rgb = np.stack([r['image'][3], r['image'][2], r['image'][1]], axis=0)
    rgb = np.clip(rgb, 0, 1).transpose(1, 2, 0)
    
    result_icon = '✅' if r['correct'] else '❌'
    
    axes[i, 0].imshow(rgb)
    axes[i, 0].set_title(f"{r['name']}\n({r['expected']}) {result_icon}", fontsize=10)
    axes[i, 0].axis('off')
    
    axes[i, 1].imshow(r['probs'], cmap='hot', vmin=0, vmax=1)
    axes[i, 1].set_title(f"Prediction\nmax={r['pred_max']:.2f}", fontsize=10)
    axes[i, 1].axis('off')
    
    axes[i, 2].imshow(r['probs'] > 0.3, cmap='hot')
    axes[i, 2].set_title(f"Binary @ 0.3\n{r['fire_percent']:.1f}% fire", fontsize=10)
    axes[i, 2].axis('off')

plt.suptitle('🔥 California Fire Detection - Real-World Test', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(f"{DRIVE_CHECKPOINT_PATH}/california_test_results.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"\n✅ Saved to {DRIVE_CHECKPOINT_PATH}/california_test_results.png")
