# MLP Fire Holdout (By-Fire Split)

This notebook trains a small MLP to predict whether the **center cell** will cross a GOES confidence threshold at time `t+1`, using the **3x3 neighborhood** features at time `t`.

- **Input (X)**: 3x3 patch (center + 8 neighbors), 7 variables per cell (63 features total).
- **Target (y)**: `1` iff `GOES_confidence_center(t+1) >= POSITIVE_THRESHOLD`, else `0`.
- **Split**: train on all fires except `TEST_FIRES`; evaluate only on `TEST_FIRES`.

Data layout expectation (same as `docs/neighbor_cell_confidence_regression.ipynb`):
- `data/multi_fire/<FireName>/*GOES*json`
- `data/multi_fire/<FireName>/rtma/rtma_manifest.json`


In [1]:
from pathlib import Path
import random

# --- Split config ---
FIRE_SELECTION = "all"  # "all" or list of fire names
TEST_FIRES = ["Dixie", "Kincade"]

# --- Task config ---
POSITIVE_THRESHOLD = 0.10
CLASSIFICATION_PROB_THRESHOLD = 0.50

# --- Training config ---
SEED = 1337
EPOCHS = 1
BATCH_SIZE = 8192
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4
HIDDEN_DIMS = [128, 64]
DROPOUT = 0.2

# Optional safety knobs (default: full data)
MAX_HOURS_PER_FIRE = None        # e.g., 48
MAX_SAMPLES_PER_HOUR = None      # e.g., 200_000

print("cwd:", Path.cwd())
print("fire selection:", FIRE_SELECTION)
print("test fires:", TEST_FIRES)
print("positive confidence threshold:", POSITIVE_THRESHOLD)
print("classification probability threshold:", CLASSIFICATION_PROB_THRESHOLD)
print("epochs:", EPOCHS)
print("batch size:", BATCH_SIZE)


cwd: /Users/seanmay/Desktop/Current Projects/wildfire-prediction/docs
fire selection: all
test fires: ['Dixie', 'Kincade']
positive confidence threshold: 0.1
classification probability threshold: 0.5
epochs: 1
batch size: 8192


In [2]:
import json
import random
import hashlib
import math
import os
from datetime import datetime, timedelta
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rasterio
from rasterio.warp import Resampling, reproject

import torch
import torch.nn as nn


def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


set_seed(SEED)


def pick_device() -> torch.device:
    if torch.cuda.is_available():
        return torch.device("cuda")
    # Apple Silicon
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")


DEVICE = pick_device()
print("torch:", torch.__version__)
print("device:", DEVICE)


torch: 2.5.1
device: mps


In [3]:
def parse_iso(value: str) -> datetime:
    if value.endswith("Z"):
        value = value[:-1] + "+00:00"
    return datetime.fromisoformat(value)


def normalize_time_str(value: str) -> str:
    dt = parse_iso(value)
    return dt.strftime("%Y-%m-%dT%H:00:00Z")


def affine_from_list(vals: list) -> rasterio.Affine:
    return rasterio.Affine(vals[0], vals[1], vals[2], vals[3], vals[4], vals[5])


def find_repo_root(start: Path) -> Path:
    for p in [start] + list(start.parents):
        if (p / "data").exists() and (p / "scripts").exists() and (p / "docs").exists():
            return p
    raise FileNotFoundError("Could not find repo root containing data/, scripts/, docs/.")


def load_goes_times(goes_meta: dict, goes_conf: np.ndarray):
    goes_time_steps = goes_meta.get("time_steps", [])
    goes_start = goes_meta.get("start_time")

    if goes_time_steps and isinstance(goes_time_steps[0], (int, float)):
        if not goes_start:
            raise ValueError("GOES time_steps are numeric and metadata.start_time is missing.")
        start_dt = parse_iso(goes_start)
        goes_time_steps = [
            (start_dt + timedelta(hours=int(i - 1))).strftime("%Y-%m-%dT%H:00:00Z")
            for i in goes_time_steps
        ]
    elif not goes_time_steps and goes_start:
        start_dt = parse_iso(goes_start)
        goes_time_steps = [
            (start_dt + timedelta(hours=i)).strftime("%Y-%m-%dT%H:00:00Z")
            for i in range(goes_conf.shape[0])
        ]
    else:
        goes_time_steps = [normalize_time_str(t) for t in goes_time_steps]

    if not goes_time_steps:
        raise ValueError("GOES metadata has no usable time_steps.")

    return goes_time_steps


