In [None]:
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import re

########################################
# Config (same station order)
########################################

stations = [
    "01_Rulo",
    "02_St_Joseph",
    "03_Kansas_City",
    "04_Waverly",
    "05_Boonville",
    "06_Hermann",
    "07_St_Charles",
    "08_Grafton",
    "09_ST_Louis",
    "10_Chester",
    "11_Thebes"
]

# This is where your merged per-subset station CSVs live
# like: 01_Rulo_mem1_q_tolerance_25.0_day_bracket_730.csv
BASE_DIR = Path("/media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH")

# Output directory for KGE seasonal plots and CSV
OUT_DIR = Path("/media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/KGE_seasonal_stage")
OUT_DIR.mkdir(parents=True, exist_ok=True)

########################################
# Seasonal logic and utilities
########################################

def season_from_month(ts):
    """Return 'dry' or 'wet' based on month of timestamp."""
    m = ts.month
    # dry: Sep(9) Oct(10) Nov(11) Dec(12) Jan(1) Feb(2)
    # wet: Mar(3) .. Aug(8)
    # Note: Your previous code only had Jan in dry. I'm extending to Feb,
    # because otherwise Feb goes nowhere. Adjust if you don't want Feb dry.
    if m in [8,9,10,11,12,1,2]:
        return "dry"
    else:
        return "wet"

def get_horizons(cols):
    """
    Given df.columns, find forecast lead columns,
    which look like '006','012',...,'720'.
    Return only 6h steps from 6h through 720h,
    plus 720 if missing in the step sequence.
    """
    horizons_all = sorted(
        [c for c in cols if re.fullmatch(r"\d{3}", c)],
        key=lambda s: int(s)
    )
    wanted = []
    for h in range(6, 721, 6):
        h_str = f"{h:03d}"
        if h_str in horizons_all:
            wanted.append(h_str)
    if "720" in horizons_all and "720" not in wanted:
        wanted.append("720")
    return wanted

def kge(sim, obs):
    """
    Kling-Gupta Efficiency.
    """
    sim = np.asarray(sim, dtype=float)
    obs = np.asarray(obs, dtype=float)

    mask = np.isfinite(sim) & np.isfinite(obs)
    if mask.sum() < 3:
        return np.nan

    sim = sim[mask]
    obs = obs[mask]

    r = np.corrcoef(sim, obs)[0, 1]
    alpha = np.std(sim) / np.std(obs) if np.std(obs) != 0 else np.nan
    beta = np.mean(sim) / np.mean(obs) if np.mean(obs) != 0 else np.nan

    return 1 - np.sqrt((r - 1) ** 2 + (alpha - 1) ** 2 + (beta - 1) ** 2)

def subset_pretty_name(raw_subset):
    """
    Map folder/subset token to publication-ready label.
    """
    mapping = {
        "All": "Ensemble mean",
        "mem1": "Ensemble 1",
        "mem2": "Ensemble 2",
        "mem3": "Ensemble 3",
        "mem4": "Ensemble 4",
    }
    return mapping.get(raw_subset, raw_subset)

########################################
# Core processing
########################################

all_kge_rows = []  # to build a big CSV of seasonal KGE

