In [None]:
"""
Backfill historical predictions from snapshots.csv without leakage.

Output:
  data/artifacts_snow/predictions_backfill_prob.csv

What it does:
- Loads snapshots (local or remote)
- Builds labels from lastserviced changes (next service time)
- Builds storm envelopes from NWS alerts log and refines operational start/end using observed service activity
- Adds features (including storm clock + 15m tempo + phase + route completion)
- For each eventid:
    - For a set of "as_of" cutpoints, train on data <= cutpoint
    - Predict probabilities on the next time window
- Enforces "point-in-time" correctness:
    - Keep only rows with as_of_cut_ts < snapshot_ts
    - For each (eventid, segment, snapshot_ts), keep the latest as_of_cut_ts BEFORE snapshot_ts

Saves per-(eventid, segment, snapshot_ts):
  p_1h, p_2h, p_4h, p_8h
  plus labels and key metadata.
"""

from __future__ import annotations

from pathlib import Path
import os
import json
import numpy as np
import pandas as pd

from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.calibration import CalibratedClassifierCV
from sklearn.isotonic import IsotonicRegression
from sklearn.linear_model import LogisticRegression

# -----------------------------
# Constants / paths
# -----------------------------
EVENT = "eventid"
SEG = "snowroutesegmentid"
TS = "snapshot_ts"
HORIZONS = [1, 2, 4, 8]

SNAPSHOT_URL = (
    "https://raw.githubusercontent.com/"
    "samedelstein/snow_map_dashboard/main/"
    "data/snapshot_snow_routes/snapshots.csv"
)

NWS_ALERTS_LOG_REMOTE = (
    "https://raw.githubusercontent.com/samedelstein/snow_map_dashboard/main/"
    "data/artifacts_snow/nws_alerts_log.csv"
)

CALIBRATION_MIN_ROWS_ISOTONIC = 500

def get_repo_root() -> Path:
    if "__file__" in globals():
        return Path(__file__).resolve().parents[1]
    return Path.cwd().resolve()

REPO_ROOT = get_repo_root()
DATA_DIR = REPO_ROOT / "data"
ARTIFACT_DIR = DATA_DIR / "artifacts_snow"
ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)

NWS_ALERTS_LOG_LOCAL = str(ARTIFACT_DIR / "nws_alerts_log.csv")

# storm envelope params (match your production)
SNOW_ALERT_EVENTS = {
    "Winter Storm Warning",
    "Winter Storm Watch",
    "Winter Weather Advisory",
}
ALERT_START_PAD_H = 6
ALERT_END_PAD_H = 24

# operational detection params (match production, tune if needed)
OPS_BUCKET_MIN = "15min"
OPS_MIN_SERVICES_PER_BUCKET = 10
OPS_SUSTAIN_BUCKETS = 2


# -----------------------------
# Calibrator helper (fallback)
# -----------------------------
class _PrefitCalibrator:
    def __init__(self, estimator: HistGradientBoostingClassifier, method: str) -> None:
        self.estimator = estimator
        self.method = method
        self.calibrator: IsotonicRegression | LogisticRegression | None = None
        self.classes_ = np.array([0, 1])

    def fit(self, X: pd.DataFrame, y: pd.Series) -> "_PrefitCalibrator":
        p = self.estimator.predict_proba(X)[:, 1]
        if self.method == "isotonic":
            self.calibrator = IsotonicRegression(out_of_bounds="clip")
            self.calibrator.fit(p, y)
        else:
            self.calibrator = LogisticRegression(solver="lbfgs")
            self.calibrator.fit(p.reshape(-1, 1), y)
        return self

    def predict_proba(self, X: pd.DataFrame) -> np.ndarray:
        if self.calibrator is None:
            raise ValueError("Calibrator not fitted.")
        p = self.estimator.predict_proba(X)[:, 1]
        if isinstance(self.calibrator, IsotonicRegression):
            p_cal = self.calibrator.predict(p)
        else:
            p_cal = self.calibrator.predict_proba(p.reshape(-1, 1))[:, 1]
        p_cal = np.clip(p_cal, 0.0, 1.0)
        return np.column_stack([1 - p_cal, p_cal])