def discover_fire_entries(repo_root: Path):
    base = repo_root / "data" / "multi_fire"
    if not base.exists():
        raise FileNotFoundError(f"Missing multi-fire directory: {base}")

    entries = []
    for fire_dir in sorted([d for d in base.iterdir() if d.is_dir()]):
        goes_candidates = sorted(fire_dir.glob("*GOES*json"))
        manifest_path = fire_dir / "rtma" / "rtma_manifest.json"
        if not goes_candidates or not manifest_path.exists():
            continue
        entries.append(
            {
                "fire_name": fire_dir.name,
                "goes_json": goes_candidates[0],
                "rtma_manifest": manifest_path,
            }
        )
    return entries


def select_fire_entries(entries, fire_selection):
    if fire_selection is None or fire_selection == "all":
        return entries

    if not isinstance(fire_selection, (list, tuple, set)):
        raise ValueError('FIRE_SELECTION must be "all" or a list/tuple/set of fire names.')

    wanted = {str(x) for x in fire_selection}
    selected = [e for e in entries if e["fire_name"] in wanted]
    found = {e["fire_name"] for e in selected}
    missing = sorted(wanted - found)
    if missing:
        raise ValueError(f"Unknown fire names in FIRE_SELECTION: {missing}")
    return selected


REPO_ROOT = find_repo_root(Path.cwd())


In [4]:
CELL_OFFSETS = [
    ("c", 0, 0),
    ("nw", -1, -1),
    ("n", -1, 0),
    ("ne", -1, 1),
    ("w", 0, -1),
    ("e", 0, 1),
    ("sw", 1, -1),
    ("s", 1, 0),
    ("se", 1, 1),
]

VAR_ORDER = ["GOES_conf", "TMP", "WIND", "SPFH", "ACPC01", "WDIR_sin", "WDIR_cos"]
RTMA_VARS_REQUIRED = ["TMP", "WIND", "WDIR", "SPFH", "ACPC01"]


def feature_names():
    names = []
    for n_name, _, _ in CELL_OFFSETS:
        for v in VAR_ORDER:
            names.append(f"{v}_{n_name}")
    return names


FEATURE_NAMES = feature_names()
N_FEATURES = len(FEATURE_NAMES)


def to_binary_target(y_continuous: np.ndarray, threshold: float) -> np.ndarray:
    return (y_continuous >= threshold).astype(np.int32)


def resolve_manifest_file_path(path_str: str, repo_root: Path, manifest_dir: Path) -> Path:
    p = Path(path_str).expanduser()
    if p.exists():
        return p

    parts = p.parts
    if "data" in parts:
        idx = parts.index("data")
        candidate = repo_root.joinpath(*parts[idx:])
        if candidate.exists():
            return candidate

    candidate = (manifest_dir / path_str).resolve()
    if candidate.exists():
        return candidate

    raise FileNotFoundError(f"Could not resolve RTMA part path: {path_str}")


def resample_stack(src_stack, src_transform, src_crs, dst_shape, dst_transform, dst_crs):
    bands = src_stack.shape[0]
    dst = np.empty((bands, dst_shape[0], dst_shape[1]), dtype=np.float32)
    for b in range(bands):
        reproject(
            source=src_stack[b],
            destination=dst[b],
            src_transform=src_transform,
            src_crs=src_crs,
            dst_transform=dst_transform,
            dst_crs=dst_crs,
            resampling=Resampling.bilinear,
        )
    return dst


