In [33]:
import pandas as pd
from pathlib import Path
from matplotlib import pyplot as plt
import numpy as np
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
import re

In [34]:
station=["01_Rulo",
        "02_St_Joseph",
        "03_Kansas_City",
        "04_Waverly",
        "05_Boonville",
        "06_Hermann",
        "07_St_Charles",
        "08_Grafton",
        "09_ST_Louis",
        "10_Chester",
        "11_Thebes"]
USGS_id=["01_Rulo_06813500.csv",
        "02_St_Joseph_06818000.csv",
        "03_Kansas_City_06893000.csv",
        "04_Waverly_06895500.csv",
        "05_Boonville_06909000.csv",
        "06_Hermann_06934500.csv",
        "07_St_Charles_06935965.csv",
        "08_Grafton_05587450.csv",
        "09_ST_Louis_07010000.csv",
        "10_Chester_07020500.csv",
        "11_Thebes_07022000.csv"]
         
i=9

hrs1="006"
hrs2="720"
t0 = "2019-01-01"
t1 = "2025-01-01"

day_bracket = 730  
day_skip = 30
q_tolerance = 25.0
q_tol_max = None

station_dir = Path(f"/media/12TB/Sujan/NWM/Csv/{station[i]}")

files = sorted(station_dir.glob("**/timeseries_*.csv"))
if not files:
    raise RuntimeError(f"No timeseries_* files under {station_dir}")

# Group by parent folder
groups = {}
for f in files:
    groups.setdefault(f.parent, []).append(f)

print(f"Found {len(groups)} subfolder(s):", [p.name for p in groups])

all_summaries = []   
outputs = []         

df_USGS = pd.read_csv(f"/media/12TB/Sujan/NWM/USGS_data/{USGS_id[i]}",
                      parse_dates=["timestamp"]).set_index("timestamp")

Found 5 subfolder(s): ['All', 'mem1', 'mem2', 'mem3', 'mem4']


In [35]:
def find_GH_with_tolerance(
    df_nwm,
    df_usgs,
    day_skip=7,
    day_bracket=7,
    q_tolerance=1.0,
    nwm_col="streamflow",
    usgs_q_col="discharge_cumecs",
    usgs_h_col="stage_ft",
):
    """
    For each NWM time t (after day_skip), search USGS data within [t - day_bracket, t + day_bracket].
    Take the mean of all USGS stages whose discharge lies within ±q_tolerance of the NWM discharge.
    If none match, return NaN. No interpolation.

    Parameters
    ----------
    df_nwm : DataFrame with DatetimeIndex and column nwm_col (default 'streamflow')
    df_usgs : DataFrame with DatetimeIndex and columns usgs_q_col ('discharge_cumecs'),
              usgs_h_col ('stage_ft')
    day_skip : days to skip from the start of df_nwm before processing
    day_bracket : half-window in days around each NWM time
    q_tolerance : discharge tolerance (same units as df_usgs[usgs_q_col])
    nwm_col, usgs_q_col, usgs_h_col : column names for flexibility
    """
    if not isinstance(df_nwm.index, pd.DatetimeIndex):
        raise TypeError("df_nwm must have a DatetimeIndex")
    if not isinstance(df_usgs.index, pd.DatetimeIndex):
        raise TypeError("df_usgs must have a DatetimeIndex")
    for col in [nwm_col]:
        if col not in df_nwm.columns:
            raise KeyError(f"df_nwm needs '{nwm_col}' column")
    for col in [usgs_q_col, usgs_h_col]:
        if col not in df_usgs.columns:
            raise KeyError(f"df_usgs needs '{usgs_q_col}' and '{usgs_h_col}' columns")

    # Keep needed cols and drop NaNs on USGS side
    nwm = df_nwm[[nwm_col]].copy()
    usgs = df_usgs[[usgs_q_col, usgs_h_col]].dropna().copy()

    # Sort by time for fast time-slicing
    nwm.sort_index(inplace=True)
    usgs.sort_index(inplace=True)

    # Apply day_skip
    start_time = nwm.index.min() + pd.Timedelta(days=day_skip)
    nwm_use = nwm.loc[nwm.index >= start_time].copy()

    out_records = []
    halfwin = pd.Timedelta(days=day_bracket)

    for t, row in nwm_use.iterrows():
        q = row[nwm_col]
        if pd.isna(q):
            out_records.append((t, np.nan))
            continue

        # Time window
        us = usgs.loc[t - halfwin : t + halfwin]
        if us.empty:
            out_records.append((t, np.nan))
            continue

        # Discharge band ± q_tolerance
        band = us[(us[usgs_q_col] >= q - q_tolerance) & (us[usgs_q_col] <= q + q_tolerance)]
        if band.empty:
            out_records.append((t, np.nan))
            continue

        # 1) closest discharge within the band
        band = band.assign(dq=(band[usgs_q_col] - q).abs())
        min_dq = band["dq"].min()
        cand = band[band["dq"] == min_dq]

        # 2) tie-break by closest time to t
        dt_sec = np.abs((cand.index - t).total_seconds())
        best = cand.iloc[int(np.argmin(dt_sec))]

        out_records.append((t, float(best[usgs_h_col])))

    result = nwm.copy()
    gh_series = pd.DataFrame(out_records, columns=["time", "Computed_GH_m"]).set_index("time")
    result = result.join(gh_series, how="left")
    return result