def _fit_prefit_calibrator(
    estimator: HistGradientBoostingClassifier,
    X_calib: pd.DataFrame,
    y_calib: pd.Series,
    method: str,
):
    try:
        cal = CalibratedClassifierCV(estimator, cv="prefit", method=method)
        cal.fit(X_calib, y_calib)
        return cal
    except Exception:
        cal = _PrefitCalibrator(estimator, method)
        cal.fit(X_calib, y_calib)
        return cal


# -----------------------------
# Alerts -> storm envelopes -> operational windows
# -----------------------------
def load_alert_log(source: str | None = None) -> pd.DataFrame:
    src = source or (NWS_ALERTS_LOG_LOCAL if os.path.exists(NWS_ALERTS_LOG_LOCAL) else NWS_ALERTS_LOG_REMOTE)
    try:
        a = pd.read_csv(src)
    except Exception:
        return pd.DataFrame(columns=["event", "start_ts", "end_ts", "severity"])

    if a.empty:
        return pd.DataFrame(columns=["event", "start_ts", "end_ts", "severity"])

    a["start_ts"] = pd.to_datetime(a["start_ts"], utc=True, errors="coerce")
    a["end_ts"] = pd.to_datetime(a["end_ts"], utc=True, errors="coerce")
    a["event"] = a["event"].astype(str)
    a["severity"] = a.get("severity", "").astype(str)
    a = a[a["start_ts"].notna() & a["end_ts"].notna()].copy()
    return a


def build_storm_envelopes_from_alerts(alerts: pd.DataFrame) -> pd.DataFrame:
    if alerts.empty:
        return pd.DataFrame(columns=["storm_id", "storm_envelope_start", "storm_envelope_end", "severity_max"])

    a = alerts[alerts["event"].isin(SNOW_ALERT_EVENTS)].copy()
    if a.empty:
        return pd.DataFrame(columns=["storm_id", "storm_envelope_start", "storm_envelope_end", "severity_max"])

    a["storm_envelope_start"] = a["start_ts"] - pd.Timedelta(hours=ALERT_START_PAD_H)
    a["storm_envelope_end"] = a["end_ts"] + pd.Timedelta(hours=ALERT_END_PAD_H)

    sev_rank = {"Minor": 1, "Moderate": 2, "Severe": 3, "Extreme": 4}
    a["sev_rank"] = a["severity"].map(sev_rank).fillna(0).astype(int)

    a = a.sort_values("storm_envelope_start").reset_index(drop=True)

    merged: list[tuple[pd.Timestamp, pd.Timestamp, int]] = []
    cur_start = None
    cur_end = None
    cur_sev = 0

    for row in a.itertuples(index=False):
        s = row.storm_envelope_start
        e = row.storm_envelope_end
        sev = int(row.sev_rank)

        if cur_start is None:
            cur_start, cur_end, cur_sev = s, e, sev
            continue

        if s <= cur_end:
            cur_end = max(cur_end, e)
            cur_sev = max(cur_sev, sev)
        else:
            merged.append((cur_start, cur_end, cur_sev))
            cur_start, cur_end, cur_sev = s, e, sev

    if cur_start is not None:
        merged.append((cur_start, cur_end, cur_sev))

    inv_sev = {v: k for k, v in sev_rank.items()}
    out = pd.DataFrame(
        [
            {
                "storm_id": f"storm_{i+1:03d}",
                "storm_envelope_start": s,
                "storm_envelope_end": e,
                "severity_max": inv_sev.get(sev, "Unknown"),
            }
            for i, (s, e, sev) in enumerate(merged)
        ]
    )
    return out


def derive_city_service_events(labeled: pd.DataFrame) -> pd.DataFrame:
    if labeled.empty:
        return pd.DataFrame(columns=[EVENT, TS])

    tmp = labeled.sort_values([EVENT, SEG, TS]).copy()
    tmp["prev_last"] = tmp.groupby([EVENT, SEG])["lastserviced"].shift(1)

    # IMPORTANT: prevent first-row NaT/NaT from becoming a "change"
    tmp["lastserviced_changed"] = (
        tmp["lastserviced"].notna()
        & tmp["prev_last"].notna()
        & (tmp["lastserviced"] != tmp["prev_last"])
    )

    svc = tmp.loc[tmp["lastserviced_changed"], [EVENT, TS]].copy()
    return svc


