In [2]:
import os
from pathlib import Path
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

########################################
# Config
########################################

stations = [
    "01_Rulo",
    "02_St_Joseph",
    "03_Kansas_City",
    "04_Waverly",
    "05_Boonville",
    "06_Hermann",
    "07_St_Charles",
    "08_Grafton",
    "09_ST_Louis",
    "10_Chester",
    "11_Thebes"
]

usgs_ids = [
    "01_Rulo_06813500.csv",
    "02_St_Joseph_06818000.csv",
    "03_Kansas_City_06893000.csv",
    "04_Waverly_06895500.csv",
    "05_Boonville_06909000.csv",
    "06_Hermann_06934500.csv",
    "07_St_Charles_06935965.csv",
    "08_Grafton_05587450.csv",
    "09_ST_Louis_07010000.csv",
    "10_Chester_07020500.csv",
    "11_Thebes_07022000.csv"
]

BASE_NWM_DIR = Path("/media/12TB/Sujan/NWM/Csv")
BASE_USGS_DIR = Path("/media/12TB/Sujan/NWM/USGS_data")
OUT_DIR = Path("/media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/Discharge_forecast_compare")
OUT_DIR.mkdir(parents=True, exist_ok=True)

########################################
# Helpers
########################################

def kge(sim, obs):
    """
    Kling-Gupta Efficiency
    """
    sim = np.asarray(sim, dtype=float)
    obs = np.asarray(obs, dtype=float)

    mask = np.isfinite(sim) & np.isfinite(obs)
    if mask.sum() < 3:
        return np.nan

    sim = sim[mask]
    obs = obs[mask]

    r = np.corrcoef(sim, obs)[0, 1]
    alpha = np.std(sim) / np.std(obs) if np.std(obs) != 0 else np.nan
    beta = np.mean(sim) / np.mean(obs) if np.mean(obs) != 0 else np.nan

    return 1 - np.sqrt((r - 1) ** 2 + (alpha - 1) ** 2 + (beta - 1) ** 2)


def subset_label(raw_name):
    """
    Map folder name to nice label for legend/CSV.
    """
    mapping = {
        "All": "Ensemble mean",
        "mem1": "Ensemble member 1",
        "mem2": "Ensemble member 2",
        "mem3": "Ensemble member 3",
        "mem4": "Ensemble member 4",
    }
    return mapping.get(raw_name, raw_name)


def get_lead_hours_from_filename(fname):
    """
    Extract lead hours from filename like timeseries_006.csv OR timeseries_090.csv.
    We'll pull the numeric part after underscore.
    """
    m = re.search(r"timeseries_(\d+)\.csv$", fname)
    if m:
        return int(m.group(1))
    else:
        raise ValueError(f"Cannot parse lead hours from filename {fname}")


def plot_kge_vs_lead(df_station_kge, station_name, out_dir):
    """
    Scatter plot: KGE vs lead days per subset for one station.
    """
    fig, ax = plt.subplots(figsize=(9,5))

    for subset, subdf in df_station_kge.groupby("Subset"):
        ax.scatter(
            subdf["LeadDays"],
            subdf["KGE"],
            s=25,
            alpha=0.8,
            label=subset
        )

    ax.set_xlabel("Lead time (days)")
    ax.set_ylabel("KGE")
    ax.set_title(f"{station_name}: KGE vs lead time (NWM discharge vs USGS)")
    ax.grid(True, alpha=0.3)
    ax.legend(title="Subset")
    fig.tight_layout()

    outfile = out_dir / f"{station_name}_KGE_vs_lead.png"
    fig.savefig(outfile, dpi=300)
    plt.close(fig)
    print(f"[plot] wrote {outfile}")


def plot_sample_timeseries(df_merged_station, station_name, out_dir, lead_hours_list=[24,72,120]):
    """
    QC plot: USGS vs Ensemble mean forecast for a few leads.
    df_merged_station columns: valid_time, Q_USGS_cms, Q_fcst_cms, lead_hours, Subset
    We'll only plot Subset == 'Ensemble mean'.
    """
    fig, ax = plt.subplots(figsize=(10,4))

    # Observations
    obs_only = (
        df_merged_station[["valid_time","Q_USGS_cms"]]
        .drop_duplicates()
        .sort_values("valid_time")
    )
    ax.plot(
        obs_only["valid_time"],
        obs_only["Q_USGS_cms"],
        linewidth=2,
        label="Observed (USGS)"
    )

    # Ensemble mean
    emean = df_merged_station[df_merged_station["Subset"] == "Ensemble mean"].copy()
    for lh in lead_hours_list:
        sub = emean[emean["lead_hours"] == lh].sort_values("valid_time")
        if sub.empty:
            continue
        ax.plot(
            sub["valid_time"],
            sub["Q_fcst_cms"],
            linewidth=1,
            alpha=0.8,
            label=f"Ens mean +{lh/24:.0f}d"
        )

    ax.set_title(f"{station_name}: Discharge timeseries (obs vs NWM Ens mean)")
    ax.set_xlabel("Time")
    ax.set_ylabel("Discharge (cumecs)")
    ax.grid(True, alpha=0.3)
    ax.legend()
    fig.tight_layout()

    outfile = out_dir / f"{station_name}_timeseries_check.png"
    fig.savefig(outfile, dpi=300)
    plt.close(fig)
    print(f"[plot] wrote {outfile}")


########################################
# Main processing
########################################

all_kge_records = []
all_merged_allstations = []

