In [None]:
import os
import glob
import json
import pandas as pd
import matplotlib.pyplot as plt

def parse_array_text_to_floats(s: str):
    if not isinstance(s, str) or not s:
        return []
    t = s.strip()
    try_json = t.replace("{", "[").replace("}", "]")
    try:
        vals = json.loads(try_json)
        return [float(v) for v in vals] if isinstance(vals, list) else []
    except Exception:
        return []

def load_hourly(path: str) -> pd.DataFrame:
    df = pd.read_csv(path)
    if "scenario" not in df.columns:
        fname = os.path.basename(path).lower()
        df["scenario"] = "policy" if "policy" in fname else ("baseline" if "baseline" in fname else "unknown")
    if "state_abbr" not in df.columns:
        df["state_abbr"] = os.path.basename(os.path.dirname(path)).upper()
    df["net_load"] = df["net_sum_text"].apply(parse_array_text_to_floats)
    return df

def pick_year_row(df: pd.DataFrame, scenario: str, year: int):
    sdf = df[(df["scenario"].str.lower() == scenario.lower()) & (df["year"] == year)]
    if sdf.empty:
        return None
    exact = sdf[sdf["net_load"].apply(lambda a: isinstance(a, list) and len(a) > 0)]
    return (exact.iloc[0] if not exact.empty else None)

def to_time_series(arr, year: int) -> pd.Series:
    idx = pd.date_range(start=f"{year}-01-01 00:00:00", periods=len(arr), freq="H")
    return pd.Series(arr, index=idx)

def aggregate_series(s: pd.Series, aggregation: str = "hourly", agg_func: str = "mean") -> pd.Series:
    if aggregation == "hourly":
        return s
    rule = "D" if aggregation == "daily" else "M"
    return getattr(s.resample(rule), agg_func)()

def plot_state_netload(
    state_dir: str = None,
    baseline_csv: str = None,
    policy_csv: str = None,
    year: int = 2040,
    aggregation: str = "daily",
    agg_func: str = "max",
    title: str | None = None,
):
    # Load baseline + policy data
    if state_dir:
        b_matches = sorted(glob.glob(os.path.join(state_dir, "hourly_baseline*.csv")))
        p_matches = sorted(glob.glob(os.path.join(state_dir, "hourly_policy*.csv")))
        frames = []
        if b_matches: frames.append(load_hourly(b_matches[0]))
        if p_matches: frames.append(load_hourly(p_matches[0]))
        if not frames:
            raise FileNotFoundError("No baseline/policy hourly files found.")
        df = pd.concat(frames, ignore_index=True)
    else:
        df = pd.concat([load_hourly(baseline_csv), load_hourly(policy_csv)], ignore_index=True)

    df["year"] = pd.to_numeric(df["year"], errors="coerce")
    brow = pick_year_row(df, "baseline", year)
    prow = pick_year_row(df, "policy", year)
    if brow is None or prow is None:
        raise ValueError(f"Missing baseline or policy rows for {year}")

    s_base = to_time_series(brow["net_load"], year)
    s_poli = to_time_series(prow["net_load"], year)

    s_base_agg = aggregate_series(s_base, aggregation, agg_func)
    s_poli_agg = aggregate_series(s_poli, aggregation, agg_func)

    plt.figure(figsize=(14, 6))
    plt.plot(s_base_agg.index, s_base_agg.values, label="Baseline", alpha=0.9)
    plt.plot(s_poli_agg.index, s_poli_agg.values, label="Policy", alpha=0.9)

    plt.xlabel("Time")
    plt.ylabel("Net Load (MW)")
    plt.title(title or f"{aggregation.capitalize()} Net Load â€” Baseline vs Policy ({year}) [{agg_func.upper()}]")
    plt.legend()

    # --- Annotations (daily max in GW) ---
    base_max = s_base_agg.max()
    poli_max = s_poli_agg.max()
    base_date = s_base_agg.idxmax()
    poli_date = s_poli_agg.idxmax()

    plt.annotate(f"{base_max/1000:.1f} GW (baseline)",
                 xy=(base_date, base_max),
                 xytext=(0, 8),
                 textcoords="offset points",
                 ha="center", color="black", fontweight="bold")

    plt.annotate(f"{poli_max/1000:.1f} GW (policy)",
                 xy=(poli_date, poli_max),
                 xytext=(0, -12),
                 textcoords="offset points",
                 ha="center", color="black", fontweight="bold")

    plt.tight_layout()
    plt.show()


In [None]:
# 2) Provide explicit files for baseline & policy
plot_state_netload(
    baseline_csv="/Volumes/Seagate Portabl/permit_power/dgen_runs/per_state_outputs/wi/hourly_baseline_test_run_all_states_no_net_billing.csv",
    policy_csv="/Volumes/Seagate Portabl/permit_power/dgen_runs/per_state_outputs/wi/hourly_policy_test_run_all_states_no_net_billing.csv",
    aggregation="daily"
)