def refine_operational_windows(storms: pd.DataFrame, service_events: pd.DataFrame) -> pd.DataFrame:
    if storms.empty:
        return pd.DataFrame(columns=[
            "storm_id","storm_envelope_start","storm_envelope_end",
            "storm_operational_start","storm_operational_end","severity_max"
        ])

    if service_events.empty:
        out = storms.copy()
        out["storm_operational_start"] = pd.NaT
        out["storm_operational_end"] = pd.NaT
        return out

    svc = service_events.copy()
    svc["bucket"] = svc[TS].dt.floor(OPS_BUCKET_MIN)

    out_rows = []
    for st in storms.itertuples(index=False):
        s0 = st.storm_envelope_start
        s1 = st.storm_envelope_end

        ssvc = svc[(svc[TS] >= s0) & (svc[TS] <= s1)]
        if ssvc.empty:
            out_rows.append({
                "storm_id": st.storm_id,
                "storm_envelope_start": s0,
                "storm_envelope_end": s1,
                "storm_operational_start": pd.NaT,
                "storm_operational_end": pd.NaT,
                "severity_max": st.severity_max,
            })
            continue

        counts = (
            ssvc.groupby("bucket")
                .size()
                .rename("services")
                .reset_index()
                .sort_values("bucket")
        )

        active = counts["services"] >= OPS_MIN_SERVICES_PER_BUCKET

        op_start = pd.NaT
        op_end = pd.NaT
        if active.any():
            sustain = (
                active.rolling(OPS_SUSTAIN_BUCKETS, min_periods=OPS_SUSTAIN_BUCKETS).sum()
                >= OPS_SUSTAIN_BUCKETS
            )
            if sustain.any():
                first_idx = int(np.argmax(sustain.to_numpy()))
                op_start = counts.iloc[first_idx]["bucket"]
                last_active_idx = int(np.where(active.to_numpy())[0].max())
                op_end = counts.iloc[last_active_idx]["bucket"] + pd.Timedelta(OPS_BUCKET_MIN)

        out_rows.append({
            "storm_id": st.storm_id,
            "storm_envelope_start": s0,
            "storm_envelope_end": s1,
            "storm_operational_start": op_start,
            "storm_operational_end": op_end,
            "severity_max": st.severity_max,
        })

    out = pd.DataFrame(out_rows)
    out["storm_operational_start"] = pd.to_datetime(out["storm_operational_start"], utc=True, errors="coerce")
    out["storm_operational_end"] = pd.to_datetime(out["storm_operational_end"], utc=True, errors="coerce")
    return out


def attach_storm_context(df: pd.DataFrame, storms_ops: pd.DataFrame) -> pd.DataFrame:
    out = df.copy()
    out["storm_id"] = "no_storm"
    # force dtype upfront
    out["storm_operational_start"] = pd.to_datetime(pd.Series([pd.NaT] * len(out)), utc=True)
    out["storm_operational_end"] = pd.to_datetime(pd.Series([pd.NaT] * len(out)), utc=True)
    out["storm_severity_max"] = "Unknown"

    if out.empty or storms_ops.empty:
        return out

    storms = storms_ops.copy()
    storms["storm_operational_start"] = pd.to_datetime(storms["storm_operational_start"], utc=True, errors="coerce")
    storms["storm_operational_end"] = pd.to_datetime(storms["storm_operational_end"], utc=True, errors="coerce")

    for st in storms.itertuples(index=False):
        mask = (out[TS] >= st.storm_envelope_start) & (out[TS] <= st.storm_envelope_end)
        out.loc[mask, "storm_id"] = st.storm_id
        out.loc[mask, "storm_operational_start"] = st.storm_operational_start
        out.loc[mask, "storm_operational_end"] = st.storm_operational_end
        out.loc[mask, "storm_severity_max"] = st.severity_max

    out["storm_operational_start"] = pd.to_datetime(out["storm_operational_start"], utc=True, errors="coerce")
    out["storm_operational_end"] = pd.to_datetime(out["storm_operational_end"], utc=True, errors="coerce")
    return out


