# Real-world Wildfire Prediction (interactive)

This notebook runs predictions similar to `examples/real.py:my_evaluate`, using a provided `real_targets.hdf5` file. You'll upload the dataset and (optionally) a `BearFire.geojson` perimeter to visualize results on a basemap. You can tweak model parameters (`a`, `c_1`, `c_2`, `p_h`, `p_continue`) and see impacts on predicted spread and accuracy.

Steps:
- Upload `real_targets.hdf5` (and optional `BearFire.geojson`).
- Select dataset group (e.g., an `exp_id`).
- Adjust parameters (defaults come from the model).
- Run prediction (mirrors `my_evaluate`).
- View metrics (Jaccard, Manhattan) and an interactive map overlay.



## Imports

We import core libraries for tensors, I/O, widgets, plotting, and mapping. The wildfire model is loaded from `pytorchfire.model`.


In [None]:
import io
import json
import base64
from typing import Tuple, Dict, Any, Optional

import h5py
import numpy as np
import torch
from IPython.display import display
import ipywidgets as widgets
import matplotlib.pyplot as plt

# Mapping
import folium
from folium import GeoJson, Map, TileLayer, LayerControl
from folium.raster_layers import ImageOverlay

# Model
from pytorchfire.model import WildfireModel

plt.rcParams['figure.figsize'] = (8, 5)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')



## Upload data

Upload `real_targets.hdf5` and (optionally) `BearFire.geojson`. Select the dataset group (e.g., `exp_id`) to use for prediction.


In [None]:
h5_uploader = widgets.FileUpload(accept='.hdf5,.h5', multiple=False)
geojson_uploader = widgets.FileUpload(accept='.geojson', multiple=False)
exp_dropdown = widgets.Dropdown(options=[], description='exp_id:')
refresh_btn = widgets.Button(description='Read HDF5 groups', button_style='info')

bounds_help = widgets.HTML(
    value=(
        "If geographic bounds are not found in HDF5, please provide them below. "
        "Bounds format: lat_min, lat_max, lon_min, lon_max."
    )
)
lat_min = widgets.FloatText(description='lat_min:', value=np.nan)
lat_max = widgets.FloatText(description='lat_max:', value=np.nan)
lon_min = widgets.FloatText(description='lon_min:', value=np.nan)
lon_max = widgets.FloatText(description='lon_max:', value=np.nan)

h5_info_out = widgets.Output()

def _infer_bounds(ds: h5py.Group) -> Optional[Tuple[float, float, float, float]]:
    keys = ['bounds', 'lat_min', 'lat_max', 'lon_min', 'lon_max']
    for k in keys:
        pass
    try:
        if 'bounds' in ds.attrs:
            b = ds.attrs['bounds']
            if len(b) == 4:
                return float(b[0]), float(b[1]), float(b[2]), float(b[3])
        needed = ['lat_min', 'lat_max', 'lon_min', 'lon_max']
        if all(k in ds.attrs for k in needed):
            return (
                float(ds.attrs['lat_min']),
                float(ds.attrs['lat_max']),
                float(ds.attrs['lon_min']),
                float(ds.attrs['lon_max']),
            )
    except Exception:
        return None
    return None

_h5_cache: Optional[h5py.File] = None

@refresh_btn.on_click
def _read_h5_groups(_):
    h5_info_out.clear_output()
    global _h5_cache
    if len(h5_uploader.value) == 0:
        with h5_info_out:
            print('Please upload real_targets.hdf5 first.')
        return
    # read file-like
    (fname, fmeta), = list(h5_uploader.value.items())
    bio = io.BytesIO(fmeta['content'])
    try:
        if _h5_cache is not None:
            try:
                _h5_cache.close()
            except Exception:
                pass
        _h5_cache = h5py.File(bio, 'r')
        groups = [k for k in _h5_cache.keys()]
        exp_dropdown.options = groups
        if groups:
            exp_dropdown.value = groups[0]
        with h5_info_out:
            print(f'Loaded HDF5 with groups: {groups}')
    except Exception as e:
        with h5_info_out:
            print('Failed to read HDF5:', e)

