# Real Wildfire Prediction (Colab)

Upload `real_targets.hdf5` and optional `BearFire.geojson`. The notebook loads everything else from the model, runs prediction, prints metrics, and renders a map overlay.

In [None]:
%pip -q install h5py numpy torch matplotlib folium pytorchfire tqdm

In [None]:
import json
from typing import Optional, Tuple

import h5py
import numpy as np
import torch
import matplotlib.pyplot as plt
from IPython.display import display

import folium
from folium import GeoJson
from folium.raster_layers import ImageOverlay

from tqdm import tqdm
from pytorchfire.model import WildfireModel

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
plt.rcParams['figure.figsize'] = (8, 5)

geojson_text: Optional[str] = None

In [None]:
from google.colab import files  # type: ignore

print('Upload real_targets.hdf5 (required) ...')
uploads = files.upload()
assert any(n.endswith('.hdf5') for n in uploads.keys()), 'Missing real_targets.hdf5'

h5_name = next(n for n in uploads.keys() if n.endswith('.hdf5'))
print('HDF5:', h5_name)

In [None]:
# Optional: upload BearFire.geojson
try:
    from google.colab import files  # type: ignore
    print('Optionally upload BearFire.geojson (or skip this cell) ...')
    uploads_gj = files.upload()
    for n in uploads_gj.keys():
        if n.lower().endswith('.geojson'):
            with open(n, 'r') as f:
                geojson_text = f.read()
            print('GeoJSON: loaded')
            break
    else:
        print('GeoJSON: none uploaded')
except Exception as e:
    print('GeoJSON upload skipped/not available:', e)


In [None]:
# Choose group and set bounds
EXP_ID: Optional[str] = None  # set if multiple groups exist

# Fallback bounds (used only if not present in HDF5)
LAT_MIN = 39.20
LAT_MAX = 39.95
LON_MIN = -121.45
LON_MAX = -120.65

h5f = h5py.File(h5_name, 'r')
groups = list(h5f.keys())
if EXP_ID is None:
    EXP_ID = groups[0]
print('Groups:', groups)
print('Using group:', EXP_ID)
ds = h5f[EXP_ID]

bounds: Optional[Tuple[float, float, float, float]] = None
try:
    if 'bounds' in ds.attrs and len(ds.attrs['bounds']) == 4:
        b = ds.attrs['bounds']
        bounds = (float(b[0]), float(b[1]), float(b[2]), float(b[3]))
    else:
        needed = ['lat_min', 'lat_max', 'lon_min', 'lon_max']
        if all(k in ds.attrs for k in needed):
            bounds = (
                float(ds.attrs['lat_min']), float(ds.attrs['lat_max']),
                float(ds.attrs['lon_min']), float(ds.attrs['lon_max'])
            )
except Exception:
    bounds = None

if bounds is None:
    bounds = (LAT_MIN, LAT_MAX, LON_MIN, LON_MAX)
print('Bounds:', bounds)

In [None]:
# Initialize model and run prediction

env = {
    'p_veg': torch.tensor(ds['p_veg'][:]),
    'p_den': torch.tensor(ds['p_den'][:]),
    'wind_towards_direction': torch.tensor(ds['wind_towards_direction'][:][0]),
    'wind_velocity': torch.tensor(ds['wind_velocity'][:][0]),
    'slope': torch.tensor(ds['slope'][:]),
    'initial_ignition': torch.tensor(ds['initial_ignition'][:], dtype=torch.bool),
}

model = WildfireModel(env_data=env, params=None, keep_acc_mask=True).to(device)
model.eval()
model.reset(seed=None)
batch_seed = model.seed

def jaccard_index(y_true: torch.Tensor, y_pred: torch.Tensor) -> float:
    y_true = y_true.to(torch.bool)
    y_pred = y_pred.to(torch.bool)
    inter = torch.sum(y_true & y_pred).float()
    union = torch.sum(y_true | y_pred).float()
    return (inter / union).item() if union.item() > 0 else 1.0

def manhattan_distance(t1: torch.Tensor, t2: torch.Tensor) -> float:
    assert t1.shape == t2.shape
    return torch.sum(torch.abs(t1 - t2)).item()

wind_step_interval = int(ds.attrs.get('wind_step_interval', 1))
max_iterations = int(ds.attrs.get('day_count', ds['target'].shape[0]))

outputs, targets, jaccard_list = [], [], []
out_counts, tgt_counts = [], []

for iterations in tqdm(range(max_iterations), desc='Predict'):
    iter_max_steps = (iterations + 1) * wind_step_interval
    for steps in range(iter_max_steps):
        if steps % wind_step_interval == 0:
            model.wind_towards_direction = torch.tensor(ds['wind_towards_direction'][iterations], device=device)
            model.wind_velocity = torch.tensor(ds['wind_velocity'][iterations], device=device)
        model.compute(attach=True)

    tgt = torch.tensor(ds['target'][iterations], device=device)
    targets.append(tgt.detach().cpu())
    affected = (model.state[0] | model.state[1])
    outputs.append(affected.detach().cpu())

    out_counts.append(affected.sum().item())
    tgt_counts.append(tgt.sum().item())
    jaccard_list.append(jaccard_index(tgt, affected))

    model.reset(seed=batch_seed)

manhattan_val = manhattan_distance(torch.tensor(out_counts), torch.tensor(tgt_counts))

print('Prediction done.')
print('Mean Jaccard:', float(np.mean(jaccard_list)))
print('Manhattan:', float(manhattan_val))

In [None]:
# Metrics
fig, ax = plt.subplots()
ax.plot(np.arange(len(jaccard_list)) + 1, jaccard_list, marker='o')
ax.set_xlabel('Iteration (day)')
ax.set_ylabel('Jaccard index')
ax.set_title('Jaccard over time')
ax.grid(True, alpha=0.3)
plt.show()
print('Manhattan distance (affected cell counts):', float(manhattan_val))

In [None]:
# Map overlay (last timestep)

def make_overlay_image(mask: np.ndarray, rgba=(255, 0, 0, 120)) -> np.ndarray:
    h, w = mask.shape
    out = np.zeros((h, w, 4), dtype=np.uint8)
    sel = mask > 0
    out[sel, 0] = rgba[0]
    out[sel, 1] = rgba[1]
    out[sel, 2] = rgba[2]
    out[sel, 3] = rgba[3]
    return out

idx = len(outputs)
arr = outputs[idx - 1].numpy().astype(np.uint8)
rgba = make_overlay_image(arr)

lat_min, lat_max, lon_min, lon_max = bounds
m_lat = (lat_min + lat_max) / 2
m_lon = (lon_min + lon_max) / 2

m = folium.Map(location=[m_lat, m_lon], zoom_start=11, tiles='CartoDB positron')
folium.TileLayer('OpenStreetMap').add_to(m)
ImageOverlay(image=rgba, bounds=[[lat_min, lon_min], [lat_max, lon_max]], opacity=0.6, name=f'Prediction t={idx}').add_to(m)

if geojson_text:
    try:
        gj = json.loads(geojson_text)
        GeoJson(gj, name='BearFire perimeter', style_function=lambda x: {'color': 'blue', 'weight': 2, 'fillOpacity': 0.0}).add_to(m)
    except Exception as e:
        print('Failed to parse GeoJSON:', e)

folium.LayerControl().add_to(m)
display(m)