# -----------------------------
# Snapshots -> labels
# -----------------------------
def load_snapshots(path_or_url: str | Path) -> pd.DataFrame:
    df = pd.read_csv(path_or_url)
    df[TS] = pd.to_datetime(df[TS], utc=True, errors="coerce")

    for c in ["lastserviced", "lastserviceleft", "lastserviceright"]:
        if c in df.columns:
            if pd.api.types.is_numeric_dtype(df[c]):
                df[c] = pd.to_datetime(df[c], unit="ms", utc=True, errors="coerce")
            else:
                df[c] = pd.to_datetime(df[c], utc=True, errors="coerce")

    df["routepriority"] = df.get("routepriority", "Unknown").fillna("Unknown").astype(str)
    df["snowrouteid"] = df.get("snowrouteid", "Unknown").fillna("Unknown").astype(str)
    df["roadname"] = df.get("roadname", "Unknown").fillna("Unknown").astype(str)

    for c in ["passes", "passesleft", "passesright"]:
        df[c] = pd.to_numeric(df.get(c, 0), errors="coerce").fillna(0)

    df["segmentlength"] = pd.to_numeric(df.get("segmentlength"), errors="coerce")

    df = df[df[TS].notna() & df[EVENT].notna() & df[SEG].notna()].copy()
    df[EVENT] = df[EVENT].astype(str)
    df[SEG] = df[SEG].astype(str)

    if "passes_phase" in df.columns:
        df["passes_event"] = df["passes_phase"].fillna(df["passes"])
    else:
        df["passes_event"] = df["passes"]

    return df


def build_events(df: pd.DataFrame) -> pd.DataFrame:
    s = df.sort_values([EVENT, SEG, TS]).copy()
    s["prev_last"] = s.groupby([EVENT, SEG])["lastserviced"].shift(1)
    ev = s[(s["lastserviced"].notna()) & (s["lastserviced"] != s["prev_last"])][
        [EVENT, SEG, TS, "lastserviced"]
    ].copy()
    ev = ev.rename(columns={TS: "observed_at", "lastserviced": "serviced_at"})
    return ev.sort_values([EVENT, SEG, "serviced_at"])


def label_next_service(df: pd.DataFrame, events: pd.DataFrame) -> pd.DataFrame:
    svc_times = {(e, s): g["serviced_at"].values for (e, s), g in events.groupby([EVENT, SEG])}

    labeled = df.sort_values([EVENT, SEG, TS]).copy()
    next_times = []

    for row in labeled[[EVENT, SEG, TS]].itertuples(index=False):
        times = svc_times.get((row.eventid, row.snowroutesegmentid))
        if times is None or len(times) == 0:
            next_times.append(pd.NaT)
            continue

        times64 = pd.to_datetime(times, utc=True).to_numpy(dtype="datetime64[ns]")
        idx = np.searchsorted(times64, row.snapshot_ts.to_datetime64(), side="right")
        next_times.append(pd.to_datetime(times64[idx], utc=True) if idx < len(times64) else pd.NaT)

    labeled["next_serviced_at"] = pd.to_datetime(next_times, utc=True)
    labeled["hours_to_next_service"] = (labeled["next_serviced_at"] - labeled[TS]).dt.total_seconds() / 3600.0
    labeled["censored"] = labeled["hours_to_next_service"].isna()
    return labeled


# -----------------------------
# Feature engineering (MATCHES new production signals)
# -----------------------------
def mark_untracked(df: pd.DataFrame) -> pd.DataFrame:
    out = df.copy()
    growth = out.groupby([EVENT, SEG])["passes_event"].agg(["min", "max"]).reset_index()
    growth["growth"] = growth["max"] - growth["min"]
    out = out.merge(growth[[EVENT, SEG, "growth"]], on=[EVENT, SEG], how="left")

    out["prediction_status"] = "OK"
    out.loc[(out["growth"] == 0) | (out["snowrouteid"].str.lower() == "unknown"), "prediction_status"] = "NO_PRED_UNTRACKED"
    return out