def run_one(file_path, df_usgs, code,
            day_skip=0, day_bracket=10, q_tolerance=5.0):
    df_nwm = pd.read_csv(file_path, parse_dates=["time"]).set_index("time")
    out = find_GH_with_tolerance(
        df_nwm, df_usgs,
        day_skip=day_skip,
        day_bracket=day_bracket,
        q_tolerance=q_tolerance,
        nwm_col="streamflow",
        usgs_q_col="discharge_cumecs",
        usgs_h_col="stage_m"
    )

    fill = out["Computed_GH_m"].notna().mean()
    print(f"{code}: filled {fill:.1%}  "
          f"window=±{day_bracket}d, Q_tol={q_tolerance})_")
    return out["Computed_GH_m"].rename(code)

def worker(args):
    f, code, day_skip, day_bracket, q_tol = args
    df_usgs_local = df_USGS  # captured from outer scope
    s = run_one(f, df_usgs_local, code, day_skip, day_bracket, q_tol)
    return code, s

def mae(y_true, y_pred):
    y_true, y_pred = np.asarray(y_true), np.asarray(y_pred)
    return np.mean(np.abs(y_pred - y_true))

def mse(y_true, y_pred):
    y_true, y_pred = np.asarray(y_true), np.asarray(y_pred)
    return np.mean((y_pred - y_true) ** 2)

def kge(y_true, y_pred):
    """
    Kling-Gupta Efficiency (Gupta et al., 2009), using:
    r = Pearson correlation
    alpha = std(pred)/std(true)
    beta = mean(pred)/mean(true)
    KGE = 1 - sqrt((r-1)^2 + (alpha-1)^2 + (beta-1)^2)
    """
    y_true, y_pred = np.asarray(y_true), np.asarray(y_pred)
    if y_true.size < 2:
        return np.nan
    r = np.corrcoef(y_true, y_pred)[0, 1]
    if np.isnan(r):
        return np.nan
    std_t = np.std(y_true, ddof=1)
    std_p = np.std(y_pred, ddof=1)
    mean_t = np.mean(y_true)
    mean_p = np.mean(y_pred)
    if std_t == 0 or mean_t == 0:
        return np.nan
    alpha = std_p / std_t
    beta = mean_p / mean_t
    return 1.0 - np.sqrt((r - 1.0) ** 2 + (alpha - 1.0) ** 2 + (beta - 1.0) ** 2)

def bias(y_true, y_pred):
    y_true, y_pred = np.asarray(y_true), np.asarray(y_pred)
    return float(np.mean(y_pred - y_true))

def rmse(y_true, y_pred):
    y_true, y_pred = np.asarray(y_true), np.asarray(y_pred)
    return float(np.sqrt(np.mean((y_pred - y_true) ** 2)))

def pearson_r(y_true, y_pred):
    y_true, y_pred = np.asarray(y_true), np.asarray(y_pred)
    if y_true.size < 2:
        return np.nan
    r = np.corrcoef(y_true, y_pred)[0, 1]
    return float(r)

def coverage_fraction(y_true, y_pred):
    # share of timestamps where both obs and sim are present
    y_true, y_pred = np.asarray(y_true), np.asarray(y_pred)
    n = y_true.size
    return float(n)  # we'll keep N already, this helper is unused; kept for symmetry

def nse(y_true, y_pred):
    y_true, y_pred = np.asarray(y_true), np.asarray(y_pred)
    if y_true.size == 0:
        return np.nan
    denom = np.sum((y_true - np.mean(y_true)) ** 2)
    if denom == 0:
        return np.nan
    return 1.0 - np.sum((y_pred - y_true) ** 2) / denom