def build_hour_samples(conf_t, conf_t1, rtma_hour):
    # Use only interior cells so center + all 8 neighbors exist.
    h, w = conf_t.shape
    if h < 3 or w < 3:
        return np.empty((0, N_FEATURES), dtype=np.float64), np.empty((0,), dtype=np.float64)

    y = conf_t1[1:-1, 1:-1].astype(np.float64)
    feat_blocks = []

    for _, dy, dx in CELL_OFFSETS:
        ys = slice(1 + dy, h - 1 + dy)
        xs = slice(1 + dx, w - 1 + dx)

        go_cell = conf_t[ys, xs].astype(np.float64)
        tmp_cell = rtma_hour["TMP"][ys, xs].astype(np.float64)
        wind_cell = rtma_hour["WIND"][ys, xs].astype(np.float64)
        spfh_cell = rtma_hour["SPFH"][ys, xs].astype(np.float64)
        precip_cell = rtma_hour["ACPC01"][ys, xs].astype(np.float64)
        wdir_deg_cell = rtma_hour["WDIR"][ys, xs].astype(np.float64)
        wdir_rad_cell = np.deg2rad(wdir_deg_cell)
        wdir_sin_cell = np.sin(wdir_rad_cell)
        wdir_cos_cell = np.cos(wdir_rad_cell)

        feat_blocks.extend([
            go_cell,
            tmp_cell,
            wind_cell,
            spfh_cell,
            precip_cell,
            wdir_sin_cell,
            wdir_cos_cell,
        ])

    X = np.stack(feat_blocks, axis=-1).reshape(-1, N_FEATURES)
    y = y.reshape(-1)

    valid = np.isfinite(y)
    valid &= np.isfinite(X).all(axis=1)

    if not valid.any():
        return np.empty((0, N_FEATURES), dtype=np.float64), np.empty((0,), dtype=np.float64)

    X = X[valid]
    y = y[valid]

    if MAX_SAMPLES_PER_HOUR is not None and X.shape[0] > MAX_SAMPLES_PER_HOUR:
        rng = np.random.default_rng(SEED)
        idx = rng.choice(X.shape[0], size=int(MAX_SAMPLES_PER_HOUR), replace=False)
        X = X[idx]
        y = y[idx]

    return X, y


def iter_aligned_hours_for_fire(
    goes_conf,
    goes_time_index,
    rtma_manifest,
    rtma_manifest_path: Path,
    goes_shape,
    goes_transform,
    goes_crs,
):
    rtma_vars = rtma_manifest["variables"]
    for req in RTMA_VARS_REQUIRED:
        if req not in rtma_vars:
            raise KeyError(f"RTMA manifest missing required variable: {req}")

    manifest_dir = rtma_manifest_path.parent
    rtma_files = rtma_manifest["files"]
    resolved_files = {
        var: [resolve_manifest_file_path(path, REPO_ROOT, manifest_dir) for path in rtma_files[var]]
        for var in rtma_vars
    }

    n_parts = len(resolved_files[rtma_vars[0]])
    for v in rtma_vars:
        if len(resolved_files[v]) != n_parts:
            raise ValueError("RTMA variable file lists do not have equal part counts.")

    parts = list(zip(*[resolved_files[v] for v in rtma_vars]))
    rtma_time_steps = [normalize_time_str(t) for t in rtma_manifest["time_steps"]]

    rtma_time_ptr = 0

    for part_paths in parts:
        rtma_arrays = {}
        rtma_transform = None
        rtma_crs = None
        band_count = None

        for var, part_path in zip(rtma_vars, part_paths):
            with rasterio.open(part_path) as ds:
                if rtma_transform is None:
                    rtma_transform = ds.transform
                    rtma_crs = ds.crs
                    band_count = ds.count
                rtma_arrays[var] = ds.read().astype("float32")

        if band_count is None:
            continue

        resampled = {}
        for var in rtma_vars:
            resampled[var] = resample_stack(
                rtma_arrays[var],
                rtma_transform,
                rtma_crs,
                goes_shape,
                goes_transform,
                goes_crs,
            )

        for local_idx in range(band_count):
            global_idx = rtma_time_ptr + local_idx
            if global_idx >= len(rtma_time_steps):
                break

            time_str = rtma_time_steps[global_idx]
            if time_str not in goes_time_index:
                continue

            t = goes_time_index[time_str]
            if t + 1 >= goes_conf.shape[0]:
                continue

            rtma_hour = {var: resampled[var][local_idx] for var in RTMA_VARS_REQUIRED}
            yield t, rtma_hour

        rtma_time_ptr += band_count