for station_name in stations:
    # 1. find all *_q_tolerance_*.csv for this station
    station_csvs = list(BASE_DIR.glob(f"{station_name}_*_q_tolerance_*.csv"))
    if not station_csvs:
        print(f"[WARN] No per-subset CSVs found for {station_name}")
        continue

    # read each subset file and prep seasonal data
    subset_data = {}  # {subset_pretty: df_with_season_and_stage}
    for f in station_csvs:
        m = re.search(rf"^{re.escape(station_name)}_(.+?)_q_tolerance_", f.name)
        if not m:
            continue
        subset_raw = m.group(1)
        nice_subset = subset_pretty_name(subset_raw)

        df = pd.read_csv(f, parse_dates=["time"]).set_index("time").sort_index()

        # we assume df["stage_m_USGS_nearest"] is observed stage
        if "stage_m_USGS_nearest" not in df.columns:
            print(f"[WARN] {station_name} {nice_subset} missing stage_m_USGS_nearest, skipping")
            continue

        # assign season
        df["season"] = [season_from_month(ts) for ts in df.index]

        subset_data[nice_subset] = df

    if not subset_data:
        print(f"[WARN] No usable subset data for {station_name}")
        continue

    # determine horizons that exist across any subset
    all_cols = set()
    for df in subset_data.values():
        all_cols.update(df.columns.tolist())
    horizons = get_horizons(list(all_cols))
    if not horizons:
        print(f"[WARN] No horizons found for {station_name}")
        continue

    # compute KGE per season, per subset, per horizon
    # we will collect results into a dict:
    # kge_by_season[season] = DataFrame with columns:
    #   LeadDays, Subset, KGE
    season_list = ["dry", "wet"]
    seasonal_records = {s: [] for s in season_list}

    for nice_subset, df in subset_data.items():
        obs = df["stage_m_USGS_nearest"]

        for h in horizons:
            if h not in df.columns:
                continue
            pred = df[h]

            tmp = pd.concat([obs, pred, df["season"]], axis=1)
            tmp.columns = ["obs_stage", "pred_stage", "season"]
            # split by season
            for seas in season_list:
                subtmp = tmp[tmp["season"] == seas].dropna()
                if subtmp.empty:
                    continue

                kge_val = kge(subtmp["pred_stage"].values, subtmp["obs_stage"].values)

                lead_hours = int(h)
                lead_days = lead_hours / 24.0 if lead_hours >= 24 else lead_hours / 24.0

                seasonal_records[seas].append({
                    "Station": station_name,
                    "Subset": nice_subset,
                    "LeadHours": lead_hours,
                    "LeadDays": lead_days,
                    "Season": seas,
                    "KGE": kge_val
                })

    # turn those seasonal records into dfs for plotting and saving
    df_dry = pd.DataFrame([r for r in seasonal_records["dry"] if r["Station"] == station_name])
    df_wet = pd.DataFrame([r for r in seasonal_records["wet"] if r["Station"] == station_name])

    # append to global CSV list
    all_kge_rows.extend(seasonal_records["dry"])
    all_kge_rows.extend(seasonal_records["wet"])

    # plot helper
    def plot_kge_scatter(df_season, season_label):
        if df_season.empty:
            print(f"[INFO] No {season_label} data to plot for {station_name}")
            return

        fig, ax = plt.subplots(figsize=(9,5))

        for subset_name, dsub in df_season.groupby("Subset"):
            ax.scatter(
                dsub["LeadDays"],
                dsub["KGE"],
                s=30,
                alpha=0.8,
                label=subset_name
            )

        ax.set_xlabel("Lead time (days)")
        ax.set_ylabel("KGE (stage)")
        ax.set_title(f"{station_name}: KGE vs lead time ({season_label} season)")
        ax.grid(True, alpha=0.3)
        ax.legend(title="Subset")
        ax.set_xlim(left=-0.5, right=30.5)  # 0 to ~30 days view
        ax.set_ylim(0.3, 1)             # standard KGE-ish range, adjust if you want
        fig.tight_layout()

        out_png = OUT_DIR / f"{station_name}_KGE_vs_lead_{season_label}.png"
        fig.savefig(out_png, dpi=300)
        plt.close(fig)
        print(f"[plot] wrote {out_png}")

    # make seasonal plots
    plot_kge_scatter(df_dry, "dry")
    plot_kge_scatter(df_wet, "wet")

# build and save combined CSV of all stations / both seasons
all_kge_df = pd.DataFrame(all_kge_rows)
csv_out = OUT_DIR / "Seasonal_KGE_stage_by_station_subset_lead.csv"
all_kge_df.to_csv(csv_out, index=False)
print(f"[csv] wrote {csv_out}")

print("Done generating seasonal KGE vs lead time plots for stage.")