def add_horizon_labels(df: pd.DataFrame) -> pd.DataFrame:
    out = df.copy()
    for h in HORIZONS:
        out[f"y_{h}h"] = (
            out["hours_to_next_service"].notna()
            & (out["hours_to_next_service"] <= h)
        ).astype(int)
    return out


def add_features(df: pd.DataFrame, events: pd.DataFrame) -> pd.DataFrame:
    f = df.sort_values([EVENT, SEG, TS]).copy()

    # ensure storm columns exist (defensive)
    for c, default in [
        ("storm_id", "no_storm"),
        ("storm_operational_start", pd.NaT),
        ("storm_operational_end", pd.NaT),
        ("storm_severity_max", "Unknown"),
    ]:
        if c not in f.columns:
            f[c] = default
    f["storm_operational_start"] = pd.to_datetime(f["storm_operational_start"], utc=True, errors="coerce")
    f["storm_operational_end"] = pd.to_datetime(f["storm_operational_end"], utc=True, errors="coerce")

    # base features
    f["priority_num"] = f["routepriority"].str.extract(r"(\d+)").astype(float)
    f["hour"] = f[TS].dt.hour
    f["dow"] = f[TS].dt.weekday

    # hours since last service (segment-level)
    f["hours_since_last_service"] = (f[TS] - f["lastserviced"]).dt.total_seconds() / 3600.0

    denom = f["segmentlength"].replace(0, np.nan)
    f["passes_per_len"] = f["passes_event"] / denom

    # last serviced change (robust)
    f["prev_last"] = f.groupby([EVENT, SEG])["lastserviced"].shift(1)
    f["lastserviced_changed"] = (
        f["lastserviced"].notna()
        & f["prev_last"].notna()
        & (f["lastserviced"] != f["prev_last"])
    ).astype(int)

    # buckets
    f["hour_bucket"] = f[TS].dt.floor("h")
    f["bucket_15m"] = f[TS].dt.floor("15min")

    # phase (if present)
    if "eventphaseid" in f.columns:
        phase_num = f["eventphaseid"].astype(str).str.extract(r"-(\d+)$")[0]
        f["phase_num"] = pd.to_numeric(phase_num, errors="coerce").fillna(0)
    else:
        f["phase_num"] = 0

    # storm clock
    f["in_storm"] = (
        f["storm_operational_start"].notna()
        & f["storm_operational_end"].notna()
        & (f[TS] >= f["storm_operational_start"])
        & (f[TS] <= f["storm_operational_end"])
    ).astype(int)

    f["hours_since_storm_start"] = ((f[TS] - f["storm_operational_start"]).dt.total_seconds() / 3600.0)
    f["hours_until_storm_end"] = ((f["storm_operational_end"] - f[TS]).dt.total_seconds() / 3600.0)
    f["hours_since_storm_start"] = f["hours_since_storm_start"].clip(lower=0).fillna(-1)
    f["hours_until_storm_end"] = f["hours_until_storm_end"].clip(lower=0).fillna(-1)

    # city and route tempo (15m)
    svc = f.loc[f["lastserviced_changed"] == 1, [EVENT, "snowrouteid", TS]].copy()
    svc["bucket_15m"] = svc[TS].dt.floor("15min")

    city_15m = svc.groupby([EVENT, "bucket_15m"]).size().rename("city_services_15m").reset_index()
    f = f.merge(city_15m, on=[EVENT, "bucket_15m"], how="left")
    f["city_services_15m"] = f["city_services_15m"].fillna(0)

    route_15m = svc.groupby([EVENT, "snowrouteid", "bucket_15m"]).size().rename("route_services_15m").reset_index()
    f = f.merge(route_15m, on=[EVENT, "snowrouteid", "bucket_15m"], how="left")
    f["route_services_15m"] = f["route_services_15m"].fillna(0)

    # city services last hour (from events)
    city = (
        events.assign(hour_bucket=events["serviced_at"].dt.floor("h"))
        .groupby([EVENT, "hour_bucket"])
        .size()
        .rename("city_services_last_hour")
        .reset_index()
    )
    f = f.merge(city, on=[EVENT, "hour_bucket"], how="left")
    f["city_services_last_hour"] = f["city_services_last_hour"].fillna(0)

    # route services last hour
    route = (
        f.groupby([EVENT, "snowrouteid", "hour_bucket"])["lastserviced_changed"]
        .sum()
        .rename("route_services_last_hour")
        .reset_index()
    )
    f = f.merge(route, on=[EVENT, "snowrouteid", "hour_bucket"], how="left")
    f["route_services_last_hour"] = f["route_services_last_hour"].fillna(0)

    # route completion (60m)
    f["_route_served_once"] = f.groupby([EVENT, "snowrouteid", SEG])["lastserviced_changed"].cummax()
    route_completion = (
        f.groupby([EVENT, "snowrouteid", "hour_bucket"])["_route_served_once"]
        .mean()
        .rename("route_completion_60m")
        .reset_index()
    )
    f = f.merge(route_completion, on=[EVENT, "snowrouteid", "hour_bucket"], how="left")
    f["route_completion_60m"] = f["route_completion_60m"].fillna(0)
    f = f.drop(columns=["_route_served_once"])

    # rolling 3h/6h service counts + passes deltas/ratios
    def rolling_sum_by_group(data, group_cols, value_col, window_h):
        ordered = data.sort_values(group_cols + ["hour_bucket"]).copy()
        rolled = (
            ordered.set_index("hour_bucket")
            .groupby(group_cols, sort=False)[value_col]
            .rolling(f"{window_h}h", min_periods=1)
            .sum()
            .reset_index(level=group_cols, drop=True)
        )
        rolled.index = ordered.index
        return rolled.reindex(data.index)

    def rolling_delta_ratio_by_group(data, group_cols, value_col, window_h, prefix):
        ordered = data.sort_values(group_cols + ["hour_bucket"]).copy()
        series = ordered.set_index("hour_bucket").groupby(group_cols, sort=False)[value_col]
        ro = series.rolling(f"{window_h}h", min_periods=1)
        rmin = ro.min().reset_index(level=group_cols, drop=True)
        rmax = ro.max().reset_index(level=group_cols, drop=True)
        delta = (rmax - rmin).to_numpy()
        ratio = (rmax / rmin.replace(0, np.nan)).to_numpy()
        out = pd.DataFrame(
            {
                f"{prefix}_delta_{window_h}h": delta,
                f"{prefix}_ratio_{window_h}h": ratio,
            },
            index=ordered.index,
        )
        return out.reindex(data.index)

    for w in [3, 6]:
        f[f"seg_services_{w}h"] = rolling_sum_by_group(
            f, [EVENT, SEG], "lastserviced_changed", w
        ).fillna(0)
        f[f"route_services_{w}h"] = rolling_sum_by_group(
            f, [EVENT, "snowrouteid"], "lastserviced_changed", w
        ).fillna(0)
        f = f.join(rolling_delta_ratio_by_group(f, [EVENT, SEG], "passes_event", w, "passes_event").fillna(0))
        f = f.join(rolling_delta_ratio_by_group(f, [EVENT, "snowrouteid"], "passes_event", w, "route_passes_event").fillna(0))

    # placeholders (keep consistent with earlier backfill)
    f["neighbor_services_last_hour"] = 0
    f["temp_c"] = np.nan
    f["wind_speed_mps"] = np.nan
    f["wind_gust_mps"] = np.nan
    f["snowfall_rate_mmhr"] = np.nan
    f["freezing_rain"] = 0
    for lag in [1, 2, 3]:
        f[f"temp_c_lag{lag}"] = np.nan
        f[f"snowfall_rate_mmhr_lag{lag}"] = np.nan
        f[f"wind_speed_mps_lag{lag}"] = np.nan
        f[f"wind_gust_mps_lag{lag}"] = np.nan
    f["nws_alert_count"] = 0
    f["nws_alert_active"] = 0

    return f