def iter_fire_hour_samples(entry):
    with Path(entry["goes_json"]).open("r", encoding="utf-8") as f:
        goes_json_local = json.load(f)
    with Path(entry["rtma_manifest"]).open("r", encoding="utf-8") as f:
        rtma_manifest_local = json.load(f)

    goes_conf_local = np.array(goes_json_local["data"], dtype=np.float32)
    goes_meta_local = goes_json_local["metadata"]
    goes_transform_local = affine_from_list(goes_meta_local["geo_transform"])
    goes_crs_local = goes_meta_local.get("crs")
    goes_shape_local = tuple(goes_meta_local["grid_shape"])
    goes_times_local = load_goes_times(goes_meta_local, goes_conf_local)
    goes_time_index_local = {t: i for i, t in enumerate(goes_times_local)}

    hours_yielded = 0
    for t, rtma_hour in iter_aligned_hours_for_fire(
        goes_conf_local,
        goes_time_index_local,
        rtma_manifest_local,
        Path(entry["rtma_manifest"]),
        goes_shape_local,
        goes_transform_local,
        goes_crs_local,
    ):
        X_hour, y_hour_cont = build_hour_samples(goes_conf_local[t], goes_conf_local[t + 1], rtma_hour)
        if X_hour.shape[0] == 0:
            continue
        y_hour = to_binary_target(y_hour_cont, POSITIVE_THRESHOLD)
        yield entry["fire_name"], t, X_hour, y_hour

        hours_yielded += 1
        if MAX_HOURS_PER_FIRE is not None and hours_yielded >= MAX_HOURS_PER_FIRE:
            break


print("feature count:", N_FEATURES)
print("first 10 feature names:", FEATURE_NAMES[:10])


feature count: 63
first 10 feature names: ['GOES_conf_c', 'TMP_c', 'WIND_c', 'SPFH_c', 'ACPC01_c', 'WDIR_sin_c', 'WDIR_cos_c', 'GOES_conf_nw', 'TMP_nw', 'WIND_nw']


In [5]:
all_fire_entries = discover_fire_entries(REPO_ROOT)
fire_entries = select_fire_entries(all_fire_entries, FIRE_SELECTION)

available = [e["fire_name"] for e in all_fire_entries]
selected = [e["fire_name"] for e in fire_entries]

print("available fires:", available)
print("selected fires:", selected)

missing_test = sorted(set(TEST_FIRES) - set(available))
if missing_test:
    raise RuntimeError(f"Missing TEST_FIRES in data/multi_fire: {missing_test}")

train_entries = [e for e in fire_entries if e["fire_name"] not in set(TEST_FIRES)]
test_entries = [e for e in fire_entries if e["fire_name"] in set(TEST_FIRES)]

print("train fires:", [e["fire_name"] for e in train_entries])
print("test fires:", [e["fire_name"] for e in test_entries])

if not train_entries:
    raise RuntimeError("No training fires after split.")
if not test_entries:
    raise RuntimeError("No test fires after split.")