[plot] wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/KGE_seasonal_stage/01_Rulo_KGE_vs_lead_dry.png
[plot] wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/KGE_seasonal_stage/01_Rulo_KGE_vs_lead_wet.png
[plot] wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/KGE_seasonal_stage/02_St_Joseph_KGE_vs_lead_dry.png
[plot] wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/KGE_seasonal_stage/02_St_Joseph_KGE_vs_lead_wet.png
[plot] wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/KGE_seasonal_stage/03_Kansas_City_KGE_vs_lead_dry.png
[plot] wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/KGE_seasonal_stage/03_Kansas_City_KGE_vs_lead_wet.png
[plot] wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/KGE_seasonal_stage/04_Waverly_KGE_vs_lead_dry.png
[plot] wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/KGE_seasonal_stage/04_Waverly_KGE_vs_lead_wet.png
[plot] wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/KGE_seasonal_stage/05_Boonville_KGE_vs_lead_dry.png
[p

In [28]:
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import re
import matplotlib.cm as cm

colors = cm.get_cmap("tab20", 12)   # 12 distinct colors
line_styles = ["-", "--", "-.", ":", "-", "--", "-.", ":", "-", "--", "-.", ":"]

########################################
# Config (same station order)
########################################

stations = [
    "01_Rulo",
    "02_St_Joseph",
    "03_Kansas_City",
    "04_Waverly",
    "05_Boonville",
    "06_Hermann",
    "07_St_Charles",
    "08_Grafton",
    "09_ST_Louis",
    "10_Chester",
    "11_Thebes"
]

# This is where your merged per-subset station CSVs live
# like: 01_Rulo_mem1_q_tolerance_25.0_day_bracket_730.csv
BASE_DIR = Path("/media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH")

# Output directory for KGE monthly plots and CSV
OUT_DIR = Path("/media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/KGE_monthly_stage")
OUT_DIR.mkdir(parents=True, exist_ok=True)

########################################
# Utilities
########################################

def get_horizons(cols):
    """
    Given df.columns, find forecast lead columns,
    which look like '006','012',...,'720'.
    Return only 6 h steps from 6 h through 720 h,
    plus 720 if missing in the step sequence.
    """
    horizons_all = sorted(
        [c for c in cols if re.fullmatch(r"\d{3}", c)],
        key=lambda s: int(s)
    )
    wanted = []
    for h in range(6, 721, 6):
        h_str = f"{h:03d}"
        if h_str in horizons_all:
            wanted.append(h_str)
    if "720" in horizons_all and "720" not in wanted:
        wanted.append("720")
    return wanted

def kge(sim, obs):
    """
    Kling-Gupta Efficiency.
    """
    sim = np.asarray(sim, dtype=float)
    obs = np.asarray(obs, dtype=float)

    mask = np.isfinite(sim) & np.isfinite(obs)
    if mask.sum() < 3:
        return np.nan

    sim = sim[mask]
    obs = obs[mask]

    r = np.corrcoef(sim, obs)[0, 1]
    alpha = np.std(sim) / np.std(obs) if np.std(obs) != 0 else np.nan
    beta = np.mean(sim) / np.mean(obs) if np.mean(obs) != 0 else np.nan

    return 1 - np.sqrt((r - 1) ** 2 + (alpha - 1) ** 2 + (beta - 1) ** 2)

def subset_pretty_name(raw_subset):
    """
    Map folder/subset token to publication-ready label.
    """
    mapping = {
        "All": "Ensemble mean"
        #"mem1": "Ensemble 1",
        #"mem2": "Ensemble 2",
        #"mem3": "Ensemble 3",
        #"mem4": "Ensemble 4",
    }
    return mapping.get(raw_subset, raw_subset)

########################################
# Core processing (monthly KGE)
########################################

all_kge_rows = []  # to build a big CSV of monthly KGE

for station_name in stations:
    # 1. find all *_q_tolerance_*.csv for this station
    station_csvs = list(BASE_DIR.glob(f"{station_name}_*_q_tolerance_*.csv"))
    if not station_csvs:
        print(f"[WARN] No per-subset CSVs found for {station_name}")
        continue

    # read each subset file and prep data
    subset_data = {}  # {subset_pretty: df}
    for f in station_csvs:
        m = re.search(rf"^{re.escape(station_name)}_(.+?)_q_tolerance_", f.name)
        if not m:
            continue
        subset_raw = m.group(1)
        nice_subset = subset_pretty_name(subset_raw)

        df = pd.read_csv(f, parse_dates=["time"]).set_index("time").sort_index()

        # we assume df["stage_m_USGS_nearest"] is observed stage
        if "stage_m_USGS_nearest" not in df.columns:
            print(f"[WARN] {station_name} {nice_subset} missing stage_m_USGS_nearest, skipping")
            continue

        # add month info
        df["Month"] = df.index.month
        df["MonthName"] = df.index.strftime("%b")  # Jan, Feb, ...

        subset_data[nice_subset] = df

    if not subset_data:
        print(f"[WARN] No usable subset data for {station_name}")
        continue

    # determine horizons that exist across any subset
    all_cols = set()
    for df in subset_data.values():
        all_cols.update(df.columns.tolist())
    horizons = get_horizons(list(all_cols))
    if not horizons:
        print(f"[WARN] No horizons found for {station_name}")
        continue

    # collect monthly records for this station
    station_monthly_records = []

    for nice_subset, df in subset_data.items():
        obs = df["stage_m_USGS_nearest"]

        for h in horizons:
            if h not in df.columns:
                continue
            pred = df[h]

            tmp = pd.concat(
                [obs, pred, df["Month"], df["MonthName"]],
                axis=1
            )
            tmp.columns = ["obs_stage", "pred_stage", "Month", "MonthName"]
            tmp = tmp.dropna()

            if tmp.empty:
                continue

            # group by calendar month
            for month_val, g in tmp.groupby("Month"):
                if g.shape[0] < 3:
                    continue

                kge_val = kge(g["pred_stage"].values, g["obs_stage"].values)

                lead_hours = int(h)
                lead_days = lead_hours / 24.0

                rec = {
                    "Station": station_name,
                    "Subset": nice_subset,
                    "LeadHours": lead_hours,
                    "LeadDays": lead_days,
                    "Month": int(month_val),
                    "MonthName": g["MonthName"].iloc[0],
                    "KGE": kge_val,
                }

                station_monthly_records.append(rec)
                all_kge_rows.append(rec)

    # plotting for this station, month by month
    df_station_month = pd.DataFrame(station_monthly_records)
    if df_station_month.empty:
        print(f"[INFO] No monthly KGE data to plot for {station_name}")
        continue

    def plot_kge_monthly_lines(df_station):
        # df_station has rows for one station with fields:
        # Station, Subset, LeadDays, Month, MonthName, KGE

        # Create consistent month order
        month_order = [1,2,3,4,5,6,7,8,9,10,11,12]

        # Build figure
        fig, ax = plt.subplots(figsize=(12,6))

        # Loop months in calendar order
        for m in month_order:
            df_m = df_station[df_station["Month"] == m]
            if df_m.empty:
                continue

            # Multiple subsets: draw one line per subset for this month
            for subset_name, dsub in df_m.groupby("Subset"):
                dsub_sorted = dsub.sort_values("LeadDays")
                ax.plot(
                df_m["LeadDays"],
                df_m["KGE"],
                linewidth=2,
                linestyle=line_styles[m-1],
                label=df_m["MonthName"].iloc[0]
            )


        ax.set_xlabel("Lead time (days)")
        ax.set_ylabel("KGE (stage)")
        ax.set_title(f"{station_name}: Monthly KGE vs lead time")
        ax.grid(True, alpha=0.3)
        ax.set_xlim(-0.2, 30.2)
        ax.set_ylim(0.3, 1.0)

        # Avoid duplicate legend labels
        handles, labels = ax.get_legend_handles_labels()
        unique = dict(zip(labels, handles))
        ax.legend(unique.values(), unique.keys(), title="Month - Subset", ncol=2)

        fig.tight_layout()

        out_png = OUT_DIR / f"{station_name}_KGE_vs_lead_MONTHLY_LINES.png"
        fig.savefig(out_png, dpi=300)
        plt.close(fig)
        print(f"[plot] wrote {out_png}")

    def plot_kge_monthly_mean(df_station):
        # df_station has columns:
        # Station, Subset, LeadDays, Month, MonthName, KGE

        # Compute mean KGE across subsets for each month and lead time
        df_mean = (
            df_station
            .groupby(["Month", "MonthName", "LeadDays"], as_index=False)
            .agg({"KGE": "mean"})
        )

        month_order = [1,2,3,4,5,6,7,8,9,10,11,12]

        fig, ax = plt.subplots(figsize=(12,6))

        for m in month_order:
            df_m = df_mean[df_mean["Month"] == m].sort_values("LeadDays")
            if df_m.empty:
                continue

            ax.plot(
                df_m["LeadDays"],
                df_m["KGE"],
                linewidth=2,
                alpha=0.9,
                color=colors(m-1),   # month 1â†’index 0
                label=df_m["MonthName"].iloc[0],
            )


        ax.set_xlabel("Lead time (days)")
        ax.set_ylabel("KGE (stage)")
        ax.set_title(f"{station_name}: Monthly mean KGE vs lead time")
        ax.grid(True, alpha=0.3)
        ax.set_xlim(-0.2, 30.2)
        ax.set_ylim(0.1, 1.0)

        ax.legend(title="Month", ncol=3)
        fig.tight_layout()

        out_png = OUT_DIR / f"{station_name}_KGE_vs_lead_MONTHLY_MEAN.png"
        fig.savefig(out_png, dpi=300)
        plt.close(fig)
        print(f"[plot] wrote {out_png}")
        
    plot_kge_monthly_mean(df_station_month)


#plot_kge_monthly_lines(df_station_month)

# build and save combined CSV of all stations (monthly)
all_kge_df = pd.DataFrame(all_kge_rows)
csv_out = OUT_DIR / "Monthly_KGE_stage_by_station_subset_lead.csv"
all_kge_df.to_csv(csv_out, index=False)
print(f"[csv] wrote {csv_out}")

print("Done generating monthly KGE vs lead time plots for stage.")


  colors = cm.get_cmap("tab20", 12)   # 12 distinct colors


[plot] wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/KGE_monthly_stage/01_Rulo_KGE_vs_lead_MONTHLY_MEAN.png
[plot] wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/KGE_monthly_stage/02_St_Joseph_KGE_vs_lead_MONTHLY_MEAN.png
[plot] wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/KGE_monthly_stage/03_Kansas_City_KGE_vs_lead_MONTHLY_MEAN.png
[plot] wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/KGE_monthly_stage/04_Waverly_KGE_vs_lead_MONTHLY_MEAN.png
[plot] wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/KGE_monthly_stage/05_Boonville_KGE_vs_lead_MONTHLY_MEAN.png


KeyboardInterrupt: 

<h1>box plot</h1>

In [9]:
import pandas as pd
import numpy as np
from pathlib import Path
from matplotlib import pyplot as plt
import re
import os

# stations info (same order as before)
stations = [
    "01_Rulo",
    "02_St_Joseph",
    "03_Kansas_City",
    "04_Waverly",
    "05_Boonville",
    "06_Hermann",
    "07_St_Charles",
    "08_Grafton",
    "09_ST_Louis",
    "10_Chester",
    "11_Thebes"
]

out_dir = Path("/media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH")

# helper: assign season label based on timestamp month
def season_from_month(ts):
    m = ts.month
    # dry: Sep(9) Oct(10) Nov(11) Dec(12) Jan(1)
    # wet: Mar(3) ... Aug(8)
    if m in [9,10,11,12,1,2]:
        return "dry"
    else:
        return "wet"

# horizons to include in boxplots:
# every 24h step starting at 6h up to 720h, plus 720 if missing.
def get_horizons(cols):
    # cols is list of strings like ["006","012","018",...]
    horizons_all = sorted([c for c in cols if re.fullmatch(r"\d{3}", c)], key=lambda s:int(s))
    wanted = []
    for h in range(6, 721, 24):
        h_str = f"{h:03d}"
        if h_str in horizons_all:
            wanted.append(h_str)
    if "720" in horizons_all and "720" not in wanted:
        wanted.append("720")
    return wanted

# plotting function:
def plot_season_boxplots(station_name):
    """
    For one station, load the per-subfolder merged CSVs that were already written:
      {station}_{subset}_q_tolerance_...csv
    We'll search for those in out_dir.
    We'll build seasonal error arrays for each subset and horizon.
    Then make 2 figures: dry, wet.
    Each figure has x-axis = lead time (days), and for each lead time,
    we overlay 5 boxplots side by side? or combine?
    
    You asked: "I think for the box plot it would only have 2 figures per station. dry and wet and each one having 30 days error box plot"
    
    We'll do grouped boxplots per lead time where each subset is a separate color at that lead.
    That way each figure shows mem1, mem2, mem3, mem4, Ensemble mean together.
    """

    # find all CSVs for this station (the merged outputs from earlier)
    # pattern example: "01_Rulo_mem1_q_tolerance_25.0_day_bracket_730.csv"
    station_csvs = list(out_dir.glob(f"{station_name}_*_q_tolerance_*.csv"))
    if not station_csvs:
        print(f"[WARN] no seasonal CSVs found for {station_name}")
        return

    # read each subset file into dict {subset_name: df}
    subset_dfs = {}
    for f in station_csvs:
        # subset is between station_ and _q_tolerance
        # e.g. "01_Rulo_mem1_q_tolerance_25.0_day_bracket_730.csv"
        m = re.search(rf"^{re.escape(station_name)}_(.+?)_q_tolerance_", f.name)
        if not m:
            continue
        subset_raw = m.group(1)  # e.g. "mem1" or "All"
        subset_name = {
            "All": "Ensemble mean",
            "mem1": "Ensemble 1",
            "mem2": "Ensemble 2",
            "mem3": "Ensemble 3",
            "mem4": "Ensemble 4",
        }.get(subset_raw, subset_raw)

        df = pd.read_csv(f, parse_dates=["time"])
        df = df.set_index("time").sort_index()

        # add season column
        df["season"] = [season_from_month(ts) for ts in df.index]

        subset_dfs[subset_name] = df

    if not subset_dfs:
        print(f"[WARN] could not parse subset dfs for {station_name}")
        return

    # figure out consistent horizons across subsets
    all_cols = set()
    for df in subset_dfs.values():
        all_cols.update(df.columns.tolist())
    all_cols = list(all_cols)
    horizons = get_horizons(all_cols)
    if not horizons:
        print(f"[WARN] no horizons for {station_name}")
        return

    # we will build data structure:
    # errors[season][horizon][subset] = array of (forecast - obs)
    seasons = ["dry","wet"]
    errors = {s: {h: {} for h in horizons} for s in seasons}

    for subset_name, df in subset_dfs.items():
        if "stage_m_USGS_nearest" not in df.columns:
            print(f"[WARN] {station_name} {subset_name} missing stage_m_USGS_nearest")
            continue
        obs = df["stage_m_USGS_nearest"]

        for h in horizons:
            if h not in df.columns:
                continue
            pred = df[h]
            err_series = (pred - obs).dropna()

            if err_series.empty:
                continue

            # split by season
            joined = pd.concat([err_series, df["season"]], axis=1)
            joined.columns = ["err","season"]
            for s in seasons:
                vals = joined.loc[joined["season"] == s, "err"].dropna().values
                if vals.size == 0:
                    continue
                errors[s][h].setdefault(subset_name, [])
                errors[s][h][subset_name].extend(vals.tolist())

    # Now we plot two figs: dry and wet
    for s in seasons:
        # Build grouped boxplot data.
        # For each horizon h in horizons, we will have up to 5 subsets.
        subset_order = ["Ensemble mean","Ensemble 1","Ensemble 2","Ensemble 3","Ensemble 4"]
        # Collect data in consistent order
        box_data = []      # list of lists of arrays, shape ~ (len(horizons)*len(subset_order))
        box_positions = [] # x positions for each box
        box_labels_days = []  # one label per horizon group (days)
        group_width = 0.8
        n_sub = len(subset_order)
        if n_sub == 0:
            continue
        dx = group_width / max(n_sub,1)

        xpos = 1
        for h in horizons:
            # lead time label in days
            lead_hours = int(h)
            if lead_hours == 6:
                day_label = "6 h"
            else:
                day_label = f"{lead_hours//24:d}"

            # stash label for tick location (center of group)
            box_labels_days.append(day_label)

            # for each subset in fixed order, append its error array or empty
            for si, subset_name in enumerate(subset_order):
                vals = errors[s][h].get(subset_name, [])
                box_data.append(vals if len(vals)>0 else [np.nan])
                # position offset within group
                box_positions.append(xpos + (si - (n_sub-1)/2.0)*dx)

            xpos += 1  # next horizon group

        if not box_data:
            print(f"[INFO] No data to plot for {station_name} season={s}")
            continue

        fig, ax = plt.subplots(figsize=(14,6))

        bp = ax.boxplot(
            box_data,
            positions=box_positions,
            widths=dx*0.9,
            showfliers=False,
            showmeans=True,
            meanline=True,
            meanprops={"marker": "D", "markersize": 2},
            medianprops={"linewidth":1.2, "color":"blue"},
            patch_artist=True
        )

        # color each subset consistently across horizons
        # we'll just rotate through 5 default matplotlib colors by subset index
        # (not manually setting custom colors, to keep matplotlib defaults valid if you change style)
        # assign facecolors subset-wise:
        import itertools
        color_cycle = plt.rcParams['axes.prop_cycle'].by_key().get('color', [])
        # build map subset -> color index
        subset_color = {subset_name: color_cycle[i % len(color_cycle)]
                        for i, subset_name in enumerate(subset_order)}
        for i_box, patch in enumerate(bp['boxes']):
            subset_idx = i_box % n_sub
            subset_name = subset_order[subset_idx]
            patch.set_facecolor(subset_color.get(subset_name, "#cccccc"))
            patch.set_alpha(0.6)

        # x ticks at the center of each horizon group
        tick_positions = []
        xpos = 1
        for _h in horizons:
            tick_positions.append(xpos)
            xpos += 1

        ax.set_xticks(tick_positions)
        ax.set_xticklabels(box_labels_days, rotation=0)
        ax.axhline(0, linestyle="--", linewidth=1.2, color="k", alpha=0.8, zorder=0)
        ax.set_ylabel("Gage height forecast error (m)")
        if s == "dry":
            season_title = "Dry season (Sep-Feb)"
        else:
            season_title = "Wet season (Mar-Aug)"

        ax.set_xlabel("Lead time (days)")
        ax.set_title(f"{station_name}: Stage error by lead time, {season_title}")
        ax.grid(axis="y", alpha=0.3)
        ax.set_ylim(-6, 4)
        # legend from subset_color
        handles = []
        labels  = []
        for sub_name in subset_order:
            if any(sub_name in d for d in errors[s].values()):
                handles.append(plt.Line2D([0],[0], marker='s', linestyle='',
                              markerfacecolor=subset_color.get(sub_name,"#ccc"),
                              markeredgecolor='k', alpha=0.6))
                labels.append(sub_name)
        ax.legend(handles, labels, title="Subset", ncol=2, loc="upper left", frameon=False)

        fig.tight_layout()
        out_png = out_dir / f"{station_name}_errors_boxplot_{s}_season.png"
        fig.savefig(out_png, dpi=300)
        plt.close(fig)
        print(f"Wrote {out_png}")

# run for all stations
for st in stations:
    plot_season_boxplots(st)

print("Seasonal dry/wet error boxplots generated for all stations.")


Wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/01_Rulo_errors_boxplot_dry_season.png
Wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/01_Rulo_errors_boxplot_wet_season.png
Wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/02_St_Joseph_errors_boxplot_dry_season.png
Wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/02_St_Joseph_errors_boxplot_wet_season.png
Wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/03_Kansas_City_errors_boxplot_dry_season.png
Wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/03_Kansas_City_errors_boxplot_wet_season.png
Wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/04_Waverly_errors_boxplot_dry_season.png
Wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/04_Waverly_errors_boxplot_wet_season.png
Wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/05_Boonville_errors_boxplot_dry_season.png
Wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/05_Boonville_errors_boxplot_wet_season.png
Wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_

In [7]:
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

########################################
# User config
########################################

out_dir = Path("/media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH")

stl_file     = out_dir / "09_ST_Louis_All_q_tolerance_25.0_day_bracket_1000.csv"
chester_file = out_dir / "10_Chester_All_q_tolerance_25.0_day_bracket_1000.csv"

obs_col  = "stage_m_USGS_nearest"
time_col = "time"

# Lead times (forecast horizons) to include in hours
stl_horizons_hours = [6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72]
chs_horizons_hours = [6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72, 78, 84, 90, 96, 102, 108, 114, 120]

# Lag sweep settings
max_lag_days   = 5        # we'll look +/- 5 days time shift
lag_step_hours = 6        # check lag every 6 hours
########################################


def tag_season(ts):
    """Return 'dry' or 'wet' according to your rule:
       dry = Sep(9), Oct(10), Nov(11), Dec(12), Jan(1)
       wet = Mar(3)...Aug(8)
       Feb(2) is not mentioned in your rule. We have to assign it.
       I'll treat Feb as 'dry' since it's low-flow/cold season for most midwest rivers.
    """
    m = ts.month
    if m in [9, 10, 11, 12, 1, 2]:
        return "dry"
    else:
        return "wet"


def pearson_r(x, y):
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)
    if x.size < 2 or y.size < 2:
        return np.nan
    return float(np.corrcoef(x, y)[0, 1])