FEATURE_COLS = [
    "priority_num",
    "passes_event",
    "passes_per_len",
    "hours_since_last_service",
    "hour",
    "dow",
    "city_services_last_hour",
    "route_services_last_hour",
    "seg_services_3h",
    "seg_services_6h",
    "route_services_3h",
    "route_services_6h",
    "passes_event_delta_3h",
    "passes_event_ratio_3h",
    "passes_event_delta_6h",
    "passes_event_ratio_6h",
    "route_passes_event_delta_3h",
    "route_passes_event_ratio_3h",
    "route_passes_event_delta_6h",
    "route_passes_event_ratio_6h",
    "neighbor_services_last_hour",
    "temp_c",
    "wind_speed_mps",
    "wind_gust_mps",
    "snowfall_rate_mmhr",
    "freezing_rain",
    "temp_c_lag1",
    "temp_c_lag2",
    "temp_c_lag3",
    "snowfall_rate_mmhr_lag1",
    "snowfall_rate_mmhr_lag2",
    "snowfall_rate_mmhr_lag3",
    "wind_speed_mps_lag1",
    "wind_speed_mps_lag2",
    "wind_speed_mps_lag3",
    "wind_gust_mps_lag1",
    "wind_gust_mps_lag2",
    "wind_gust_mps_lag3",
    "nws_alert_count",
    "nws_alert_active",

    # NEW signals (match snow_predict.py)
    "phase_num",
    "in_storm",
    "hours_since_storm_start",
    "hours_until_storm_end",
    "city_services_15m",
    "route_services_15m",
    "route_completion_60m",
]