for i, stn in enumerate(stations):
    print(f"\n=== Station {stn} ===")

    # 1. Load USGS data
    usgs_file = BASE_USGS_DIR / usgs_ids[i]
    usgs_df = pd.read_csv(usgs_file, parse_dates=["timestamp"])
    # Normalize column names
    usgs_df = usgs_df.rename(columns={
        "timestamp": "valid_time",
        "discharge_cumecs": "Q_USGS_cms"
    })
    usgs_df = usgs_df[["valid_time", "Q_USGS_cms"]].dropna()

    # 2. Discover forecast member folders
    station_dir = BASE_NWM_DIR / stn
    if not station_dir.exists():
        print(f"[warn] Directory not found for {stn}, skipping.")
        continue

    subfolders = [p for p in station_dir.iterdir() if p.is_dir()]
    if not subfolders:
        print(f"[warn] No subfolders in {station_dir}, skipping.")
        continue

    print("  Found subfolders:", [sf.name for sf in subfolders])

    merged_this_station = []

    # 3. Loop forecast subsets (All, mem1, mem2, mem3, mem4)
    for subfolder in subfolders:
        raw_subset = subfolder.name
        pretty_subset = subset_label(raw_subset)

        # gather all timeseries_*.csv in this subfolder
        files = sorted(subfolder.glob("timeseries_*.csv"))
        if not files:
            print(f"    [warn] No timeseries_*.csv in {subfolder}")
            continue

        rows_list = []

        for fpath in files:
            # extract lead hours from filename
            lead_h = get_lead_hours_from_filename(fpath.name)

            df_fcst = pd.read_csv(fpath, parse_dates=["time"])
            # standardize forecast columns
            df_fcst = df_fcst.rename(columns={
                "time": "valid_time",
                "streamflow": "Q_fcst_cms"
            })
            df_fcst["lead_hours"] = lead_h
            df_fcst["LeadDays"] = df_fcst["lead_hours"] / 24.0
            df_fcst["Subset"] = pretty_subset
            df_fcst["Station"] = stn

            rows_list.append(df_fcst[["valid_time","Q_fcst_cms","lead_hours","LeadDays","Subset","Station"]])

        if not rows_list:
            continue

        fcst_all = pd.concat(rows_list, ignore_index=True)

        # 4. Merge with USGS obs on valid_time
        merged = pd.merge(
            fcst_all,
            usgs_df,
            on="valid_time",
            how="inner"
        )
        if merged.empty:
            print(f"    [warn] No overlap in time for {stn} {pretty_subset}")
            continue

        merged_this_station.append(merged)

    # done looping over members for this station
    if not merged_this_station:
        print(f"[warn] nothing merged for station {stn}")
        continue

    merged_station_full = pd.concat(merged_this_station, ignore_index=True)
    all_merged_allstations.append(merged_station_full)

    # 5. Compute KGE for each subset / lead_hours
    for (subset_name, lead_h), dgrp in merged_station_full.groupby(["Subset","lead_hours"]):
        kge_val = kge(dgrp["Q_fcst_cms"].values, dgrp["Q_USGS_cms"].values)

        all_kge_records.append({
            "Station": stn,
            "Subset": subset_name,
            "LeadHours": lead_h,
            "LeadDays": lead_h / 24.0,
            "KGE": kge_val
        })

# Combine results across stations
kge_table = pd.DataFrame(all_kge_records)
merged_all_df = pd.concat(all_merged_allstations, ignore_index=True) if all_merged_allstations else pd.DataFrame()

########################################
# Save CSV outputs
########################################

kge_csv_path = OUT_DIR / "KGE_by_station_and_lead.csv"
kge_table.to_csv(kge_csv_path, index=False)
print(f"[csv] wrote {kge_csv_path}")

merged_csv_path = OUT_DIR / "Merged_obs_vs_fcst_timeseries.csv"
merged_all_df.to_csv(merged_csv_path, index=False)
print(f"[csv] wrote {merged_csv_path}")

########################################
# Plots
########################################

# 1. KGE vs lead time for each station
for stn, df_stn in kge_table.groupby("Station"):
    plot_kge_vs_lead(df_stn, stn, OUT_DIR)

# 2. QC time series, plotting only ensemble mean
for stn, df_stn in merged_all_df.groupby("Station"):
    plot_sample_timeseries(df_stn, stn, OUT_DIR, lead_hours_list=[24,72,120])

print("Done.")



=== Station 01_Rulo ===
  Found subfolders: ['All', 'mem1', 'mem4', 'mem3', 'mem2']

=== Station 02_St_Joseph ===
  Found subfolders: ['All', 'mem1', 'mem4', 'mem3', 'mem2']

=== Station 03_Kansas_City ===
  Found subfolders: ['All', 'mem1', 'mem4', 'mem3', 'mem2']

=== Station 04_Waverly ===
  Found subfolders: ['All', 'mem1', 'mem4', 'mem3', 'mem2']

=== Station 05_Boonville ===
  Found subfolders: ['All', 'mem1', 'mem4', 'mem3', 'mem2']

=== Station 06_Hermann ===
  Found subfolders: ['All', 'mem1', 'mem4', 'mem3', 'mem2']

=== Station 07_St_Charles ===
  Found subfolders: ['All', 'mem1', 'mem4', 'mem3', 'mem2']

=== Station 08_Grafton ===
  Found subfolders: ['All', 'mem1', 'mem4', 'mem3', 'mem2']

=== Station 09_ST_Louis ===
  Found subfolders: ['All', 'mem1', 'mem4', 'mem3', 'mem2']

=== Station 10_Chester ===
  Found subfolders: ['All', 'mem1', 'mem4', 'mem3', 'mem2']

=== Station 11_Thebes ===
  Found subfolders: ['All']
[csv] wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_