def lagged_corr(up_series, down_series, max_lag_days, lag_step_hours):
    """
    up_series:   upstream error series (St. Louis)
    down_series: downstream error series (Chester)

    We shift 'down_series' in time by lag_h hours.
    Positive lag means Chester is moved forward in time (arrives later).
    Returns DataFrame of [lag_hours, corr, N].
    """
    results = []
    max_lag_hours = max_lag_days * 24
    lag_hours_list = np.arange(-max_lag_hours,
                               max_lag_hours + 0.1,
                               lag_step_hours)
    for lag_h in lag_hours_list:
        shifted = down_series.copy()
        shifted.index = shifted.index + pd.Timedelta(hours=lag_h)

        merged = pd.concat(
            [up_series.rename("up"),
             shifted.rename("down")],
            axis=1
        ).dropna()

        if merged.empty:
            results.append({"lag_hours": lag_h, "corr": np.nan, "N": 0})
        else:
            r = pearson_r(merged["up"].values, merged["down"].values)
            results.append({"lag_hours": lag_h, "corr": r, "N": len(merged)})

    return pd.DataFrame(results)


def best_lag_info(lag_df):
    """Given lag_df from lagged_corr, return (best_lag_hours, best_corr, N_at_best)."""
    lag_df_valid = lag_df.dropna(subset=["corr"])
    if lag_df_valid.empty:
        return np.nan, np.nan, 0
    idx_best = lag_df_valid["corr"].idxmax()
    row = lag_df_valid.loc[idx_best]
    return float(row["lag_hours"]), float(row["corr"]), int(row["N"])