available fires: ['August_Complex', 'Bobcat', 'CZU_Lightning_Complex', 'Creek', 'Dixie', 'Dolan', 'Glass', 'July_Complex', 'Kincade', 'LNU_Lightning_Complex', 'North_Complex', 'Red_Salmon_Complex', 'SCU_Lightning_Complex', 'SQF_Complex', 'Slater_and_Devil', 'W-5_Cold_Springs', 'Walker', 'Zogg']
selected fires: ['August_Complex', 'Bobcat', 'CZU_Lightning_Complex', 'Creek', 'Dixie', 'Dolan', 'Glass', 'July_Complex', 'Kincade', 'LNU_Lightning_Complex', 'North_Complex', 'Red_Salmon_Complex', 'SCU_Lightning_Complex', 'SQF_Complex', 'Slater_and_Devil', 'W-5_Cold_Springs', 'Walker', 'Zogg']
train fires: ['August_Complex', 'Bobcat', 'CZU_Lightning_Complex', 'Creek', 'Dolan', 'Glass', 'July_Complex', 'LNU_Lightning_Complex', 'North_Complex', 'Red_Salmon_Complex', 'SCU_Lightning_Complex', 'SQF_Complex', 'Slater_and_Devil', 'W-5_Cold_Springs', 'Walker', 'Zogg']
test fires: ['Dixie', 'Kincade']


In [6]:
def welford_update(count: int, mean: np.ndarray, m2: np.ndarray, X: np.ndarray):
    # X: (n, d)
    X = X.astype(np.float64, copy=False)
    n = X.shape[0]
    if n == 0:
        return count, mean, m2

    batch_mean = X.mean(axis=0)
    batch_m2 = ((X - batch_mean) ** 2).sum(axis=0)

    if count == 0:
        return n, batch_mean, batch_m2

    delta = batch_mean - mean
    total = count + n
    new_mean = mean + delta * (n / total)
    new_m2 = m2 + batch_m2 + (delta ** 2) * (count * n / total)
    return total, new_mean, new_m2


# Pass A: compute normalization stats + class balance on TRAIN fires only.
count = 0
mean = np.zeros((N_FEATURES,), dtype=np.float64)
m2 = np.zeros((N_FEATURES,), dtype=np.float64)

train_pos = 0
train_neg = 0
train_samples = 0
hours_used = 0

for entry in train_entries:
    for fire_name, t, X_hour, y_hour in iter_fire_hour_samples(entry):
        hours_used += 1
        train_samples += int(y_hour.shape[0])
        train_pos += int(y_hour.sum())
        train_neg += int(y_hour.shape[0] - y_hour.sum())

        count, mean, m2 = welford_update(count, mean, m2, X_hour)

if count < 2:
    raise RuntimeError("Not enough training samples to compute normalization stats.")

var = m2 / (count - 1)
std = np.sqrt(np.maximum(var, 1e-12))

pos_weight_value = (train_neg / max(train_pos, 1))

print("training hours used:", hours_used)
print("training samples:", train_samples)
print("training positives:", train_pos)
print("training negatives:", train_neg)
print("training pos rate:", train_pos / max(train_samples, 1))
print("pos_weight (neg/pos):", pos_weight_value)


training hours used: 11329
training samples: 78365262
training positives: 222000
training negatives: 78143262
training pos rate: 0.0028328878680964533
pos_weight (neg/pos): 351.9966756756757


In [7]:
class MLP(nn.Module):
    def __init__(self, in_dim: int, hidden_dims: list[int], dropout: float):
        super().__init__()
        layers: list[nn.Module] = []
        prev = in_dim
        for h in hidden_dims:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.ReLU())
            if dropout and dropout > 0:
                layers.append(nn.Dropout(dropout))
            prev = h
        layers.append(nn.Linear(prev, 1))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


model = MLP(N_FEATURES, HIDDEN_DIMS, DROPOUT).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

pos_weight = torch.tensor([pos_weight_value], dtype=torch.float32, device=DEVICE)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

print(model)


MLP(
  (net): Sequential(
    (0): Linear(in_features=63, out_features=128, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=128, out_features=64, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.2, inplace=False)
    (6): Linear(in_features=64, out_features=1, bias=True)
  )
)


In [8]:
def batch_iter(n_rows: int, batch_size: int):
    for start in range(0, n_rows, batch_size):
        end = min(n_rows, start + batch_size)
        yield start, end