# ----------------------------
# Helper: compute metrics for one file
# ----------------------------
def summarize_file(file_path):
    """
    Returns a single-row DataFrame with columns like:
    006_MAE, 006_MSE, 006_KGE, 006_NSE, 012_MAE, ...
    """
    df = pd.read_csv(file_path, parse_dates=["time"]) if "time" in open(file_path).read(200).lower() else pd.read_csv(file_path)
    # Try to set index to time if present
    if "time" in df.columns:
        df = df.set_index("time")
    if "stage_m_USGS_nearest" not in df.columns:
        raise KeyError(f"'stage_m_USGS_nearest' not found in {file_path}")

    y = df["stage_m_USGS_nearest"]

    # horizon columns are 3-digit strings
    horizons = [c for c in df.columns if re.fullmatch(r"\d{3}", str(c))]
    if not horizons:
        raise ValueError(f"No horizon columns like '006' found in {file_path}")

    out = {}
    for h in sorted(horizons, key=lambda s: int(s)):
        pair = pd.concat([y, df[h]], axis=1, keys=["obs", "sim"]).dropna()
        if pair.empty:
            out[f"{h}_MAE"]  = np.nan
            out[f"{h}_MSE"]  = np.nan
            out[f"{h}_KGE"]  = np.nan
            out[f"{h}_NSE"]  = np.nan
            out[f"{h}_Bias"] = np.nan
            out[f"{h}_RMSE"] = np.nan
            out[f"{h}_R"]    = np.nan
            out[f"{h}_N"]    = 0
            continue

        yt = pair["obs"].values
        yp = pair["sim"].values

        out[f"{h}_MAE"]  = mae(yt, yp)
        out[f"{h}_MSE"]  = mse(yt, yp)
        out[f"{h}_KGE"]  = kge(yt, yp)
        out[f"{h}_NSE"]  = nse(yt, yp)
        out[f"{h}_Bias"] = bias(yt, yp)
        out[f"{h}_RMSE"] = rmse(yt, yp)
        out[f"{h}_R"]    = pearson_r(yt, yp)
        out[f"{h}_N"]    = len(pair)

    # Also include some file-level info if useful
    sr = pd.Series(out)
    return sr.to_frame().T  # single row



In [36]:
# ------------ Per-subfolder processing ------------
out_dir = Path("/media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH")
out_dir.mkdir(parents=True, exist_ok=True)

all_summaries = []   # rows from summarize_file(...)
outputs = []         # paths to per-subfolder CSVs you just wrote