def load_station_df(csv_path, obs_col, time_col):
    """Load CSV, set DatetimeIndex, add season label, return DataFrame."""
    df = pd.read_csv(csv_path, parse_dates=[time_col]).set_index(time_col).sort_index()
    df["season"] = [tag_season(ts) for ts in df.index]
    return df


def compute_season_matrix(stl_df, chs_df, stl_horizons_hours, chs_horizons_hours,
                          season_label, max_lag_days, lag_step_hours, obs_col):
    """
    For a given season ('dry' or 'wet'), build:
      1) correlation matrix (max corr at best lag)
      2) lag matrix (best lag in days at that max corr)

    Returns (result_df, heat_corr, heat_lag)
    where:
      - result_df is long form rows for CSV
      - heat_corr is a 2D DataFrame indexed by STL horizon hr, cols=CHS horizon hr
      - heat_lag is same shape but best lag in days
    """
    recs = []

    # convenience subsets for that season only
    stl_season = stl_df[stl_df["season"] == season_label]
    chs_season = chs_df[chs_df["season"] == season_label]

    for h_up in stl_horizons_hours:
        col_up = f"{h_up:03d}"
        if col_up not in stl_season.columns:
            continue

        stl_err = (stl_season[col_up] - stl_season[obs_col]).dropna().sort_index()
        if stl_err.empty:
            continue

        for h_down in chs_horizons_hours:
            col_down = f"{h_down:03d}"
            if col_down not in chs_season.columns:
                continue

            chs_err = (chs_season[col_down] - chs_season[obs_col]).dropna().sort_index()
            if chs_err.empty:
                continue

            # compute lag curve for this horizon pair
            lag_df = lagged_corr(
                up_series=stl_err,
                down_series=chs_err,
                max_lag_days=max_lag_days,
                lag_step_hours=lag_step_hours
            )

            lag_hours, max_corr, n_pairs = best_lag_info(lag_df)

            recs.append({
                "season": season_label,
                "STL_horizon_hr": h_up,
                "CHS_horizon_hr": h_down,
                "best_lag_hours": lag_hours,
                "best_lag_days": (lag_hours / 24.0) if not np.isnan(lag_hours) else np.nan,
                "max_corr": max_corr,
                "N_overlap": n_pairs
            })

    result_df = pd.DataFrame.from_records(recs)

    # pivot for heatmaps
    if not result_df.empty:
        heat_corr = result_df.pivot(
            index="STL_horizon_hr",
            columns="CHS_horizon_hr",
            values="max_corr"
        )
        heat_lag = result_df.pivot(
            index="STL_horizon_hr",
            columns="CHS_horizon_hr",
            values="best_lag_days"
        )
    else:
        # empty fallback
        heat_corr = pd.DataFrame()
        heat_lag = pd.DataFrame()

    return result_df, heat_corr, heat_lag