def logits_to_pred(logits: torch.Tensor, prob_threshold: float) -> torch.Tensor:
    probs = torch.sigmoid(logits)
    return (probs >= prob_threshold).to(torch.int32)


# Train (streaming over hours)
model.train()

for epoch in range(1, EPOCHS + 1):
    epoch_loss = 0.0
    epoch_batches = 0

    tp = fp = fn = tn = 0

    for entry in train_entries:
        for fire_name, t, X_hour, y_hour in iter_fire_hour_samples(entry):
            Xn = (X_hour - mean) / std

            y_np = y_hour.astype(np.float32, copy=False)
            n_rows = Xn.shape[0]

            for start, end in batch_iter(n_rows, BATCH_SIZE):
                xb = torch.from_numpy(Xn[start:end]).to(device=DEVICE, dtype=torch.float32)
                yb = torch.from_numpy(y_np[start:end]).to(device=DEVICE, dtype=torch.float32).unsqueeze(1)

                optimizer.zero_grad(set_to_none=True)
                logits = model(xb)
                loss = criterion(logits, yb)
                loss.backward()
                optimizer.step()

                epoch_loss += float(loss.detach().cpu())
                epoch_batches += 1

                with torch.no_grad():
                    pred = logits_to_pred(logits, CLASSIFICATION_PROB_THRESHOLD)
                    y_int = yb.to(torch.int32)
                    tp += int(((pred == 1) & (y_int == 1)).sum().cpu())
                    fp += int(((pred == 1) & (y_int == 0)).sum().cpu())
                    fn += int(((pred == 0) & (y_int == 1)).sum().cpu())
                    tn += int(((pred == 0) & (y_int == 0)).sum().cpu())

    avg_loss = epoch_loss / max(epoch_batches, 1)
    acc = (tp + tn) / max(tp + tn + fp + fn, 1)
    tpr = tp / max(tp + fn, 1)
    tnr = tn / max(tn + fp, 1)

    print(f"epoch {epoch}/{EPOCHS} avg_loss={avg_loss:.6f} acc={acc:.4f} tpr={tpr:.4f} tnr={tnr:.4f} TP={tp} FP={fp} FN={fn} TN={tn}")


epoch 1/1 avg_loss=0.352312 acc=0.9802 tpr=0.8764 tnr=0.9805 TP=194564 FP=1520461 FN=27436 TN=76622801


In [9]:
# Evaluate on held-out fires (no gradient)
model.eval()

overall = {"tp": 0, "fp": 0, "fn": 0, "tn": 0}
per_fire = {}

# Precision/Recall/F1 sweep (match logistic notebook style)
PR_THRESHOLDS = np.linspace(0.0, 1.0, 1001)
pr_tp = np.zeros(PR_THRESHOLDS.shape[0], dtype=np.int64)
pr_fp = np.zeros(PR_THRESHOLDS.shape[0], dtype=np.int64)
pr_total_pos = 0
pr_total_neg = 0


def batch_iter(n_rows: int, batch_size: int):
    for start in range(0, n_rows, batch_size):
        end = min(n_rows, start + batch_size)
        yield start, end


def metrics_from_counts(tp, fp, fn, tn):
    n = tp + fp + fn + tn
    acc = (tp + tn) / max(n, 1)
    recall = tp / max(tp + fn, 1)
    tpr = recall
    tnr = tn / max(tn + fp, 1)
    precision = tp / max(tp + fp, 1)
    denom = (precision + recall)
    f1 = (2 * precision * recall / denom) if denom > 0 else 0.0
    return {
        "n": n,
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "tpr": tpr,
        "tnr": tnr,
        "f1": f1,
        "tp": tp,
        "fp": fp,
        "fn": fn,
        "tn": tn,
    }


