## 0) Variable meanings

- **Target** = center-cell GOES fire confidence at time `t+1`, converted to binary class
- **Positive class (1)** = `confidence_t+1 >= 0.10`
- **Negative class (0)** = `confidence_t+1 < 0.10`
- **Cell features** = all available cell data at time `t` for center cell + surrounding 8 cells
- **Variables per cell**:
  - `fire_confidence` (GOES fire confidence at current hour)
  - `temperature` (RTMA near-surface temperature, raw key `TMP`)
  - `wind_speed` (RTMA wind speed, raw key `WIND`)
  - `specific_humidity` (RTMA specific humidity, raw key `SPFH`)
  - `precipitation_1h` (RTMA 1-hour accumulated precipitation, raw key `ACPC01`)
  - `wind_direction_sin` (sine of RTMA wind direction in radians)
  - `wind_direction_cos` (cosine of RTMA wind direction in radians)
- **Cell order** = `C, NW, N, NE, W, E, SW, S, SE`
- **Model** = logistic regression (streaming SGD optimization)


## 1) Summary

- Data: GOES hourly confidence (~2 km) + RTMA hourly meteorology
- Task: predict whether next-hour center-cell confidence crosses 0.10 using center + 8-neighbor features at current hour
- Wind direction features: encoded as `sin` and `cos` (no raw degree feature)
- Split: complete-fire holdout (train on full fires, test on different full fires)
- Feature normalization: z-score normalization using train-fire statistics only
- Training: logistic regression on normalized train-fire features only
- Evaluation: held-out test-fire set `test_accuracy_overall`, `test_positive_accuracy`, `test_negative_accuracy`


## 2) Section toggles

Use this block to quickly enable/disable major notebook sections.

Dependency notes:
- `RUN_TRAINING_SECTION` requires fire discovery + helper definitions.
- `RUN_EVALUATION_SECTION` requires a trained model.
- `RUN_PR_SECTION` and `RUN_COEFFICIENT_SECTION` require a trained model.
- `RUN_SUMMARY_SECTION` and `RUN_REPORT_SECTION` expect prior sections to have produced metrics.


In [262]:
from pathlib import Path

# Fire selection:
# - "all" => use every fire in data/multi_fire/*
# - list => use only named fires, e.g. ["Creek", "Dixie"]
FIRE_SELECTION = "all"

# Fire-level train/test split settings.
# - If TRAIN_FIRES and TEST_FIRES are both "auto": split selected fires by FIRE_TRAIN_FRACTION
# - If one side is "auto" and the other is a list: auto side gets remaining fires
# - If both are lists: both lists are used exactly
TRAIN_FIRES = "auto"
TEST_FIRES = "auto"
FIRE_TRAIN_FRACTION = 0.70
FIRE_SPLIT_SEED = 42

# Classification settings.
POSITIVE_THRESHOLD = 0.10
CLASSIFICATION_PROB_THRESHOLD = 0.50

# Feature scaling settings.
NORMALIZE_FEATURES = True

# Section toggles.
RUN_FIRE_DISCOVERY_SECTION = True
RUN_DATA_STATS_SECTION = True
RUN_NORMALIZATION_SECTION = True
RUN_TRAINING_SECTION = True
RUN_EVALUATION_SECTION = True
RUN_SUMMARY_SECTION = True
RUN_PR_SECTION = True
RUN_COEFFICIENT_SECTION = True
RUN_REPORT_SECTION = True


In [263]:
print("cwd:", Path.cwd())
print("fire selection:", FIRE_SELECTION)
print("train fires config:", TRAIN_FIRES)
print("test fires config:", TEST_FIRES)
print("fire train fraction (auto mode):", FIRE_TRAIN_FRACTION)
print("fire split seed:", FIRE_SPLIT_SEED)
print("positive confidence threshold:", POSITIVE_THRESHOLD)
print("classification probability threshold:", CLASSIFICATION_PROB_THRESHOLD)
print("normalize features:", NORMALIZE_FEATURES)
print("section toggles:")
print("  fire discovery:", RUN_FIRE_DISCOVERY_SECTION)
print("  data stats:", RUN_DATA_STATS_SECTION)
print("  normalization:", RUN_NORMALIZATION_SECTION)
print("  training:", RUN_TRAINING_SECTION)
print("  evaluation:", RUN_EVALUATION_SECTION)
print("  summary:", RUN_SUMMARY_SECTION)
print("  precision-recall:", RUN_PR_SECTION)
print("  coefficients:", RUN_COEFFICIENT_SECTION)
print("  report:", RUN_REPORT_SECTION)


cwd: /Users/seanmay/Desktop/Current Projects/wildfire-prediction/docs
fire selection: all
train fires config: auto
test fires config: auto
fire train fraction (auto mode): 0.7
fire split seed: 42
positive confidence threshold: 0.1
classification probability threshold: 0.5
normalize features: True
section toggles:
  fire discovery: True
  data stats: True
  normalization: True
  training: True
  evaluation: True
  summary: True
  precision-recall: True
  coefficients: True
  report: True


## 3) Imports and generic helpers


In [264]:
import json
from datetime import datetime, timedelta

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rasterio
from rasterio.warp import Resampling, reproject
from sklearn.linear_model import SGDClassifier


In [265]:
def parse_iso(value: str) -> datetime:
    if value.endswith("Z"):
        value = value[:-1] + "+00:00"
    return datetime.fromisoformat(value)


def normalize_time_str(value: str) -> str:
    dt = parse_iso(value)
    return dt.strftime("%Y-%m-%dT%H:00:00Z")


def affine_from_list(vals: list) -> rasterio.Affine:
    return rasterio.Affine(vals[0], vals[1], vals[2], vals[3], vals[4], vals[5])