def plot_heatmap(matrix_df, title, xlabel, ylabel, cbar_label,
                 outfile, cmap="viridis", vmin=None, vmax=None,
                 x_is_hours=True, y_is_hours=True):
    """
    matrix_df is a 2D DataFrame with numeric values.
    We'll label ticks in days (hr/24.0).
    """

    if matrix_df.empty:
        print(f"[WARN] nothing to plot for {outfile}")
        return

    fig, ax = plt.subplots(figsize=(8,6))

    im = ax.imshow(
        matrix_df.values,
        aspect="auto",
        origin="lower",
        cmap=cmap,
        vmin=vmin,
        vmax=vmax
    )

    # tick labels
    x_vals = list(matrix_df.columns)
    y_vals = list(matrix_df.index)

    if x_is_hours:
        x_ticklabels = [f"{h/24:.1f}d" for h in x_vals]
    else:
        x_ticklabels = [str(x) for x in x_vals]

    if y_is_hours:
        y_ticklabels = [f"{h/24:.1f}d" for h in y_vals]
    else:
        y_ticklabels = [str(y) for y in y_vals]

    ax.set_xticks(range(len(x_vals)))
    ax.set_xticklabels(x_ticklabels, rotation=45)

    ax.set_yticks(range(len(y_vals)))
    ax.set_yticklabels(y_ticklabels)

    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)

    cbar = fig.colorbar(im, ax=ax)
    cbar.set_label(cbar_label)

    fig.tight_layout()
    plt.savefig(outfile, dpi=200)
    plt.close(fig)
    print(f"Wrote {outfile}")