# -----------------------------
# Modeling
# -----------------------------
def train_models_point_in_time(train_df: pd.DataFrame):
    train_cutoffs = train_df.groupby(EVENT)[TS].transform(lambda s: s.quantile(0.8))
    train_mask = train_df[TS] <= train_cutoffs

    X = train_df[FEATURE_COLS].replace([np.inf, -np.inf], np.nan)
    med = X.median(numeric_only=True)
    X = X.fillna(med)

    calib_cutoffs = train_df.loc[train_mask].groupby(EVENT)[TS].transform(lambda s: s.quantile(0.8))
    calib_mask = pd.Series(False, index=train_df.index)
    calib_mask.loc[train_mask] = train_df.loc[train_mask, TS] > calib_cutoffs
    base_train_mask = train_mask & ~calib_mask

    models: dict[int, HistGradientBoostingClassifier | None] = {}
    calibrated: dict[int, CalibratedClassifierCV | _PrefitCalibrator | None] = {}

    for h in HORIZONS:
        y = train_df[f"y_{h}h"].astype(int)
        X_train, y_train = X[base_train_mask], y[base_train_mask]
        X_cal, y_cal = X[calib_mask], y[calib_mask]

        if X_train.empty or y_train.nunique() < 2:
            models[h] = None
            calibrated[h] = None
            continue

        pos = int(y_train.sum())
        neg = int(len(y_train) - pos)
        pos_rate = float(pos / len(y_train)) if len(y_train) else 0.0
        if pos > 0 and neg > 0 and (pos_rate < 0.2 or pos_rate > 0.8):
            pos_weight = neg / pos
            sample_weight = np.where(y_train == 1, pos_weight, 1.0)
        else:
            sample_weight = None

        clf = HistGradientBoostingClassifier(
            max_depth=6, learning_rate=0.08, max_iter=350, random_state=42
        )
        clf.fit(X_train, y_train, sample_weight=sample_weight)
        models[h] = clf

        if not X_cal.empty and y_cal.nunique() >= 2:
            method = "isotonic" if len(X_cal) >= CALIBRATION_MIN_ROWS_ISOTONIC else "sigmoid"
            calibrated[h] = _fit_prefit_calibrator(clf, X_cal, y_cal, method)
        else:
            calibrated[h] = None

    return models, calibrated, med