widgets.VBox([
    widgets.HBox([widgets.VBox([widgets.HTML('<b>Upload HDF5:</b>'), h5_uploader, refresh_btn]),
                  widgets.VBox([widgets.HTML('<b>Upload GeoJSON (optional):</b>'), geojson_uploader])]),
    widgets.HBox([exp_dropdown]),
    bounds_help,
    widgets.HBox([lat_min, lat_max, lon_min, lon_max]),
    h5_info_out
])


## Parameters and model defaults

We load environment tensors from the selected HDF5 group, instantiate `WildfireModel` with its default parameters, and expose widgets to adjust `a`, `c_1`, `c_2`, `p_h`, `p_continue`. Defaults are read from the model.


In [None]:
params_out = widgets.Output()

load_env_btn = widgets.Button(description='Load environment from HDF5', button_style='success')

def _get_selected_group() -> Optional[h5py.Group]:
    global _h5_cache
    if _h5_cache is None:
        return None
    if exp_dropdown.value is None or exp_dropdown.value == '':
        return None
    return _h5_cache[exp_dropdown.value]

def _load_env_from_ds(ds: h5py.Group) -> Dict[str, torch.Tensor]:
    env = {
        'p_veg': torch.tensor(ds['p_veg'][:]),
        'p_den': torch.tensor(ds['p_den'][:]),
        'wind_towards_direction': torch.tensor(ds['wind_towards_direction'][:][0]),
        'wind_velocity': torch.tensor(ds['wind_velocity'][:][0]),
        'slope': torch.tensor(ds['slope'][:]),
        'initial_ignition': torch.tensor(ds['initial_ignition'][:], dtype=torch.bool)
    }
    return env

# Placeholders to share state across cells
_state: Dict[str, Any] = {
    'model': None,
    'env': None,
    'ds_attrs': None,
    'bounds': None,
    'geojson': None,
}

# Parameter widgets (filled after model init)
a_slider = widgets.FloatSlider(description='a', min=-2.0, max=2.0, step=0.01, value=0.0)
c1_slider = widgets.FloatSlider(description='c_1', min=-2.0, max=2.0, step=0.01, value=0.0)
c2_slider = widgets.FloatSlider(description='c_2', min=-2.0, max=2.0, step=0.01, value=0.0)
ph_slider = widgets.FloatSlider(description='p_h', min=0.0, max=1.0, step=0.01, value=0.3)
pc_slider = widgets.FloatSlider(description='p_continue', min=0.0, max=1.0, step=0.01, value=0.3)

seed_text = widgets.IntText(description='seed', value=None)

@load_env_btn.on_click
def _init_model_from_h5(_):
    params_out.clear_output()
    ds = _get_selected_group()
    if ds is None:
        with params_out:
            print('Please upload HDF5 and select a valid group.')
        return
    env = _load_env_from_ds(ds)
    # Try to infer bounds
    b = _infer_bounds(ds)
    if b is None:
        # Use user-provided, if all finite
        arr = [lat_min.value, lat_max.value, lon_min.value, lon_max.value]
        if all(np.isfinite(v) for v in arr):
            b = (lat_min.value, lat_max.value, lon_min.value, lon_max.value)
    _state['bounds'] = b
    _state['env'] = env
    _state['ds_attrs'] = dict(ds.attrs)

    model = WildfireModel(env_data=env, params=None, keep_acc_mask=True).to(device)
    # Fill widgets with defaults from model
    a_slider.value = float(model.a.item())
    c1_slider.value = float(model.c_1.item())
    c2_slider.value = float(model.c_2.item())
    ph_slider.value = float(model.p_h.item())
    pc_slider.value = float(model.p_continue.item())

    _state['model'] = model

    # Read optional geojson
    if len(geojson_uploader.value) > 0:
        (_, meta), = list(geojson_uploader.value.items())
        _state['geojson'] = meta['content'].decode('utf-8', errors='ignore')
    else:
        _state['geojson'] = None

    with params_out:
        print('Environment loaded. Model initialized on', device)
        if b is None:
            print('Bounds not provided or found. Mapping will require manual bounds.')
        else:
            print('Bounds:', b)

