In [5]:
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

########################################
# User config
########################################

out_dir = Path("/media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/Hypothesis")

stl_file     = out_dir / "../09_ST_Louis_All_q_tolerance_25.0_day_bracket_1000.csv"
chester_file = out_dir / "../10_Chester_All_q_tolerance_25.0_day_bracket_1000.csv"

obs_col = "stage_m_USGS_nearest"
time_col = "time"

# Horizons of interest (hours as strings with zero padding to 3 digits)
# Around 2 days for St. Louis
stl_horizons_hours = [6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72]

# Around 5 days for Chester
chs_horizons_hours = [36, 42, 48, 54, 60, 66,72, 78, 84, 90, 96, 102, 108, 114, 120]

# lag search settings
max_lag_days = 5
lag_step_hours = 6
########################################


def load_error_series(csv_path, horizon_col, obs_col, time_col):
    """
    Return forecast error (pred - obs) as a pandas Series indexed by timestamp.
    """
    df = pd.read_csv(csv_path, parse_dates=[time_col]).set_index(time_col)
    if horizon_col not in df.columns:
        raise KeyError(f"{horizon_col} not found in {csv_path.name}")
    if obs_col not in df.columns:
        raise KeyError(f"{obs_col} not found in {csv_path.name}")
    err = df[horizon_col] - df[obs_col]
    err = err.dropna().sort_index()
    return err


def pearson_r(x, y):
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)
    if x.size < 2 or y.size < 2:
        return np.nan
    r = np.corrcoef(x, y)[0, 1]
    return float(r)


def lagged_corr(upstream_series,
                downstream_series,
                max_lag_days=5,
                lag_step_hours=6):
    """
    For a pair of error series (same units), sweep lags by shifting downstream_series
    forward/backward in time. Positive lag means downstream is moved forward
    (arrives later). Returns a DataFrame of lag_hours, corr, N.
    """
    results = []
    max_lag_hours = max_lag_days * 24
    lag_hours_list = np.arange(-max_lag_hours,
                               max_lag_hours + 0.1,
                               lag_step_hours)

    for lag_h in lag_hours_list:
        shifted = downstream_series.copy()
        shifted.index = shifted.index + pd.Timedelta(hours=lag_h)

        merged = pd.concat(
            [upstream_series.rename("up"),
             shifted.rename("down")],
            axis=1
        ).dropna()

        if merged.empty:
            r = np.nan
            n = 0
        else:
            r = pearson_r(merged["up"].values, merged["down"].values)
            n = len(merged)

        results.append({"lag_hours": lag_h, "corr": r, "N": n})

    return pd.DataFrame(results)


def best_lag_info(lag_df):
    """Given the output of lagged_corr, return best lag row."""
    lag_df_valid = lag_df.dropna(subset=["corr"])
    if lag_df_valid.empty:
        return np.nan, np.nan, 0
    idx_best = lag_df_valid["corr"].idxmax()
    row = lag_df_valid.loc[idx_best]
    return float(row["lag_hours"]), float(row["corr"]), int(row["N"])