with torch.no_grad():
    for entry in test_entries:
        fire = entry["fire_name"]
        per_fire.setdefault(fire, {"tp": 0, "fp": 0, "fn": 0, "tn": 0, "n": 0, "pos": 0, "neg": 0})

        for fire_name, t, X_hour, y_hour in iter_fire_hour_samples(entry):
            Xn = (X_hour - mean) / std
            y_np = y_hour.astype(np.int32, copy=False)

            n_rows = Xn.shape[0]
            for start, end in batch_iter(n_rows, BATCH_SIZE):
                xb = torch.from_numpy(Xn[start:end]).to(device=DEVICE, dtype=torch.float32)
                logits = model(xb)
                probs = torch.sigmoid(logits).squeeze(1).detach().cpu().numpy()

                yb = y_np[start:end]
                pred = (probs >= CLASSIFICATION_PROB_THRESHOLD).astype(np.int32)

                tp = int(((pred == 1) & (yb == 1)).sum())
                fp = int(((pred == 1) & (yb == 0)).sum())
                fn = int(((pred == 0) & (yb == 1)).sum())
                tn = int(((pred == 0) & (yb == 0)).sum())

                overall["tp"] += tp
                overall["fp"] += fp
                overall["fn"] += fn
                overall["tn"] += tn

                pf = per_fire[fire]
                pf["tp"] += tp
                pf["fp"] += fp
                pf["fn"] += fn
                pf["tn"] += tn
                pf["n"] += int(yb.shape[0])
                pf["pos"] += int(yb.sum())
                pf["neg"] += int(yb.shape[0] - yb.sum())

                # Threshold sweep counts (exact for this threshold grid)
                pos = yb == 1
                pr_total_pos += int(pos.sum())
                pr_total_neg += int((~pos).sum())

                sweep_pred = probs[:, None] >= PR_THRESHOLDS[None, :]
                pr_tp += (sweep_pred & pos[:, None]).sum(axis=0).astype(np.int64)
                pr_fp += (sweep_pred & (~pos)[:, None]).sum(axis=0).astype(np.int64)


overall_metrics = metrics_from_counts(**overall)
print("overall test metrics:", overall_metrics)

per_fire_metrics = {}
for fire, c in per_fire.items():
    per_fire_metrics[fire] = metrics_from_counts(c["tp"], c["fp"], c["fn"], c["tn"])
    per_fire_metrics[fire]["pos_rate"] = c["pos"] / max(c["n"], 1)

print("per-fire test metrics:")
for fire in sorted(per_fire_metrics):
    print(fire, per_fire_metrics[fire])


overall test metrics: {'n': 25817908, 'accuracy': 0.9953858383878353, 'tpr': 0.8702959528307287, 'tnr': 0.9958104686632779, 'precision': 0.4135462285450045, 'tp': 76016, 'fp': 107799, 'fn': 11329, 'tn': 25622764}
per-fire test metrics:
Dixie {'n': 25370800, 'accuracy': 0.9954943478329418, 'tpr': 0.8696321926248933, 'tnr': 0.995914405873701, 'precision': 0.41533672891907186, 'tp': 73390, 'fp': 103310, 'fn': 11002, 'tn': 25183098, 'pos_rate': 0.003326343670676526}
Kincade {'n': 447108, 'accuracy': 0.9892285532801918, 'tpr': 0.889265154080596, 'tnr': 0.9898931679256116, 'precision': 0.3690794096978215, 'tp': 2626, 'fp': 4489, 'fn': 327, 'tn': 439666, 'pos_rate': 0.006604668223337538}


In [14]:
# Precision-Recall curve + top thresholds by F1 (Test-Fire Set)

pr_df = None

if pr_total_pos == 0:
    raise RuntimeError("No positive samples in the held-out test set; cannot compute precision/recall.")

pr_fn = pr_total_pos - pr_tp
pr_tn = pr_total_neg - pr_fp

pr_precision = np.where((pr_tp + pr_fp) > 0, pr_tp / (pr_tp + pr_fp), 1.0)
pr_recall = pr_tp / pr_total_pos

pr_f1 = np.where(
    (pr_precision + pr_recall) > 0,
    2 * pr_precision * pr_recall / (pr_precision + pr_recall),
    0.0,
)