In [266]:
def find_repo_root(start: Path) -> Path:
    for p in [start] + list(start.parents):
        if (p / "data").exists() and (p / "scripts").exists() and (p / "docs").exists():
            return p
    raise FileNotFoundError("Could not find repo root containing data/, scripts/, docs/.")


def load_goes_times(goes_meta: dict, goes_conf: np.ndarray):
    goes_time_steps = goes_meta.get("time_steps", [])
    goes_start = goes_meta.get("start_time")

    if goes_time_steps and isinstance(goes_time_steps[0], (int, float)):
        if not goes_start:
            raise ValueError("GOES time_steps are numeric and metadata.start_time is missing.")
        start_dt = parse_iso(goes_start)
        goes_time_steps = [
            (start_dt + timedelta(hours=int(i - 1))).strftime("%Y-%m-%dT%H:00:00Z")
            for i in goes_time_steps
        ]
    elif not goes_time_steps and goes_start:
        start_dt = parse_iso(goes_start)
        goes_time_steps = [
            (start_dt + timedelta(hours=i)).strftime("%Y-%m-%dT%H:00:00Z")
            for i in range(goes_conf.shape[0])
        ]
    else:
        goes_time_steps = [normalize_time_str(t) for t in goes_time_steps]

    if not goes_time_steps:
        raise ValueError("GOES metadata has no usable time_steps.")

    return goes_time_steps


REPO_ROOT = find_repo_root(Path.cwd())
print("repo root:", REPO_ROOT)


repo root: /Users/seanmay/Desktop/Current Projects/wildfire-prediction


## 4) Fire discovery and split


In [267]:
def discover_fire_entries(repo_root: Path):
    base = repo_root / "data" / "multi_fire"
    if not base.exists():
        raise FileNotFoundError(f"Missing multi-fire directory: {base}")

    entries = []
    for fire_dir in sorted([d for d in base.iterdir() if d.is_dir()]):
        goes_candidates = sorted(fire_dir.glob("*GOES*json"))
        manifest_path = fire_dir / "rtma" / "rtma_manifest.json"
        if not goes_candidates or not manifest_path.exists():
            continue
        entries.append(
            {
                "fire_name": fire_dir.name,
                "goes_json": goes_candidates[0],
                "rtma_manifest": manifest_path,
            }
        )
    return entries


def select_fire_entries(entries, fire_selection):
    if fire_selection is None or fire_selection == "all":
        return entries

    if not isinstance(fire_selection, (list, tuple, set)):
        raise ValueError('FIRE_SELECTION must be "all" or a list/tuple/set of fire names.')

    wanted = {str(x) for x in fire_selection}
    selected = [e for e in entries if e["fire_name"] in wanted]
    found = {e["fire_name"] for e in selected}
    missing = sorted(wanted - found)
    if missing:
        raise ValueError(f"Unknown fire names in FIRE_SELECTION: {missing}")
    return selected


In [268]:
def split_fire_entries(entries, train_fires, test_fires, train_fraction, split_seed):
    if len(entries) < 2:
        raise ValueError("Need at least 2 selected fires for fire-level train/test split.")

    names = [e["fire_name"] for e in entries]
    name_set = set(names)

    def normalize_group(value):
        if value is None or value == "auto":
            return "auto"
        if not isinstance(value, (list, tuple, set)):
            raise ValueError("TRAIN_FIRES/TEST_FIRES must be 'auto' or list/tuple/set of fire names.")
        normalized = [str(x) for x in value]
        unknown = sorted(set(normalized) - name_set)
        if unknown:
            raise ValueError(f"Unknown fire names in train/test split: {unknown}")
        return normalized

    train_group = normalize_group(train_fires)
    test_group = normalize_group(test_fires)

    if train_group == "auto" and test_group == "auto":
        if not (0.0 < train_fraction < 1.0):
            raise ValueError("FIRE_TRAIN_FRACTION must be between 0 and 1.")
        rng = np.random.default_rng(split_seed)
        perm_names = list(np.array(names)[rng.permutation(len(names))])
        n_train = int(round(train_fraction * len(perm_names)))
        n_train = max(1, min(len(perm_names) - 1, n_train))
        train_names = set(perm_names[:n_train])
        test_names = set(perm_names[n_train:])
    elif train_group == "auto":
        test_names = set(test_group)
        train_names = set(names) - test_names
    elif test_group == "auto":
        train_names = set(train_group)
        test_names = set(names) - train_names
    else:
        train_names = set(train_group)
        test_names = set(test_group)

    overlap = sorted(train_names & test_names)
    if overlap:
        raise ValueError(f"Train/test fire sets overlap: {overlap}")
    if len(train_names) == 0:
        raise ValueError("Train fire set is empty.")
    if len(test_names) == 0:
        raise ValueError("Test fire set is empty.")

    train_entries = [e for e in entries if e["fire_name"] in train_names]
    test_entries = [e for e in entries if e["fire_name"] in test_names]
    return train_entries, test_entries


In [269]:
all_fire_entries = []
fire_entries = []
train_fire_entries = []
test_fire_entries = []

if RUN_FIRE_DISCOVERY_SECTION:
    all_fire_entries = discover_fire_entries(REPO_ROOT)
    fire_entries = select_fire_entries(all_fire_entries, FIRE_SELECTION)
    train_fire_entries, test_fire_entries = split_fire_entries(
        fire_entries,
        TRAIN_FIRES,
        TEST_FIRES,
        FIRE_TRAIN_FRACTION,
        FIRE_SPLIT_SEED,
    )

    print("Available fires:", [e["fire_name"] for e in all_fire_entries])
    print("Selected fires:", [e["fire_name"] for e in fire_entries])
    print("Train fires:", [e["fire_name"] for e in train_fire_entries])
    print("Test fires:", [e["fire_name"] for e in test_fire_entries])

    if len(fire_entries) == 0:
        raise RuntimeError("No fires selected.")