def predict_with_models(df: pd.DataFrame, models, calibrated, medians: pd.Series) -> pd.DataFrame:
    out = df.copy()
    X = out[FEATURE_COLS].replace([np.inf, -np.inf], np.nan)
    X = X.fillna(medians)

    for h in HORIZONS:
        model = calibrated.get(h) or models.get(h)
        if model is None:
            out[f"p_{h}h"] = np.nan
        else:
            out[f"p_{h}h"] = model.predict_proba(X)[:, 1]

    out["p_2h"] = out[["p_1h", "p_2h"]].max(axis=1)
    out["p_4h"] = out[["p_2h", "p_4h"]].max(axis=1)
    out["p_8h"] = out[["p_4h", "p_8h"]].max(axis=1)

    for h in HORIZONS:
        out.loc[out["prediction_status"] != "OK", f"p_{h}h"] = np.nan

    return out


# -----------------------------
# Main backfill
# -----------------------------
def main() -> None:
    df = load_snapshots(SNAPSHOT_URL)
    if df.empty:
        raise SystemExit("No snapshots loaded")

    events = build_events(df)
    labeled = label_next_service(df, events)

    # storms from alerts + operational activity
    alerts_log = load_alert_log()
    storms_env = build_storm_envelopes_from_alerts(alerts_log)
    service_events = derive_city_service_events(labeled)
    storms_ops = refine_operational_windows(storms_env, service_events)
    labeled = attach_storm_context(labeled, storms_ops)

    featured = add_features(labeled, events)
    featured = mark_untracked(featured)
    featured = add_horizon_labels(featured)

    backfills = []
    cut_every_hours = 6
    predict_window_hours = 6

    for event_id, ev_df in featured.groupby(EVENT):
        ev_df = ev_df.sort_values(TS).copy()
        times = ev_df[TS].dropna().sort_values().unique()
        if len(times) < 50:
            continue

        t0 = pd.to_datetime(times[0], utc=True)
        t1 = pd.to_datetime(times[-1], utc=True)
        cutpoints = pd.date_range(start=t0, end=t1, freq=f"{cut_every_hours}h", tz="UTC")

        for cut in cutpoints:
            train = ev_df[ev_df[TS] <= cut]
            test = ev_df[(ev_df[TS] > cut) & (ev_df[TS] <= cut + pd.Timedelta(hours=predict_window_hours))]
            if train.empty or test.empty:
                continue

            train_ok = train[train["prediction_status"] == "OK"].copy()
            test_ok = test[test["prediction_status"] == "OK"].copy()
            if len(train_ok) < 200 or len(test_ok) < 50:
                continue

            models, calibrated, med = train_models_point_in_time(train_ok)
            pred = predict_with_models(test_ok, models, calibrated, med)

            keep = [
                EVENT, SEG, TS, "snowrouteid", "routepriority", "priority_num",
                "prediction_status",
                "next_serviced_at", "hours_to_next_service", "censored",
                "y_1h", "y_2h", "y_4h", "y_8h",
                "p_1h", "p_2h", "p_4h", "p_8h",
            ]
            pred["as_of_cut_ts"] = cut
            backfills.append(pred[keep + ["as_of_cut_ts"]])

    if not backfills:
        raise SystemExit("No backfill rows generated (try lowering thresholds / increasing windows).")

    out = pd.concat(backfills, ignore_index=True)

    # --- CRITICAL: enforce correct point-in-time prediction (no leakage) ---
    out["as_of_cut_ts"] = pd.to_datetime(out["as_of_cut_ts"], utc=True, errors="coerce")
    out["snapshot_ts"] = pd.to_datetime(out["snapshot_ts"], utc=True, errors="coerce")

    out = out[out["as_of_cut_ts"] < out["snapshot_ts"]].copy()

    out = (
        out.sort_values(["eventid", "snowroutesegmentid", "snapshot_ts", "as_of_cut_ts"])
           .drop_duplicates(subset=["eventid", "snowroutesegmentid", "snapshot_ts"], keep="last")
           .sort_values(["eventid", "snapshot_ts", "snowroutesegmentid"])
    )

    out_path = ARTIFACT_DIR / "predictions_backfill_prob.csv"
    out.to_csv(out_path, index=False)
    print(f"Saved backfill predictions: {out_path} | rows={len(out):,}")


if __name__ == "__main__":
    main()


In [None]:
print(storms_ops[["storm_id","storm_operational_start","storm_operational_end","severity_max"]])