pr_df = pd.DataFrame(
    {
        "threshold": PR_THRESHOLDS,
        "precision": pr_precision,
        "recall": pr_recall,
        "f1": pr_f1,
        "tp": pr_tp,
        "fp": pr_fp,
        "fn": pr_fn,
        "tn": pr_tn,
    }
)

best = pr_df.iloc[int(pr_df["f1"].idxmax())]
baseline = pr_total_pos / (pr_total_pos + pr_total_neg)

print("test-fire positive rate (baseline precision at recall=1):", float(baseline))
print("best threshold by F1:", float(best["threshold"]))
print("precision:", float(best["precision"]), "recall:", float(best["recall"]), "f1:", float(best["f1"]))

pr_plot_df = pr_df.sort_values("recall")

plt.figure(figsize=(7, 4))
plt.plot(pr_plot_df["recall"], pr_plot_df["precision"], linewidth=2)
plt.hlines(baseline, 0, 1, linestyles="dashed", colors="gray", label=f"baseline (pos rate={baseline:.3f})")
plt.xlim(0, 1)
plt.ylim(0, 1)
plt.xlabel("Recall (TP / (TP+FN))")
plt.ylabel("Precision (TP / (TP+FP))")
plt.title("Precision-Recall Curve (Test-Fire Set)")
plt.grid(True, alpha=0.25)
plt.legend()
plt.show()

pr_df.sort_values("f1", ascending=False).head(12)


NameError: name 'pr_total_pos' is not defined

In [None]:
# Save report
out_dir = REPO_ROOT / "data" / "analysis" / "mlp_fire_holdout"
out_dir.mkdir(parents=True, exist_ok=True)

# Approx AP via trapezoid on the PR curve (recall-sorted)
pr_plot_df = pr_df.sort_values("recall")
trapz = getattr(np, "trapezoid", None)
if trapz is None:
    trapz = np.trapz
approx_ap = float(trapz(pr_plot_df["precision"].to_numpy(), pr_plot_df["recall"].to_numpy()))

best = pr_df.iloc[int(pr_df["f1"].idxmax())]

report = {
    "model": "mlp_pytorch",
    "config": {
        "fire_selection": FIRE_SELECTION,
        "test_fires": TEST_FIRES,
        "positive_threshold": POSITIVE_THRESHOLD,
        "classification_prob_threshold": CLASSIFICATION_PROB_THRESHOLD,
        "seed": SEED,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "learning_rate": LEARNING_RATE,
        "weight_decay": WEIGHT_DECAY,
        "hidden_dims": HIDDEN_DIMS,
        "dropout": DROPOUT,
        "max_hours_per_fire": MAX_HOURS_PER_FIRE,
        "max_samples_per_hour": MAX_SAMPLES_PER_HOUR,
        "n_features": N_FEATURES,
        "feature_names": FEATURE_NAMES,
        "pr_threshold_count": int(PR_THRESHOLDS.shape[0]),
    },
    "train": {
        "samples": int(train_samples),
        "positives": int(train_pos),
        "negatives": int(train_neg),
        "pos_rate": float(train_pos / max(train_samples, 1)),
        "pos_weight": float(pos_weight_value),
    },
    "test": {
        "overall": overall_metrics,
        "per_fire": per_fire_metrics,
        "threshold_sweep": {
            "baseline_pos_rate": float(pr_total_pos / (pr_total_pos + pr_total_neg)),
            "best_threshold_by_f1": float(best["threshold"]),
            "best_precision": float(best["precision"]),
            "best_recall": float(best["recall"]),
            "best_f1": float(best["f1"]),
            "approx_ap": approx_ap,
            "top_by_f1": pr_df.sort_values("f1", ascending=False).head(12).to_dict(orient="records"),
        },
    },
}

report_path = out_dir / "report.json"
with report_path.open("w", encoding="utf-8") as f:
    json.dump(report, f, indent=2)

print("saved:", report_path)