for subdir, file_list in groups.items():
    # Build tasks from whatever horizons exist in this subdir
    tasks = []
    for f in sorted(file_list):
        m = re.search(r"timeseries_(\d{3})\.csv", f.name)
        if not m:
            continue
        code = m.group(1)
        q_tol = q_tolerance if q_tol_max is None else min(q_tolerance, q_tol_max)
        tasks.append((f, code, day_skip, day_bracket, q_tol))

    if not tasks:
        print(f"Skip {subdir}: no matching files")
        continue

    # Run in parallel
    wide = None
    max_workers = min(32, os.cpu_count() or 8)
    with ProcessPoolExecutor(max_workers=max_workers) as ex:
        futures = [ex.submit(worker, t) for t in tasks]
        for fut in as_completed(futures):
            code, s = fut.result()
            if wide is None:
                wide = pd.DataFrame(index=s.index)
            wide = wide.join(s, how="outer")

    if wide is None:
        print(f"Skip {subdir}: produced no output")
        continue

    wide.sort_index(inplace=True)

    # Nearest-time join with USGS stage
    usgs_stage = (
        pd.merge_asof(
            left=wide.reset_index().rename(columns={"index":"time"}).sort_values("time"),
            right=df_USGS[["stage_m"]].reset_index().rename(columns={"timestamp":"time"}).sort_values("time"),
            on="time", direction="nearest", tolerance=pd.Timedelta("3h")
        )
        .set_index("time")["stage_m"]
        .rename("stage_m_USGS_nearest")
    )

    merged_all = wide.join(usgs_stage, how="left")

    # Suffix from subfolder name (e.g., 'All', 'mem1')
    suffix = subdir.name

    # Save per-subfolder merged file
    out_csv = out_dir / f"{station[i]}_{suffix}_q_tolerance_{q_tolerance}_day_bracket_{day_bracket}.csv"
    merged_all.to_csv(out_csv)
    outputs.append(str(out_csv))
    print(f"Wrote {out_csv}")

    # --- Timeseries plot for discharge & stage (per subfolder) ---
    f006 = subdir / f"timeseries_{hrs1}.csv"
    f720 = subdir / f"timeseries_{hrs2}.csv"
    if f006.exists() and f720.exists():
        df_hrs1_NWM = pd.read_csv(f006, parse_dates=["time"]).set_index("time")
        df_hrs2_NWM = pd.read_csv(f720, parse_dates=["time"]).set_index("time")

        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(9, 12))
        ax1.plot(df_USGS.index, df_USGS["discharge_cumecs"], "o",color="green", label="USGS discharge", markersize=1)
        ax1.plot(df_hrs1_NWM.index, df_hrs1_NWM["streamflow"], "o",color="blue", label=f"NWM {hrs1} discharge", markersize=1)
        ax1.plot(df_hrs2_NWM.index, df_hrs2_NWM["streamflow"], "o",color="orange", label=f"NWM {hrs2} discharge", markersize=1)
        ax1.set_ylabel("Discharge (cumecs)")
        ax1.set_title(f"{station[i]} ({suffix}): USGS, 6h and 720h Discharge")
        ax1.legend(); ax1.grid(True); ax1.set_xlim(pd.to_datetime(t0), pd.to_datetime(t1))

        ax2.plot(df_USGS.index, df_USGS["stage_m"], "o",color="green", label="USGS stage", markersize=1)
        ax2.plot(merged_all.index, merged_all.get(hrs1), "o",color="blue", label=f"NWM {hrs1} stage", alpha=0.7, markersize=1)
        ax2.plot(merged_all.index, merged_all.get(hrs2), "o",color="orange", label=f"NWM {hrs2} stage", alpha=0.7, markersize=1)
        ax2.set_xlabel("Time"); ax2.set_xlim(pd.to_datetime(t0), pd.to_datetime(t1))
        ax2.set_ylabel("Gage height (m)")
        ax2.set_title(f"{station[i]} ({suffix}): USGS, 6h and 720h Stage")
        ax2.legend(); ax2.grid(True)

        fig.tight_layout()
        out_ts_png = out_dir / f"{station[i]}_{suffix}_timeseries_6h_720h.png"
        fig.savefig(out_ts_png, dpi=300)
        plt.close(fig)
        print(f"Wrote {out_ts_png}")
    else:
        print(f"Timeseries plot skipped for {suffix}: missing {hrs1} or {hrs2}")

    # --- Error boxplot (sim - obs) every 24h starting at 006, include 720) ---
    horizon_cols = [c for c in merged_all.columns if re.fullmatch(r"\d{3}", str(c))]
    if horizon_cols:
        cols = [f"{h:03d}" for h in range(6, 721, 24) if f"{h:03d}" in horizon_cols]
        if "720" in horizon_cols and "720" not in cols:
            cols.append("720")
        if cols:
            err_df = merged_all[cols].subtract(merged_all["stage_m_USGS_nearest"], axis=0).dropna(how="all")
            if not err_df.empty:
                fig, ax = plt.subplots(figsize=(12, 6))
                err_df.boxplot(
                    column=cols,
                    ax=ax,
                    showfliers=False,
                    showmeans=True,
                    meanline=True,
                    meanprops={"marker": "D", "markersize": 1},
                    medianprops={"linewidth": 1.5, "color": "blue"},
                )
                ax.axhline(0, linestyle="--", linewidth=1.5, color="k", alpha=0.8, zorder=0)
                ax.set_xlabel("Lead time (days)")
                ax.set_ylabel("Gage height forecast error (m)")
                ax.set_title(f"{station[i]} ({suffix}): Stage errors by lead time")

                labels = ["6 h" if int(c) == 6 else f"{int(c)//24:d}" for c in cols]
                ax.set_xticklabels(labels, rotation=0)
                ax.grid(axis="y", alpha=0.3)

                out_png = out_dir / f"{station[i]}_{suffix}_errors_boxplot.png"
                fig.tight_layout()
                fig.savefig(out_png, dpi=300)
                plt.close(fig)
                print(f"Wrote {out_png}")
            else:
                print(f"{suffix}: err_df empty; boxplot skipped")
        else:
            print(f"{suffix}: no 24h-stepped horizons found; boxplot skipped")
    else:
        print(f"{suffix}: no horizon columns found; boxplot skipped")

    # --- Collect per-subfolder summary row ---
    row = summarize_file(out_csv)
    row.index = [f"{station[i]}_{suffix}"]
    row.insert(0, "Station", station[i])
    row.insert(1, "Subset",  suffix)
    all_summaries.append(row)