else:
    print("Skipped fire discovery/split section.")


Available fires: ['Antelope', 'August_Complex', 'Beckwourth_Complex', 'Bobcat', 'CZU_Lightning_Complex', 'Caldor', 'Creek', 'Dixie', 'Dolan', 'Glass', 'July_Complex', 'KNP_Complex', 'Kincade', 'LNU_Lightning_Complex', 'McCash', 'McFarland', 'Monument', 'North_Complex', 'Red_Salmon_Complex', 'River_Complex', 'SCU_Lightning_Complex', 'SQF_Complex', 'Slater_and_Devil', 'Tamarack', 'W-5_Cold_Springs', 'Walker', 'Windy', 'Zogg']
Selected fires: ['Antelope', 'August_Complex', 'Beckwourth_Complex', 'Bobcat', 'CZU_Lightning_Complex', 'Caldor', 'Creek', 'Dixie', 'Dolan', 'Glass', 'July_Complex', 'KNP_Complex', 'Kincade', 'LNU_Lightning_Complex', 'McCash', 'McFarland', 'Monument', 'North_Complex', 'Red_Salmon_Complex', 'River_Complex', 'SCU_Lightning_Complex', 'SQF_Complex', 'Slater_and_Devil', 'Tamarack', 'W-5_Cold_Springs', 'Walker', 'Windy', 'Zogg']
Train fires: ['Antelope', 'Bobcat', 'Caldor', 'Creek', 'Dixie', 'Glass', 'July_Complex', 'KNP_Complex', 'Kincade', 'McFarland', 'Monument', 'Nort

## 5) Feature schema and sample building


In [270]:
CELL_OFFSETS = [
    ("c", 0, 0),
    ("nw", -1, -1),
    ("n", -1, 0),
    ("ne", -1, 1),
    ("w", 0, -1),
    ("e", 0, 1),
    ("sw", 1, -1),
    ("s", 1, 0),
    ("se", 1, 1),
]
VAR_ORDER = ["fire_confidence", "temperature", "wind_speed", "specific_humidity", "precipitation_1h", "wind_direction_sin", "wind_direction_cos"]
RTMA_VARS_REQUIRED = ["TMP", "WIND", "WDIR", "SPFH", "ACPC01"]


def feature_names():
    names = []
    for n_name, _, _ in CELL_OFFSETS:
        for v in VAR_ORDER:
            names.append(f"{v}_{n_name}")
    return names


FEATURE_NAMES = feature_names()
N_FEATURES = len(FEATURE_NAMES)


print("feature count:", N_FEATURES)
print("first 10 feature names:", FEATURE_NAMES[:10])


feature count: 63
first 10 feature names: ['fire_confidence_c', 'temperature_c', 'wind_speed_c', 'specific_humidity_c', 'precipitation_1h_c', 'wind_direction_sin_c', 'wind_direction_cos_c', 'fire_confidence_nw', 'temperature_nw', 'wind_speed_nw']


In [271]:
def to_binary_target(y_continuous: np.ndarray, threshold: float) -> np.ndarray:
    return (y_continuous >= threshold).astype(np.int32)


def resolve_manifest_file_path(path_str: str, repo_root: Path, manifest_dir: Path) -> Path:
    p = Path(path_str).expanduser()
    if p.exists():
        return p

    parts = p.parts
    if "data" in parts:
        idx = parts.index("data")
        candidate = repo_root.joinpath(*parts[idx:])
        if candidate.exists():
            return candidate

    candidate = (manifest_dir / path_str).resolve()
    if candidate.exists():
        return candidate

    raise FileNotFoundError(f"Could not resolve RTMA part path: {path_str}")


In [272]:
def resample_stack(src_stack, src_transform, src_crs, dst_shape, dst_transform, dst_crs):
    bands = src_stack.shape[0]
    dst = np.empty((bands, dst_shape[0], dst_shape[1]), dtype=np.float32)
    for b in range(bands):
        reproject(
            source=src_stack[b],
            destination=dst[b],
            src_transform=src_transform,
            src_crs=src_crs,
            dst_transform=dst_transform,
            dst_crs=dst_crs,
            resampling=Resampling.bilinear,
        )
    return dst


def build_hour_samples(conf_t, conf_t1, rtma_hour):
    h, w = conf_t.shape
    if h < 3 or w < 3:
        return np.empty((0, N_FEATURES), dtype=np.float64), np.empty((0,), dtype=np.float64)

    y = conf_t1[1:-1, 1:-1].astype(np.float64)
    feat_blocks = []

    for _, dy, dx in CELL_OFFSETS:
        ys = slice(1 + dy, h - 1 + dy)
        xs = slice(1 + dx, w - 1 + dx)

        go_cell = conf_t[ys, xs].astype(np.float64)
        tmp_cell = rtma_hour["TMP"][ys, xs].astype(np.float64)
        wind_cell = rtma_hour["WIND"][ys, xs].astype(np.float64)
        spfh_cell = rtma_hour["SPFH"][ys, xs].astype(np.float64)
        precip_cell = rtma_hour["ACPC01"][ys, xs].astype(np.float64)
        wdir_deg_cell = rtma_hour["WDIR"][ys, xs].astype(np.float64)
        wdir_rad_cell = np.deg2rad(wdir_deg_cell)
        wdir_sin_cell = np.sin(wdir_rad_cell)
        wdir_cos_cell = np.cos(wdir_rad_cell)

        feat_blocks.extend([
            go_cell,
            tmp_cell,
            wind_cell,
            spfh_cell,
            precip_cell,
            wdir_sin_cell,
            wdir_cos_cell,
        ])

    X = np.stack(feat_blocks, axis=-1).reshape(-1, N_FEATURES)
    y = y.reshape(-1)

    valid = np.isfinite(y)
    valid &= np.isfinite(X).all(axis=1)

    if not valid.any():
        return np.empty((0, N_FEATURES), dtype=np.float64), np.empty((0,), dtype=np.float64)

    return X[valid], y[valid]


In [273]:
def iter_aligned_hours_for_fire(
    goes_conf,
    goes_time_index,
    rtma_manifest,
    rtma_manifest_path: Path,
    goes_shape,
    goes_transform,
    goes_crs,
):
    rtma_vars = rtma_manifest["variables"]
    for req in RTMA_VARS_REQUIRED:
        if req not in rtma_vars:
            raise KeyError(f"RTMA manifest missing required variable: {req}")

    manifest_dir = rtma_manifest_path.parent
    rtma_files = rtma_manifest["files"]
    resolved_files = {
        var: [resolve_manifest_file_path(path, REPO_ROOT, manifest_dir) for path in rtma_files[var]]
        for var in rtma_vars
    }

    n_parts = len(resolved_files[rtma_vars[0]])
    for v in rtma_vars:
        if len(resolved_files[v]) != n_parts:
            raise ValueError("RTMA variable file lists do not have equal part counts.")

    parts = list(zip(*[resolved_files[v] for v in rtma_vars]))
    rtma_time_steps = [normalize_time_str(t) for t in rtma_manifest["time_steps"]]

    rtma_time_ptr = 0

    for part_paths in parts:
        rtma_arrays = {}
        rtma_transform = None
        rtma_crs = None
        band_count = None

        for var, part_path in zip(rtma_vars, part_paths):
            with rasterio.open(part_path) as ds:
                if rtma_transform is None:
                    rtma_transform = ds.transform
                    rtma_crs = ds.crs
                    band_count = ds.count
                rtma_arrays[var] = ds.read().astype("float32")

        if band_count is None:
            continue

        resampled = {}
        for var in rtma_vars:
            resampled[var] = resample_stack(
                rtma_arrays[var],
                rtma_transform,
                rtma_crs,
                goes_shape,
                goes_transform,
                goes_crs,
            )

        for local_idx in range(band_count):
            global_idx = rtma_time_ptr + local_idx
            if global_idx >= len(rtma_time_steps):
                break

            time_str = rtma_time_steps[global_idx]
            if time_str not in goes_time_index:
                continue

            t = goes_time_index[time_str]
            if t + 1 >= goes_conf.shape[0]:
                continue

            rtma_hour = {var: resampled[var][local_idx] for var in RTMA_VARS_REQUIRED}
            yield t, rtma_hour

        rtma_time_ptr += band_count


In [274]:
def iter_fire_hour_samples(entry):
    with Path(entry["goes_json"]).open("r", encoding="utf-8") as f:
        goes_json_local = json.load(f)
    with Path(entry["rtma_manifest"]).open("r", encoding="utf-8") as f:
        rtma_manifest_local = json.load(f)

    goes_conf_local = np.array(goes_json_local["data"], dtype=np.float32)
    goes_meta_local = goes_json_local["metadata"]
    goes_transform_local = affine_from_list(goes_meta_local["geo_transform"])
    goes_crs_local = goes_meta_local.get("crs")
    goes_shape_local = tuple(goes_meta_local["grid_shape"])
    goes_times_local = load_goes_times(goes_meta_local, goes_conf_local)
    goes_time_index_local = {t: i for i, t in enumerate(goes_times_local)}

    for t, rtma_hour in iter_aligned_hours_for_fire(
        goes_conf_local,
        goes_time_index_local,
        rtma_manifest_local,
        Path(entry["rtma_manifest"]),
        goes_shape_local,
        goes_transform_local,
        goes_crs_local,
    ):
        X_hour, y_hour_cont = build_hour_samples(goes_conf_local[t], goes_conf_local[t + 1], rtma_hour)
        if X_hour.shape[0] == 0:
            continue
        y_hour = to_binary_target(y_hour_cont, POSITIVE_THRESHOLD)
        yield entry["fire_name"], t, X_hour, y_hour


## 6) Dataset pass 1 (sample/class stats)


In [None]:
n_total_samples = 0
n_train_samples = 0
n_test_samples = 0
n_hours_used = 0
n_train_hours = 0
n_test_hours = 0

train_pos = 0
train_neg = 0
test_pos = 0
test_neg = 0

if RUN_DATA_STATS_SECTION:
    if not train_fire_entries or not test_fire_entries:
        raise RuntimeError("Data stats section requires fire discovery section to be enabled.")

    for entry in train_fire_entries:
        for fire_name, t, X_hour, y_hour in iter_fire_hour_samples(entry):
            n_hours_used += 1
            n_train_hours += 1
            n_rows = X_hour.shape[0]
            n_total_samples += n_rows
            n_train_samples += n_rows
            train_pos += int(y_hour.sum())
            train_neg += int(y_hour.shape[0] - y_hour.sum())

    for entry in test_fire_entries:
        for fire_name, t, X_hour, y_hour in iter_fire_hour_samples(entry):
            n_hours_used += 1
            n_test_hours += 1
            n_rows = X_hour.shape[0]
            n_total_samples += n_rows
            n_test_samples += n_rows
            test_pos += int(y_hour.sum())
            test_neg += int(y_hour.shape[0] - y_hour.sum())

    if n_train_samples == 0:
        raise RuntimeError("No training samples from train-fire set.")
    if n_test_samples == 0:
        raise RuntimeError("No testing samples from test-fire set.")

    train_pos_rate = train_pos / n_train_samples
    test_pos_rate = test_pos / n_test_samples

    print("train fires:", [e["fire_name"] for e in train_fire_entries])
    print("test fires:", [e["fire_name"] for e in test_fire_entries])
    print("hours used total:", n_hours_used)
    print("hours used train:", n_train_hours)
    print("hours used test:", n_test_hours)
    print("total samples:", n_total_samples)
    print("train samples:", n_train_samples)
    print("test samples:", n_test_samples)
    print("observed train sample fraction:", n_train_samples / n_total_samples)
    print("observed test sample fraction:", n_test_samples / n_total_samples)
    print("train positives:", train_pos, "train negatives:", train_neg, "train pos rate:", train_pos_rate)
    print("test positives:", test_pos, "test negatives:", test_neg, "test pos rate:", test_pos_rate)
else:
    train_pos_rate = None
    test_pos_rate = None
    print("Skipped data stats section.")


## 7) Normalization


In [None]:
feature_mean = np.zeros(N_FEATURES, dtype=np.float64)
feature_std = np.ones(N_FEATURES, dtype=np.float64)
feature_std_safe = np.ones(N_FEATURES, dtype=np.float64)
n_norm_samples = 0
n_zero_std_features = 0


def normalize_X(X):
    return X


if RUN_NORMALIZATION_SECTION:
    if NORMALIZE_FEATURES:
        if not train_fire_entries:
            raise RuntimeError("Normalization section requires fire discovery section to be enabled.")

        feature_sum = np.zeros(N_FEATURES, dtype=np.float64)
        feature_sq_sum = np.zeros(N_FEATURES, dtype=np.float64)

        for entry in train_fire_entries:
            for fire_name, t, X_hour, y_hour in iter_fire_hour_samples(entry):
                X_block = X_hour.astype(np.float64)
                feature_sum += X_block.sum(axis=0)
                feature_sq_sum += np.square(X_block).sum(axis=0)
                n_norm_samples += X_block.shape[0]

        if n_norm_samples == 0:
            raise RuntimeError("No training samples available for normalization stats.")

        feature_mean = feature_sum / n_norm_samples
        feature_var = (feature_sq_sum / n_norm_samples) - np.square(feature_mean)
        feature_var = np.maximum(feature_var, 0.0)
        feature_std = np.sqrt(feature_var)
        feature_std_safe = np.where(feature_std > 0.0, feature_std, 1.0)
        n_zero_std_features = int((feature_std == 0.0).sum())

        def normalize_X(X):
            return (X - feature_mean) / feature_std_safe

        print("normalization mode:", "zscore_from_train_fires")
        print("normalization samples:", n_norm_samples)
        print("zero-std feature count:", n_zero_std_features)
    else:
        print("normalization mode:", "none (NORMALIZE_FEATURES=False)")
else:
    print("Skipped normalization section.")


## 8) Train logistic regression


In [None]:
clf = None
trained = False
intercept = None
coef_map = {}

if RUN_TRAINING_SECTION:
    if not train_fire_entries:
        raise RuntimeError("Training section requires fire discovery section to be enabled.")

    clf = SGDClassifier(
        loss="log_loss",
        penalty="l2",
        alpha=1e-4,
        max_iter=1,
        tol=None,
        random_state=None,
    )

    classes = np.array([0, 1], dtype=np.int32)

    for entry in train_fire_entries:
        for fire_name, t, X_hour, y_hour in iter_fire_hour_samples(entry):
            X_train = normalize_X(X_hour.astype(np.float64))
            y_train = y_hour

            if not trained:
                clf.partial_fit(X_train, y_train, classes=classes)
                trained = True
            else:
                clf.partial_fit(X_train, y_train)

    if not trained:
        raise RuntimeError("Model did not receive training data from train-fire set.")

    intercept = float(clf.intercept_[0])
    coef_std = clf.coef_.ravel()
    coef_map = {name: float(val) for name, val in zip(FEATURE_NAMES, coef_std)}

    print("intercept:", intercept)
    print("coefficient count:", len(coef_map))
else:
    print("Skipped training section.")


## 9) Evaluate on held-out test fires


In [None]:
correct_test = 0
count_test_eval = 0

tp = 0
fp = 0
fn = 0
tn = 0

test_accuracy_overall = None
test_positive_accuracy = None
test_negative_accuracy = None

if RUN_EVALUATION_SECTION:
    if clf is None or not trained:
        raise RuntimeError("Evaluation section requires training section to run first.")
    if not test_fire_entries:
        raise RuntimeError("Evaluation section requires fire discovery section to be enabled.")

    for entry in test_fire_entries:
        for fire_name, t, X_hour, y_hour in iter_fire_hour_samples(entry):
            X_test = normalize_X(X_hour.astype(np.float64))
            y_test = y_hour

            prob_test = clf.predict_proba(X_test)[:, 1]
            y_hat = (prob_test >= CLASSIFICATION_PROB_THRESHOLD).astype(np.int32)

            correct_test += int((y_hat == y_test).sum())
            count_test_eval += y_test.shape[0]

            tp += int(((y_hat == 1) & (y_test == 1)).sum())
            fp += int(((y_hat == 1) & (y_test == 0)).sum())
            fn += int(((y_hat == 0) & (y_test == 1)).sum())
            tn += int(((y_hat == 0) & (y_test == 0)).sum())

    if count_test_eval == 0:
        raise RuntimeError("No valid evaluation samples in full test-fire set.")

    test_accuracy_overall = float(correct_test / count_test_eval)
    test_positive_accuracy = float(tp / (tp + fn)) if (tp + fn) > 0 else None
    test_negative_accuracy = float(tn / (tn + fp)) if (tn + fp) > 0 else None

    print("test evaluation samples:", count_test_eval)
    print("test accuracy overall:", test_accuracy_overall)
    print("test positive accuracy:", test_positive_accuracy)
    print("test negative accuracy:", test_negative_accuracy)
    print("TP:", tp, "FP:", fp, "FN:", fn, "TN:", tn)
else:
    print("Skipped evaluation section.")


## 10) Summary tables


In [None]:
summary_df = None

if RUN_SUMMARY_SECTION:
    summary_df = pd.DataFrame(
        [
            {
                "model": "logistic_regression",
                "target": "center_confidence_t+1_binary",
                "fires_used_count": len(fire_entries),
                "fires_used": [e["fire_name"] for e in fire_entries],
                "train_fires_count": len(train_fire_entries),
                "test_fires_count": len(test_fire_entries),
                "train_fires": [e["fire_name"] for e in train_fire_entries],
                "test_fires": [e["fire_name"] for e in test_fire_entries],
                "positive_threshold": POSITIVE_THRESHOLD,
                "total_samples": int(n_total_samples),
                "train_samples": int(n_train_samples),
                "test_samples": int(n_test_samples),
                "hours_used": int(n_hours_used),
                "train_hours": int(n_train_hours),
                "test_hours": int(n_test_hours),
                "train_positive_rate": train_pos_rate,
                "test_positive_rate": test_pos_rate,
                "test_accuracy_overall": test_accuracy_overall,
                "test_positive_accuracy": test_positive_accuracy,
                "test_negative_accuracy": test_negative_accuracy,
                "tp": int(tp),
                "fp": int(fp),
                "fn": int(fn),
                "tn": int(tn),
                "classification_prob_threshold": CLASSIFICATION_PROB_THRESHOLD,
                "feature_normalization": "zscore_from_train_fires" if NORMALIZE_FEATURES else "none",
                "intercept": intercept,
            }
        ]
    )
    summary_df
else:
    print("Skipped summary table section.")


In [None]:
confusion_breakdown_df = None

if RUN_SUMMARY_SECTION:
    confusion_breakdown_df = pd.DataFrame(
        [
            {
                "true_positives": int(tp),
                "false_positives": int(fp),
                "false_negatives": int(fn),
                "true_negatives": int(tn),
            }
        ]
    )
    confusion_breakdown_df
else:
    print("Skipped confusion breakdown table.")


## 11) Validation-fire PR curve (threshold selection)

Use this section to choose `CLASSIFICATION_PROB_THRESHOLD` from validation fires (not from test fires).

- Split method: fire-level split inside the current `train_fire_entries`
- Output: validation PR curve + `val_pr_df` sorted by `f1`


In [None]:
# Validation-fire PR curve + best-threshold table (sorted by F1)

# Validation split config.
VALIDATION_FIRES = "auto"  # "auto" or explicit list like ["Creek", "Dixie"]
VALIDATION_FRACTION_OF_TRAIN_FIRES = 0.30
VALIDATION_SPLIT_SEED = 123
VAL_PR_THRESHOLDS = np.linspace(0.0, 1.0, 1001)

# Split current train fires into inner-train and validation by full fires.
train_names = [e["fire_name"] for e in train_fire_entries]
if len(train_names) < 2:
    raise RuntimeError("Need at least 2 train fires to make a validation split.")

if VALIDATION_FIRES == "auto":
    rng = np.random.default_rng(VALIDATION_SPLIT_SEED)
    perm = list(np.array(train_names)[rng.permutation(len(train_names))])
    n_val = max(1, int(round(len(perm) * VALIDATION_FRACTION_OF_TRAIN_FIRES)))
    n_val = min(n_val, len(perm) - 1)
    val_name_set = set(perm[:n_val])
else:
    if not isinstance(VALIDATION_FIRES, (list, tuple, set)):
        raise ValueError('VALIDATION_FIRES must be "auto" or list/tuple/set of fire names.')
    val_name_set = {str(x) for x in VALIDATION_FIRES}

inner_train_entries = [e for e in train_fire_entries if e["fire_name"] not in val_name_set]
val_entries = [e for e in train_fire_entries if e["fire_name"] in val_name_set]

if not inner_train_entries or not val_entries:
    raise RuntimeError("Validation split produced empty inner-train or validation fire set.")

print("inner-train fires:", [e["fire_name"] for e in inner_train_entries])
print("validation fires:", [e["fire_name"] for e in val_entries])

# Fit normalization from inner-train only.
if NORMALIZE_FEATURES:
    s = np.zeros(N_FEATURES, dtype=np.float64)
    ss = np.zeros(N_FEATURES, dtype=np.float64)
    n = 0
    for entry in inner_train_entries:
        for fire_name, t, X_hour, y_hour in iter_fire_hour_samples(entry):
            Xb = X_hour.astype(np.float64)
            s += Xb.sum(axis=0)
            ss += np.square(Xb).sum(axis=0)
            n += Xb.shape[0]

    if n == 0:
        raise RuntimeError("No inner-train samples available for validation normalization.")

    mu = s / n
    var = np.maximum((ss / n) - np.square(mu), 0.0)
    std = np.sqrt(var)
    std_safe = np.where(std > 0.0, std, 1.0)

    def val_normalize_X(X):
        return (X - mu) / std_safe
else:

    def val_normalize_X(X):
        return X.astype(np.float64)

# Train temporary validation model on inner-train fires.
val_clf = SGDClassifier(
    loss="log_loss",
    penalty="l2",
    alpha=1e-4,
    max_iter=1,
    tol=None,
    random_state=None,
)

classes = np.array([0, 1], dtype=np.int32)
val_trained = False

for entry in inner_train_entries:
    for fire_name, t, X_hour, y_hour in iter_fire_hour_samples(entry):
        X_train = val_normalize_X(X_hour)
        y_train = y_hour
        if not val_trained:
            val_clf.partial_fit(X_train, y_train, classes=classes)
            val_trained = True
        else:
            val_clf.partial_fit(X_train, y_train)

if not val_trained:
    raise RuntimeError("Validation model did not receive inner-train samples.")

# Compute precision-recall on validation fires.
val_pr_tp = np.zeros(VAL_PR_THRESHOLDS.shape[0], dtype=np.int64)
val_pr_fp = np.zeros(VAL_PR_THRESHOLDS.shape[0], dtype=np.int64)
val_total_pos = 0
val_total_neg = 0

for entry in val_entries:
    for fire_name, t, X_hour, y_hour in iter_fire_hour_samples(entry):
        X_val = val_normalize_X(X_hour)
        y_val = y_hour.astype(np.int32)

        prob = val_clf.predict_proba(X_val)[:, 1]
        pos = y_val == 1
        val_total_pos += int(pos.sum())
        val_total_neg += int((~pos).sum())

        pred = prob[:, None] >= VAL_PR_THRESHOLDS[None, :]
        val_pr_tp += (pred & pos[:, None]).sum(axis=0).astype(np.int64)
        val_pr_fp += (pred & (~pos)[:, None]).sum(axis=0).astype(np.int64)

if val_total_pos == 0:
    raise RuntimeError("No positive samples in validation set; cannot compute precision/recall.")

val_pr_fn = val_total_pos - val_pr_tp
val_pr_tn = val_total_neg - val_pr_fp

val_pr_precision = np.where((val_pr_tp + val_pr_fp) > 0, val_pr_tp / (val_pr_tp + val_pr_fp), 1.0)
val_pr_recall = val_pr_tp / val_total_pos
val_pr_f1 = np.where(
    (val_pr_precision + val_pr_recall) > 0,
    2 * val_pr_precision * val_pr_recall / (val_pr_precision + val_pr_recall),
    0.0,
)

val_pr_df = pd.DataFrame(
    {
        "threshold": VAL_PR_THRESHOLDS,
        "precision": val_pr_precision,
        "recall": val_pr_recall,
        "f1": val_pr_f1,
        "tp": val_pr_tp,
        "fp": val_pr_fp,
        "fn": val_pr_fn,
        "tn": val_pr_tn,
    }
).sort_values("f1", ascending=False)

best_val = val_pr_df.iloc[0]
val_baseline = val_total_pos / (val_total_pos + val_total_neg)

print("validation positive rate (baseline precision):", float(val_baseline))
print("best threshold by F1 (validation):", float(best_val["threshold"]))
print(
    "best precision/recall/F1:",
    float(best_val["precision"]),
    float(best_val["recall"]),
    float(best_val["f1"]),
)

val_pr_plot_df = val_pr_df.sort_values("recall")

plt.figure(figsize=(7, 4))
plt.plot(val_pr_plot_df["recall"], val_pr_plot_df["precision"], linewidth=2)
plt.hlines(
    val_baseline,
    0,
    1,
    linestyles="dashed",
    colors="gray",
    label=f"baseline (pos rate={val_baseline:.3f})",
)
plt.xlim(0, 1)
plt.ylim(0, 1)
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision-Recall Curve (Validation Fires)")
plt.grid(True, alpha=0.25)
plt.legend()
plt.show()

# Top thresholds by F1 for manual threshold selection.
val_pr_df.head(20)


## 11) Precision-Recall curve (test split)

Precision-recall is often more informative than ROC when positives are rare.

- Precision = `TP / (TP + FP)`
- Recall = `TP / (TP + FN)`

This sweeps the *classification probability threshold* (not the label threshold) over the held-out test split.


In [None]:
pr_df = None

if RUN_PR_SECTION:
    if clf is None or not trained:
        raise RuntimeError("Precision-recall section requires training section to run first.")
    if not test_fire_entries:
        raise RuntimeError("Precision-recall section requires fire discovery section to be enabled.")

    PR_THRESHOLDS = np.linspace(0.0, 1.0, 1001)

    pr_tp = np.zeros(PR_THRESHOLDS.shape[0], dtype=np.int64)
    pr_fp = np.zeros(PR_THRESHOLDS.shape[0], dtype=np.int64)

    pr_total_pos = 0
    pr_total_neg = 0

    for entry in test_fire_entries:
        for fire_name, t, X_hour, y_hour in iter_fire_hour_samples(entry):
            X_test = normalize_X(X_hour.astype(np.float64))
            y_test = y_hour.astype(np.int32)

            prob = clf.predict_proba(X_test)[:, 1]
            pos = y_test == 1
            pr_total_pos += int(pos.sum())
            pr_total_neg += int((~pos).sum())

            pred = prob[:, None] >= PR_THRESHOLDS[None, :]
            pr_tp += (pred & pos[:, None]).sum(axis=0).astype(np.int64)
            pr_fp += (pred & (~pos)[:, None]).sum(axis=0).astype(np.int64)

    if pr_total_pos == 0:
        raise RuntimeError("No positive samples in the test-fire set; cannot compute precision/recall.")

    pr_fn = pr_total_pos - pr_tp
    pr_tn = pr_total_neg - pr_fp

    pr_precision = np.where((pr_tp + pr_fp) > 0, pr_tp / (pr_tp + pr_fp), 1.0)
    pr_recall = pr_tp / pr_total_pos

    pr_f1 = np.where(
        (pr_precision + pr_recall) > 0,
        2 * pr_precision * pr_recall / (pr_precision + pr_recall),
        0.0,
    )

    pr_df = pd.DataFrame(
        {
            "threshold": PR_THRESHOLDS,
            "precision": pr_precision,
            "recall": pr_recall,
            "f1": pr_f1,
            "tp": pr_tp,
            "fp": pr_fp,
            "fn": pr_fn,
            "tn": pr_tn,
        }
    )

    best = pr_df.iloc[int(pr_df["f1"].idxmax())]
    baseline = pr_total_pos / (pr_total_pos + pr_total_neg)

    print("test-fire positive rate (baseline precision at recall=1):", float(baseline))
    print("best threshold by F1:", float(best["threshold"]))
    print("precision:", float(best["precision"]), "recall:", float(best["recall"]), "f1:", float(best["f1"]))

    pr_plot_df = pr_df.sort_values("recall")

    plt.figure(figsize=(7, 4))
    plt.plot(pr_plot_df["recall"], pr_plot_df["precision"], linewidth=2)
    plt.hlines(baseline, 0, 1, linestyles="dashed", colors="gray", label=f"baseline (pos rate={baseline:.3f})")
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.xlabel("Recall (TP / (TP+FN))")
    plt.ylabel("Precision (TP / (TP+FP))")
    plt.title("Precision-Recall Curve (Test-Fire Set)")
    plt.grid(True, alpha=0.25)
    plt.legend()
    plt.show()

    pr_df.sort_values("f1", ascending=False).head(12)
else:
    print("Skipped precision-recall section.")


In [None]:
pr_df.sort_values("f1", ascending=False).head(12)
pr_df.head(501)

## 12) Coefficients (top features)


In [None]:
coef_df = None
coef_top = None

if RUN_COEFFICIENT_SECTION:
    if not coef_map:
        raise RuntimeError("Coefficient section requires training section to run first.")

    coef_rows = []
    for feat, coef in coef_map.items():
        coef_rows.append(
            {
                "feature": feat,
                "coef": coef,
                "odds_ratio": float(np.exp(coef)),
                "abs_coef": abs(coef),
            }
        )

    coef_df = pd.DataFrame(coef_rows).sort_values("abs_coef", ascending=False)
    coef_top = coef_df.head(20).drop(columns=["abs_coef"])
    coef_top
else:
    print("Skipped coefficient table section.")


In [None]:
if RUN_COEFFICIENT_SECTION:
    if coef_df is None:
        raise RuntimeError("Coefficient plot section requires coefficient table generation.")

    coef_plot_df = coef_df.head(10).iloc[::-1]
    colors = ["#1f77b4" if c >= 0 else "#d62728" for c in coef_plot_df["coef"]]

    plt.figure(figsize=(8, 6))
    plt.barh(coef_plot_df["feature"], coef_plot_df["coef"], color=colors)
    plt.axvline(0.0, color="black", linewidth=0.8)
    plt.title("Top logistic coefficients")
    plt.xlabel("Coefficient (log-odds)")
    plt.tight_layout()
    plt.show()
else:
    print("Skipped coefficient plot section.")


## 13) JSON-style report object


In [None]:
report = None

if RUN_REPORT_SECTION:
    report = {
        "model": "logistic_regression",
        "target": "center_confidence_t_plus_1_binary",
        "fires_used": [e["fire_name"] for e in fire_entries],
        "train_fires": [e["fire_name"] for e in train_fire_entries],
        "test_fires": [e["fire_name"] for e in test_fire_entries],
        "thresholds": {
            "positive_confidence": POSITIVE_THRESHOLD,
            "classification_probability": CLASSIFICATION_PROB_THRESHOLD,
        },
        "split": {
            "method": "fire_holdout",
            "train_fire_count": len(train_fire_entries),
            "test_fire_count": len(test_fire_entries),
            "train_fire_fraction_target": FIRE_TRAIN_FRACTION,
            "split_seed": FIRE_SPLIT_SEED,
        },
        "feature_order": FEATURE_NAMES,
        "feature_normalization": {
            "enabled": NORMALIZE_FEATURES,
            "method": "zscore_from_train_fires" if NORMALIZE_FEATURES else "none",
            "samples_used": int(n_norm_samples),
            "zero_std_feature_count": int(n_zero_std_features),
        },
        "metrics_test": {
            "test_accuracy_overall": test_accuracy_overall,
            "test_positive_accuracy": test_positive_accuracy,
            "test_negative_accuracy": test_negative_accuracy,
            "tp": int(tp),
            "fp": int(fp),
            "fn": int(fn),
            "tn": int(tn),
        },
        "class_balance": {
            "train_positive_rate": train_pos_rate,
            "test_positive_rate": test_pos_rate,
            "train_positives": int(train_pos),
            "train_negatives": int(train_neg),
            "test_positives": int(test_pos),
            "test_negatives": int(test_neg),
        },
        "coefficients": {
            "intercept": intercept,
            "values": coef_map,
        },
        "data": {
            "total_samples": int(n_total_samples),
            "train_samples": int(n_train_samples),
            "test_samples": int(n_test_samples),
            "hours_used": int(n_hours_used),
            "train_hours": int(n_train_hours),
            "test_hours": int(n_test_hours),
        },
    }

    print("Report keys:", list(report.keys()))
else:
    print("Skipped report section.")


## 14) Notes and constraints

- GOES confidence is a proxy signal, not direct flame-front geometry.
- This is a linear decision-boundary model (logistic regression), so non-linear effects are not modeled.
- Features use center + 8-neighbor cell data at current hour.
- Wind direction is encoded as `sin`/`cos` (from RTMA degree values), not passed as raw degrees.
- Split is by complete fires: train and test sets use different fires (no within-fire sample split).
- Features are z-score normalized using train-fire statistics only, then applied to both train and test data.
- `test_positive_accuracy` = `TP / (TP + FN)` and `test_negative_accuracy` = `TN / (TN + FP)`.
- Runtime can be substantial because this uses all samples across full train/test fire groups.
