# üî• California Fire Model - Hackathon Demo

Interactive demonstration of burn severity detection and recovery monitoring.

This notebook showcases:
1. Real-time burn severity prediction
2. Temporal progression visualization (burn ‚Üí recovery)
3. Interactive location selection
4. Model performance on held-out fires

In [None]:
import sys
from pathlib import Path

# Setup paths
PROJECT_ROOT = Path('.').resolve().parent
sys.path.insert(0, str(PROJECT_ROOT))

import torch
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, HTML

print(f"PyTorch: {torch.__version__}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

## 1. Load Model

In [None]:
from inference.predict import FirePredictor
from config import CHECKPOINT_DIR

# Load trained model
model_path = CHECKPOINT_DIR / 'best_model.pth'

if model_path.exists():
    predictor = FirePredictor(str(model_path), use_tta=True)
    print("\n‚úÖ Model loaded and ready!")
else:
    print(f"\n‚ùå Model not found at {model_path}")
    print("   Please train the model first using 02_train_model.ipynb")

## 2. Initialize Earth Engine

In [None]:
import ee
import requests
from rasterio.io import MemoryFile

from config import EE_PROJECT_ID, BANDS

try:
    ee.Initialize(project=EE_PROJECT_ID)
    print("‚úÖ Earth Engine initialized")
except:
    ee.Authenticate()
    ee.Initialize(project=EE_PROJECT_ID)
    print("‚úÖ Earth Engine initialized")

In [None]:
def fetch_sentinel2(lat, lon, date_start, date_end, buffer_m=1280):
    """
    Fetch Sentinel-2 imagery for a location.
    
    Returns:
        image: (10, H, W) numpy array
    """
    point = ee.Geometry.Point([lon, lat])
    region = point.buffer(buffer_m).bounds()
    
    s2 = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED') \
        .filterBounds(point) \
        .filterDate(date_start, date_end) \
        .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 15)) \
        .median() \
        .select(BANDS)
    
    url = s2.getDownloadURL({
        'scale': 10,
        'region': region,
        'format': 'GEO_TIFF',
        'crs': 'EPSG:3857'
    })
    
    response = requests.get(url, timeout=60)
    
    with MemoryFile(response.content) as memfile:
        with memfile.open() as src:
            data = src.read()
    
    # Handle NaN
    data = np.nan_to_num(data, nan=0.0, posinf=10000.0, neginf=0.0)
    
    return data

print("‚úÖ Fetch function ready")

## 3. Demo Locations

In [None]:
# Demo locations for the hackathon
DEMO_LOCATIONS = {
    'Dixie Fire - Core Burn': {
        'lat': 40.05,
        'lon': -121.15,
        'pre_dates': ('2021-05-01', '2021-06-30'),
        'post_dates': ('2021-10-01', '2021-11-30'),
        'recovery_dates': ('2023-06-01', '2023-08-31'),
        'description': 'Largest fire of 2021 (963K acres)',
    },
    'Caldor Fire - South Tahoe': {
        'lat': 38.82,
        'lon': -120.08,
        'pre_dates': ('2021-06-01', '2021-07-31'),
        'post_dates': ('2021-09-15', '2021-11-15'),
        'recovery_dates': ('2023-06-01', '2023-08-31'),
        'description': 'Crossed the Sierra Nevada',
    },
    'Camp Fire - Paradise': {
        'lat': 39.76,
        'lon': -121.61,
        'pre_dates': ('2018-06-01', '2018-09-30'),
        'post_dates': ('2018-12-01', '2019-02-28'),
        'recovery_dates': ('2023-06-01', '2023-08-31'),
        'description': 'Destroyed Paradise, CA (2018)',
    },
    'Healthy Forest - Tahoe': {
        'lat': 39.10,
        'lon': -120.05,
        'pre_dates': ('2024-06-01', '2024-08-31'),
        'post_dates': ('2024-06-01', '2024-08-31'),  # Same (healthy)
        'recovery_dates': ('2024-06-01', '2024-08-31'),
        'description': 'Healthy forest reference',
    },
}

# Display as table
print("üìç Demo Locations:")
print("-" * 60)
for name, info in DEMO_LOCATIONS.items():
    print(f"\n   {name}")
    print(f"   ‚îî‚îÄ {info['description']}")
    print(f"      Lat: {info['lat']}, Lon: {info['lon']}")

## 4. Run Predictions

In [None]:
from inference.visualize import (
    plot_prediction, plot_temporal_series,
    rgb_from_sentinel2, get_severity_cmap
)

def run_demo(location_name):
    """
    Run a complete demo for a location:
    - Fetch pre-fire, post-fire, and recovery imagery
    - Predict burn severity
    - Visualize temporal progression
    """
    loc = DEMO_LOCATIONS[location_name]
    
    print(f"\n{'='*60}")
    print(f"üî• {location_name}")
    print(f"   {loc['description']}")
    print(f"{'='*60}")
    
    # Fetch images
    stages = [
        ('Pre-Fire', loc['pre_dates']),
        ('Post-Fire', loc['post_dates']),
        ('Recovery', loc['recovery_dates']),
    ]
    
    images = {}
    predictions = {}
    
    for stage_name, (start, end) in stages:
        print(f"\nüì• Fetching {stage_name}: {start} to {end}")
        
        try:
            img = fetch_sentinel2(loc['lat'], loc['lon'], start, end)
            images[stage_name] = img
            
            # Center crop to 256x256
            c, h, w = img.shape
            cy, cx = h // 2, w // 2
            if h >= 256 and w >= 256:
                img_crop = img[:, cy-128:cy+128, cx-128:cx+128]
            else:
                img_crop = img
            
            # Predict
            severity, confidence = predictor.predict_tile(img_crop)
            predictions[stage_name] = severity
            
            print(f"   ‚úÖ Severity: {severity.mean():.1%} (confidence: {confidence:.1%})")
            
        except Exception as e:
            print(f"   ‚ùå Error: {e}")
    
    # Visualize temporal progression
    if len(predictions) == 3:
        fig = plot_temporal_series(
            {
                'pre_fire': predictions['Pre-Fire'],
                'post_fire': predictions['Post-Fire'],
                'recovery': predictions['Recovery'],
            },
            title=f"{location_name} - Fire Progression"
        )
        plt.show()
        
        # Calculate recovery percentage
        if predictions['Post-Fire'].mean() > 0.1:
            post_severity = predictions['Post-Fire'].mean()
            recovery_severity = predictions['Recovery'].mean()
            recovery_pct = (post_severity - recovery_severity) / post_severity * 100
            
            print(f"\nüìä Recovery Analysis:")
            print(f"   Post-fire severity: {post_severity:.1%}")
            print(f"   Current severity: {recovery_severity:.1%}")
            print(f"   Recovery: {recovery_pct:.0f}%")
    
    return images, predictions

In [None]:
# Run demo for Dixie Fire
images, predictions = run_demo('Dixie Fire - Core Burn')

In [None]:
# Run demo for Caldor Fire
images, predictions = run_demo('Caldor Fire - South Tahoe')

In [None]:
# Run demo for Camp Fire
images, predictions = run_demo('Camp Fire - Paradise')

In [None]:
# Run demo for Healthy Forest (should show low severity)
images, predictions = run_demo('Healthy Forest - Tahoe')

## 5. Interactive Custom Location

In [None]:
# Try any California location!

# Enter coordinates (California wildfires work best)
CUSTOM_LAT = 37.90  # Rim Fire area
CUSTOM_LON = -119.95

# Dates to compare
BEFORE_DATES = ('2013-06-01', '2013-07-31')  # Before fire
AFTER_DATES = ('2013-10-01', '2013-11-30')    # After fire

print(f"üìç Custom Location: ({CUSTOM_LAT}, {CUSTOM_LON})")

# Fetch and predict
print("\nüì• Fetching before fire imagery...")
img_before = fetch_sentinel2(CUSTOM_LAT, CUSTOM_LON, *BEFORE_DATES)

print("üì• Fetching after fire imagery...")
img_after = fetch_sentinel2(CUSTOM_LAT, CUSTOM_LON, *AFTER_DATES)

# Crop to 256x256
def center_crop(img, size=256):
    c, h, w = img.shape
    cy, cx = h // 2, w // 2
    if h >= size and w >= size:
        return img[:, cy-size//2:cy+size//2, cx-size//2:cx+size//2]
    return img

img_before_crop = center_crop(img_before)
img_after_crop = center_crop(img_after)

# Predict
sev_before, conf_before = predictor.predict_tile(img_before_crop)
sev_after, conf_after = predictor.predict_tile(img_after_crop)

print(f"\nüìä Results:")
print(f"   Before: {sev_before.mean():.1%} severity")
print(f"   After: {sev_after.mean():.1%} severity")
print(f"   Change: {(sev_after.mean() - sev_before.mean())*100:+.1f}%")

# Visualize
from inference.visualize import plot_comparison

fig = plot_comparison(
    [sev_before, sev_after],
    ['Before Fire', 'After Fire'],
    title=f"Custom Location ({CUSTOM_LAT}, {CUSTOM_LON})"
)
plt.show()

## 6. Summary

### What This Model Does:
- **Detects burn severity** from Sentinel-2 satellite imagery
- **Continuous prediction** (0-100% severity, not just yes/no)
- **Tracks recovery** over multiple years
- **Works in real-time** on any California location

### Use Cases:
1. üöí **Emergency Response**: Quickly assess fire damage
2. üå≤ **Forest Management**: Monitor recovery progress
3. üè† **Insurance**: Estimate property damage
4. üî¨ **Research**: Study fire behavior patterns

### Technical Highlights:
- U-Net architecture with attention gates
- Trained on 7 major California fires
- Continuous dNBR-based severity labels
- Test-time augmentation for robust predictions