def main():
    # Preload the two station dataframes (to avoid re-reading file every loop)
    stl_df = pd.read_csv(stl_file, parse_dates=[time_col]).set_index(time_col).sort_index()
    chs_df = pd.read_csv(chester_file, parse_dates=[time_col]).set_index(time_col).sort_index()

    # We'll store results for every combination
    records = []

    for h_up in stl_horizons_hours:
        col_up = f"{h_up:03d}"
        if col_up not in stl_df.columns:
            print(f"[WARN] {col_up} not found in {stl_file.name}, skipping")
            continue
        stl_err = (stl_df[col_up] - stl_df[obs_col]).dropna().sort_index()

        for h_down in chs_horizons_hours:
            col_down = f"{h_down:03d}"
            if col_down not in chs_df.columns:
                print(f"[WARN] {col_down} not found in {chester_file.name}, skipping")
                continue
            chs_err = (chs_df[col_down] - chs_df[obs_col]).dropna().sort_index()

            # compute lag correlation curve
            lag_df = lagged_corr(
                upstream_series=stl_err,
                downstream_series=chs_err,
                max_lag_days=max_lag_days,
                lag_step_hours=lag_step_hours
            )

            # get best lag
            lag_hours, max_corr, n_pairs = best_lag_info(lag_df)

            records.append({
                "STL_horizon_hr": h_up,
                "CHS_horizon_hr": h_down,
                "best_lag_hours": lag_hours,
                "best_lag_days": lag_hours/24.0 if not np.isnan(lag_hours) else np.nan,
                "max_corr": max_corr,
                "N_overlap": n_pairs
            })

    result_df = pd.DataFrame.from_records(records)

    # save numeric summary
    matrix_csv = out_dir / "lag_corr_matrix_STL_vs_CHS.csv"
    result_df.to_csv(matrix_csv, index=False)
    print(f"Wrote summary table: {matrix_csv}")

    # pivot for heatmap of max_corr
    heat_corr = result_df.pivot(
        index="STL_horizon_hr",
        columns="CHS_horizon_hr",
        values="max_corr"
    )

    # pivot for best lag (days)
    heat_lag = result_df.pivot(
        index="STL_horizon_hr",
        columns="CHS_horizon_hr",
        values="best_lag_days"
    )

    # quick heatmap of correlation strength
    fig1, ax1 = plt.subplots(figsize=(8,6))
    im1 = ax1.imshow(heat_corr, aspect="auto", origin="lower",cmap="coolwarm")
    ax1.set_xticks(range(len(heat_corr.columns)))
    ax1.set_xticklabels([f"{c/24:.1f}d" for c in heat_corr.columns], rotation=45)
    ax1.set_yticks(range(len(heat_corr.index)))
    ax1.set_yticklabels([f"{r/24:.1f}d" for r in heat_corr.index])
    ax1.set_xlabel("Chester lead time (days)")
    ax1.set_ylabel("St. Louis lead time (days)")
    ax1.set_title("Max correlation (STL error vs CHS error)")
    cbar1 = fig1.colorbar(im1, ax=ax1)
    cbar1.set_label("corr (Pearson r)")
    fig1.tight_layout()
    heatmap_corr_png = out_dir / "lag_corr_heatmap_corr.png"
    fig1.savefig(heatmap_corr_png, dpi=200)
    plt.close(fig1)
    print(f"Wrote {heatmap_corr_png}")

    # heatmap of best lag in days (travel time)
    fig2, ax2 = plt.subplots(figsize=(8,6))
    im2 = ax2.imshow(heat_lag, aspect="auto", origin="lower")
    ax2.set_xticks(range(len(heat_lag.columns)))
    ax2.set_xticklabels([f"{c/24:.1f}d" for c in heat_lag.columns], rotation=45)
    ax2.set_yticks(range(len(heat_lag.index)))
    ax2.set_yticklabels([f"{r/24:.1f}d" for r in heat_lag.index])
    ax2.set_xlabel("Chester lead time (days)")
    ax2.set_ylabel("St. Louis lead time (days)")
    ax2.set_title("Best lag in days (positive = CHS later)")
    cbar2 = fig2.colorbar(im2, ax2)
    cbar2.set_label("lag (days)")
    fig2.tight_layout()
    heatmap_lag_png = out_dir / "lag_corr_heatmap_lagdays.png"
    fig2.savefig(heatmap_lag_png, dpi=200)
    plt.close(fig2)
    print(f"Wrote {heatmap_lag_png}")

    # Bonus: scatter of best lag vs Chester horizon, filtered to strong correlations
    strong = result_df[result_df["max_corr"] > 0.4].copy()
    if not strong.empty:
        fig3, ax3 = plt.subplots(figsize=(8,4))
        ax3.scatter(strong["CHS_horizon_hr"]/24.0,
                    strong["best_lag_days"],
                    s=40,
                    alpha=0.7)
        ax3.set_xlabel("Chester lead time (days)")
        ax3.set_ylabel("Best lag (days)")
        ax3.set_title("Lag vs Chester horizon (only pairs with corr > 0.4)")
        ax3.grid(True, alpha=0.3)
        fig3.tight_layout()
        scatter_png = out_dir / "lag_corr_scatter_strong.png"
        fig3.savefig(scatter_png, dpi=200)
        plt.close(fig3)
        print(f"Wrote {scatter_png}")
    else:
        print("No strong correlation pairs above threshold, skip scatter plot.")


if __name__ == "__main__":
    main()


Wrote summary table: /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/Hypothesis/lag_corr_matrix_STL_vs_CHS.csv
Wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/Hypothesis/lag_corr_heatmap_corr.png
Wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/Hypothesis/lag_corr_heatmap_lagdays.png
Wrote /media/12TB/Sujan/NWM/Codes/LRF_RC/NWM_Q_to_GH/Hypothesis/lag_corr_scatter_strong.png