widgets.VBox([
    load_env_btn,
    widgets.HBox([a_slider, c1_slider, c2_slider]),
    widgets.HBox([ph_slider, pc_slider, seed_text]),
    params_out
])


## Run prediction (like `my_evaluate`)

This step mirrors `examples/real.py:Fig7Trainer.my_evaluate`:
- Iterate days with wind updates by `wind_step_interval`.
- Accumulate `output_list` (affected cells) and `targets_list`.
- Compute Jaccard per iteration and a global Manhattan distance on affected cell counts.

Use the button to run with current parameters.


In [None]:
run_out = widgets.Output()
run_btn = widgets.Button(description='Run prediction', button_style='warning')


def _jaccard_index(y_true: torch.Tensor, y_pred: torch.Tensor) -> float:
    y_true = y_true.to(torch.bool)
    y_pred = y_pred.to(torch.bool)
    inter = torch.sum(y_true & y_pred).float()
    union = torch.sum(y_true | y_pred).float()
    if union.item() == 0:
        return 1.0
    return (inter / union).item()


def _manhattan_distance(t1: torch.Tensor, t2: torch.Tensor) -> float:
    assert t1.shape == t2.shape
    return torch.sum(torch.abs(t1 - t2)).item()


@run_btn.on_click
def _run_eval(_):
    run_out.clear_output()
    model: WildfireModel = _state.get('model')
    env = _state.get('env')
    ds = _get_selected_group()
    if model is None or env is None or ds is None:
        with run_out:
            print('Please load environment first.')
        return

    # Apply current parameters
    model.a.data = torch.tensor(a_slider.value, device=device)
    model.c_1.data = torch.tensor(c1_slider.value, device=device)
    model.c_2.data = torch.tensor(c2_slider.value, device=device)
    model.p_h.data = torch.tensor(ph_slider.value, device=device)
    model.p_continue.data = torch.tensor(pc_slider.value, device=device)

    # Seed
    seed_val = int(seed_text.value) if seed_text.value not in (None, '') else None

    # Attributes
    wind_step_interval = int(ds.attrs.get('wind_step_interval', 1))
    max_iterations = int(ds.attrs.get('day_count', ds['target'].shape[0]))

    output_list = []
    targets_list = []
    jaccard_list = []
    out_counts = []
    tgt_counts = []

    # Reset/eval
    model.eval()
    model.reset(seed=seed_val)
    batch_seed = model.seed

    for iterations in range(max_iterations):
        iter_max_steps = (iterations + 1) * wind_step_interval
        for steps in range(iter_max_steps):
            if steps % wind_step_interval == 0:
                model.wind_towards_direction = torch.tensor(ds['wind_towards_direction'][iterations], device=device)
                model.wind_velocity = torch.tensor(ds['wind_velocity'][iterations], device=device)
            model.compute(attach=True)

        targets = torch.tensor(ds['target'][iterations], device=device)
        targets_list.append(targets.detach().cpu())

        affected_cell = (model.state[0] | model.state[1])
        output_list.append(affected_cell.detach().cpu())

        out_counts.append(affected_cell.sum().item())
        tgt_counts.append(targets.sum().item())

        jaccard_list.append(_jaccard_index(targets, affected_cell))

        model.reset(seed=batch_seed)

    manhattan_val = _manhattan_distance(torch.tensor(out_counts), torch.tensor(tgt_counts))

    _state['outputs'] = output_list
    _state['targets'] = targets_list
    _state['jaccard'] = jaccard_list
    _state['manhattan'] = manhattan_val
    _state['counts_out'] = out_counts
    _state['counts_tgt'] = tgt_counts

    with run_out:
        print('Prediction done.')
        print('Mean Jaccard:', float(np.mean(jaccard_list)))
        print('Manhattan:', float(manhattan_val))

widgets.VBox([run_btn, run_out])


## Accuracy metrics

We plot the Jaccard index over iterations and report the Manhattan distance between affected cell counts (prediction vs. target).