def main():
    # Load both stations' ensemble-mean data
    stl_df = load_station_df(stl_file,     obs_col, time_col)
    chs_df = load_station_df(chester_file, obs_col, time_col)

    # We'll do this for dry and wet separately
    for season_label in ["dry", "wet"]:
        print(f"Processing season: {season_label}")

        result_df, heat_corr, heat_lag = compute_season_matrix(
            stl_df,
            chs_df,
            stl_horizons_hours,
            chs_horizons_hours,
            season_label,
            max_lag_days,
            lag_step_hours,
            obs_col
        )

        # save numeric summary for this season
        season_csv = out_dir / f"lag_corr_matrix_STL_vs_CHS_{season_label}.csv"
        result_df.to_csv(season_csv, index=False)
        print(f"Wrote {season_csv}")

        # plot max correlation heatmap for this season
        corr_png = out_dir / f"lag_corr_heatmap_corr_{season_label}.png"
        plot_heatmap(
            heat_corr,
            title=f"Max correlation (STL error vs CHS error), {season_label} season",
            xlabel="Chester lead time (days)",
            ylabel="St. Louis lead time (days)",
            cbar_label="corr (Pearson r)",
            outfile=corr_png,
            cmap="plasma",        # good contrast for correlation
            vmin=0.0, vmax=1.0
        )

        # plot best lag heatmap for this season
        # We'll use a diverging cmap centered on 0 days lag
        lag_png = out_dir / f"lag_corr_heatmap_lagdays_{season_label}.png"
        # Estimate symmetric range for color scale
        # We scan heat_lag to get max abs lag to set vmin/vmax balanced
        if not heat_lag.empty:
            max_abs_lag = np.nanmax(np.abs(heat_lag.values))
            if np.isnan(max_abs_lag):
                max_abs_lag = 2.0  # fallback
            vmax_lag = max_abs_lag
        else:
            vmax_lag = 2.0

        plot_heatmap(
            heat_lag,
            title=f"Best lag in days (positive = CHS later), {season_label} season",
            xlabel="Chester lead time (days)",
            ylabel="St. Louis lead time (days)",
            cbar_label="lag (days)",
            outfile=lag_png,
            cmap="RdBu_r",        # blue=negative lag, red=positive lag
            vmin=-vmax_lag, vmax=vmax_lag
        )


if __name__ == "__main__":
    main()


Processing season: dry
Wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/lag_corr_matrix_STL_vs_CHS_dry.csv
Wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/lag_corr_heatmap_corr_dry.png
Wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/lag_corr_heatmap_lagdays_dry.png
Processing season: wet
Wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/lag_corr_matrix_STL_vs_CHS_wet.csv
Wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/lag_corr_heatmap_corr_wet.png
Wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/lag_corr_heatmap_lagdays_wet.png