# ------------ Finalize combined stats across subfolders ------------
if all_summaries:
    combined_summary = pd.concat(all_summaries, axis=0)

    # WIDE summary: replace this station’s rows if file exists
    out_wide = out_dir / "summary_metrics_by_subfolder_wide.csv"
    if out_wide.exists():
        existing = pd.read_csv(out_wide)
        if "Station" in existing.columns:
            existing = existing[existing["Station"] != station[i]]
        combined_summary = pd.concat([existing, combined_summary], ignore_index=True)
    combined_summary.to_csv(out_wide, index=False)
    print(f"Updated {out_wide} with {station[i]} data")

    # LONG (tidy) summary: replace this station’s rows if file exists
    value_cols = [c for c in combined_summary.columns if re.fullmatch(r"\d{3}_(MAE|MSE|KGE|NSE|Bias|RMSE|R|N)", c)]
    long = (
        combined_summary
        .reset_index(drop=True)
        .melt(id_vars=["Station", "Subset"], value_vars=value_cols, var_name="H_M", value_name="Value")
    )
    long[["Horizon","Metric"]] = long["H_M"].str.extract(r"^(\d{3})_(.+)$")
    long.drop(columns=["H_M"], inplace=True)
    long["LeadDays"] = long["Horizon"].astype(int) / 24.0

    out_long = out_dir / "summary_metrics_by_subfolder_long.csv"
    if out_long.exists():
        existing_long = pd.read_csv(out_long)
        if "Station" in existing_long.columns:
            existing_long = existing_long[existing_long["Station"] != station[i]]
        long = pd.concat([existing_long, long], ignore_index=True)
    long.to_csv(out_long, index=False)
    print(f"Updated {out_long} with {station[i]} data")

    # --- Combined KGE vs lead-time plot (all subfolders together) ---
    try:
        kge_long = long[(long["Metric"] == "KGE") & (long["Station"] == station[i])].copy()
        if not kge_long.empty:
            # Rename "All" to "Ensemble mean" for plotting
            kge_long["Subset"] = kge_long["Subset"].replace({
                    "All": "Ensemble mean",
                    "mem1": "Ensemble 1",
                    "mem2": "Ensemble 2",
                    "mem3": "Ensemble 3",
                    "mem4": "Ensemble 4",                  
                })

            fig, ax = plt.subplots(figsize=(9, 5))

            # Loop through each subfolder and scatter their points
            for subset, dsub in kge_long.groupby("Subset"):
                ax.scatter(
                    dsub["LeadDays"], dsub["Value"],
                    s=15, label=subset, alpha=0.8
                )

            ax.set_xlabel("Lead time (days)")
            ax.set_ylabel("KGE")
            ax.set_title(f"{station[i]}: KGE vs lead time between different Ensembles")
            ax.grid(True, alpha=0.3)
            ax.legend(title="Subset")
            fig.tight_layout()

            out_kge_png = out_dir / f"{station[i]}_KGE_by_subfolder.png"
            fig.savefig(out_kge_png, dpi=300)
            plt.close(fig)
            print(f"Wrote {out_kge_png}")
        else:
            print("No KGE rows found; KGE plot skipped.")
    except Exception as e:
        print("KGE plot skipped:", e)




042: filled 95.8%  window=±730d, Q_tol=25.0)_
012: filled 97.5%  window=±730d, Q_tol=25.0)_
114: filled 97.3%  window=±730d, Q_tol=25.0)_
060: filled 97.3%  window=±730d, Q_tol=25.0)_
108: filled 97.4%  window=±730d, Q_tol=25.0)_
156: filled 97.7%  window=±730d, Q_tol=25.0)_
072: filled 97.2%  window=±730d, Q_tol=25.0)_
036: filled 95.6%  window=±730d, Q_tol=25.0)_
018: filled 97.6%  window=±730d, Q_tol=25.0)_
102: filled 97.3%  window=±730d, Q_tol=25.0)_
180: filled 97.8%  window=±730d, Q_tol=25.0)_
006: filled 97.7%  window=±730d, Q_tol=25.0)_
174: filled 97.8%  window=±730d, Q_tol=25.0)_
162: filled 97.8%  window=±730d, Q_tol=25.0)_
192: filled 97.8%  window=±730d, Q_tol=25.0)_
150: filled 97.7%  window=±730d, Q_tol=25.0)_
168: filled 97.8%  window=±730d, Q_tol=25.0)_
096: filled 97.4%  window=±730d, Q_tol=25.0)_
054: filled 97.2%  window=±730d, Q_tol=25.0)_
090: filled 97.4%  window=±730d, Q_tol=25.0)_
132: filled 97.4%  window=±730d, Q_tol=25.0)_
120: filled 97.3%  window=±730d, Q