In [None]:
metrics_out = widgets.Output()
refresh_metrics = widgets.Button(description='Show metrics', button_style='')

@refresh_metrics.on_click
def _show_metrics(_):
    metrics_out.clear_output()
    j = _state.get('jaccard')
    m = _state.get('manhattan')
    if j is None or m is None:
        with metrics_out:
            print('Run prediction first.')
        return
    with metrics_out:
        fig, ax = plt.subplots()
        ax.plot(np.arange(len(j)) + 1, j, marker='o')
        ax.set_xlabel('Iteration (day)')
        ax.set_ylabel('Jaccard index')
        ax.set_title('Jaccard over time')
        ax.grid(True, alpha=0.3)
        plt.show()
        print('Manhattan distance (affected cell counts):', float(m))

widgets.VBox([refresh_metrics, metrics_out])


## Interactive map

We create a folium map. The `output_list` is rendered as a semi-transparent raster overlay. Use the timestep selector to change the overlay. If a `BearFire.geojson` was uploaded, its perimeters are added as a layer. If geographic bounds are missing, please enter them above.


In [None]:
map_out = widgets.Output()
step_slider = widgets.IntSlider(description='timestep', min=1, max=1, step=1, value=1)
show_map_btn = widgets.Button(description='Show map', button_style='primary')


def _make_overlay_image(mask: np.ndarray, cmap_color=(255, 0, 0, 120)) -> np.ndarray:
    # mask is bool or 0/1 array, create RGBA image with transparency
    h, w = mask.shape
    rgba = np.zeros((h, w, 4), dtype=np.uint8)
    rgba[..., 3] = 0
    sel = mask > 0
    rgba[sel, 0] = cmap_color[0]
    rgba[sel, 1] = cmap_color[1]
    rgba[sel, 2] = cmap_color[2]
    rgba[sel, 3] = cmap_color[3]
    return rgba


def _display_map(t_idx: int):
    map_out.clear_output()
    outputs = _state.get('outputs')
    bounds = _state.get('bounds')
    geojson_text = _state.get('geojson')

    if outputs is None or len(outputs) == 0:
        with map_out:
            print('Run prediction first.')
        return
    if bounds is None:
        with map_out:
            print('Bounds are required to place overlay. Please set/enter bounds.')
        return

    t_idx = max(1, min(t_idx, len(outputs)))
    arr = outputs[t_idx - 1].numpy().astype(np.uint8)
    rgba = _make_overlay_image(arr)

    lat_min, lat_max, lon_min, lon_max = bounds
    m_lat = (lat_min + lat_max) / 2
    m_lon = (lon_min + lon_max) / 2

    m = Map(location=[m_lat, m_lon], zoom_start=11, tiles='CartoDB positron')
    TileLayer('OpenStreetMap').add_to(m)

    # Convert RGBA to a data URL via PNG encode using folium ImageOverlay with array
    # folium supports numpy arrays directly for image parameter
    img = rgba
    img_bounds = [[lat_min, lon_min], [lat_max, lon_max]]

    overlay = ImageOverlay(image=img, bounds=img_bounds, opacity=0.6, name=f'Prediction t={t_idx}')
    overlay.add_to(m)

    if geojson_text:
        try:
            gj = json.loads(geojson_text)
            GeoJson(gj, name='BearFire perimeter', style_function=lambda x: {
                'color': 'blue', 'weight': 2, 'fillOpacity': 0.0
            }).add_to(m)
        except Exception as e:
            with map_out:
                print('Failed to parse GeoJSON:', e)

    LayerControl().add_to(m)

    with map_out:
        display(m)


@show_map_btn.on_click
def _render_map(_):
    _display_map(step_slider.value)


def _update_slider_range():
    outs = _state.get('outputs')
    if outs is not None and len(outs) > 0:
        step_slider.max = len(outs)
        step_slider.value = len(outs)

widgets.VBox([widgets.HBox([step_slider, show_map_btn]), map_out])

# Keep slider in sync when new results arrive
_update_slider_range()
