In [None]:
import math
from typing import Sequence, List, Tuple
import pandas as pd

# -------- linear interpolation on ascending tables --------
# --- add near top ---
def interp_inverse(x_tab, y_tab, y):
    """Given ascending x_tab -> y_tab(x), return x such that y_tab(x) ≈ y."""
    # We invert by interpolating on y->x with clamping
    n = len(x_tab)
    if n == 0:
        return float("nan")
    # ensure y_tab is ascending too; if not, sort by y
    if not all(y_tab[i] <= y_tab[i+1] for i in range(n-1)):
        # fall back: sort pairs by y
        pairs = sorted(zip(y_tab, x_tab))
        y_tab, x_tab = [p[0] for p in pairs], [p[1] for p in pairs]
        n = len(x_tab)
    if y <= y_tab[0]:
        return float(x_tab[0])
    if y >= y_tab[-1]:
        return float(x_tab[-1])
    lo, hi = 0, n - 1
    while hi - lo > 1:
        mid = (lo + hi) // 2
        if y_tab[mid] <= y:
            lo = mid
        else:
            hi = mid
    y0, y1 = y_tab[lo], y_tab[hi]
    x0, x1 = x_tab[lo], x_tab[hi]
    if y1 == y0:
        return float(x0)
    w = (y - y0) / (y1 - y0)
    return float(x0 + w * (x1 - x0))

def interp_linear(x_tab: Sequence[float], y_tab: Sequence[float], x: float) -> float:
    """Linear interpolate y(x) on ascending x_tab with clamping at ends."""
    n = len(x_tab)
    if n == 0:
        return float("nan")
    # guard for NaNs
    # cast once to float lists (important if read as objects/strings)
    # (callers already pass float lists, but this protects direct calls)
    if x <= x_tab[0]:
        return float(y_tab[0])
    if x >= x_tab[-1]:
        return float(y_tab[-1])
    lo, hi = 0, n - 1
    while hi - lo > 1:
        mid = (lo + hi) // 2
        if x_tab[mid] <= x:
            lo = mid
        else:
            hi = mid
    x0, x1 = x_tab[lo], x_tab[hi]
    y0, y1 = y_tab[lo], y_tab[hi]
    # avoid divide-by-zero if table contains duplicate x
    if x1 == x0:
        return float(y0)
    w = (x - x0) / (x1 - x0)
    return float(y0 + w * (y1 - y0))

# -------- shift accessors --------

def shift_time_based(t_idx: int, shift_series: Sequence[float]) -> float:
    return float(shift_series[t_idx])

# dummy to keep the name present; you said there is no stage-based shift
def shift_stage_based(GH: float, stage_tab: Sequence[float], shift_tab: Sequence[float]) -> float:
    # not used in this workflow
    return 0.0

# -------- Qmodel at a given time t and stage GH --------

def q_model_for_time(
    t_idx: int,
    GH: float,
    cape_stage_t: float,
    rating_stage_tab: Sequence[float],
    rating_q_tab: Sequence[float],
    fall_stage_tab: Sequence[float],
    fall_fr_tab: Sequence[float],
    factor_ratio_tab: Sequence[float],
    factor_val_tab: Sequence[float],
    c_cape: float,
    c_thebes: float,
    use_stage_dependent_shift: bool,
    shift_series: Sequence[float] = None,
    shift_stage_axis: Sequence[float] = None,
    shift_stage_vals: Sequence[float] = None
) -> float:
    # You stated: time-based shift only
    # Measured fall does NOT include shift
    sh = shift_time_based(t_idx, shift_series)

    # shifted stage for rating lookup
    GsH = GH + sh

    # base rated discharge, rated fall, measured fall
    QR  = interp_linear(rating_stage_tab, rating_q_tab, GsH)
    FR  = interp_linear(fall_stage_tab,   fall_fr_tab,   GH)
    FM  = (cape_stage_t + c_cape) - (GH + c_thebes)   # no shift here

    # guards
    if not math.isfinite(FR) or FR <= 0.0:
        return float("nan")

    r   = FM / FR
    fac = interp_linear(factor_ratio_tab, factor_val_tab, r)

    return QR * fac

# -------- bisection solver for GH at a single time --------

def solve_GH_bisection_for_time(
    t_idx: int,
    Q_target: float,
    cape_stage_t: float,
    gh_low: float,
    gh_high: float,
    rating_stage_tab: Sequence[float],
    rating_q_tab: Sequence[float],
    fall_stage_tab: Sequence[float],
    fall_fr_tab: Sequence[float],
    factor_ratio_tab: Sequence[float],
    factor_val_tab: Sequence[float],
    c_cape: float,
    c_thebes: float,
    use_stage_dependent_shift: bool,
    shift_series: Sequence[float] = None,
    shift_stage_axis: Sequence[float] = None,
    shift_stage_vals: Sequence[float] = None,
    tol_stage: float = 1e-3,
    tol_flow: float = 1.0,
    max_iter: int = 60
    ) -> Tuple[float, int, bool]:

    def residual(GH: float) -> float:
        q = q_model_for_time(
            t_idx, GH, cape_stage_t,
            rating_stage_tab, rating_q_tab,
            fall_stage_tab, fall_fr_tab,
            factor_ratio_tab, factor_val_tab,
            c_cape, c_thebes,
            use_stage_dependent_shift,
            shift_series, shift_stage_axis, shift_stage_vals
        )
        return q - Q_target

    f_lo = residual(gh_low)
    f_hi = residual(gh_high)

    # bracket check
    bracketed = (math.isfinite(f_lo) and math.isfinite(f_hi) and (f_lo == 0 or f_hi == 0 or f_lo * f_hi < 0))

    if not bracketed:
        # optional diagnostic print; comment if noisy
        print(f"[t={t_idx}] not bracketed: f(lo)={f_lo:.3f}, f(hi)={f_hi:.3f}, "
              f"lo={gh_low:.3f}, hi={gh_high:.3f}, Q_target={Q_target:.3f}")
        # return the bound that is closer in residual
        if not math.isfinite(f_lo) and math.isfinite(f_hi):
            return gh_high, 0, False
        if not math.isfinite(f_hi) and math.isfinite(f_lo):
            return gh_low, 0, False
        return (gh_low, 0, False) if abs(f_lo) <= abs(f_hi) else (gh_high, 0, False)

    lo, hi = gh_low, gh_high
    for it in range(1, max_iter + 1):
        mid = 0.5 * (lo + hi)
        f_mid = residual(mid)

        # convergence tests
        if abs(hi - lo) < tol_stage:
            return mid, it, True
        if math.isfinite(f_mid) and abs(f_mid) < tol_flow:
            return mid, it, True

        # bisection step
        if f_lo * f_mid <= 0:   # root in [lo, mid]
            hi, f_hi = mid, f_mid
        else:                   # root in [mid, hi]
            lo, f_lo = mid, f_mid

    # max iterations reached
    return 0.5 * (lo + hi), max_iter, True

# -------- driver over the full time series --------

def solve_GH_timeseries(
    Q_fore: Sequence[float],
    Cape_stage: Sequence[float],
    gh_bounds: Tuple[float, float],
    rating_stage_tab: Sequence[float],
    rating_q_tab: Sequence[float],
    fall_stage_tab: Sequence[float],
    fall_fr_tab: Sequence[float],
    factor_ratio_tab: Sequence[float],
    factor_val_tab: Sequence[float],
    c_cape: float,
    c_thebes: float,
    use_stage_dependent_shift: bool,
    shift_series: Sequence[float] = None,
    shift_stage_axis: Sequence[float] = None,
    shift_stage_vals: Sequence[float] = None
) -> List[float]:

    gh_min, gh_max = gh_bounds
    GH_out = []
    for t in range(len(Q_fore)):
        GH_t, _, _ = solve_GH_bisection_for_time(
            t_idx=t,
            Q_target=float(Q_fore[t]),
            cape_stage_t=float(Cape_stage[t]),
            gh_low=gh_min,
            gh_high=gh_max,
            rating_stage_tab=rating_stage_tab,
            rating_q_tab=rating_q_tab,
            fall_stage_tab=fall_stage_tab,
            fall_fr_tab=fall_fr_tab,
            factor_ratio_tab=factor_ratio_tab,
            factor_val_tab=factor_val_tab,
            c_cape=c_cape,
            c_thebes=c_thebes,
            use_stage_dependent_shift=use_stage_dependent_shift,
            shift_series=shift_series,
            shift_stage_axis=shift_stage_axis,
            shift_stage_vals=shift_stage_vals
        )
        GH_out.append(GH_t)
        # tighten next-step window around last solution for speed
        gh_min = max(gh_bounds[0], GH_t - 2.0)
        gh_max = min(gh_bounds[1], GH_t + 2.0)
    return GH_out

# ------------------------- main: reads CSVs and runs solver -------------------------

if __name__ == "__main__":
    # Read tables
    rating_df = pd.read_csv("rating_table.csv")
    fall_df   = pd.read_csv("fall_rated.csv")
    factor_df = pd.read_csv("factor_table.csv")

    # Sort to be safe
    rating_df = rating_df.sort_values("GsH_ft")
    fall_df   = fall_df.sort_values("GH_ft")
    factor_df = factor_df.sort_values("ratio")

    # Cast to float lists (prevents object/str surprises in interp)
    rating_stage_tab = rating_df["GsH_ft"].astype(float).to_list()
    rating_q_tab     = rating_df["Q_cfs"].astype(float).to_list()
    fall_stage_tab   = fall_df["GH_ft"].astype(float).to_list()
    fall_fr_tab      = fall_df["Fr_ft"].astype(float).to_list()
    factor_ratio_tab = factor_df["ratio"].astype(float).to_list()
    factor_val_tab   = factor_df["factor"].astype(float).to_list()

    # Time-based shift only
    use_stage_dependent_shift = False

    shift_ser    = pd.read_csv("shift_series.csv", parse_dates=["time"]).sort_values("time")
    shift_series = shift_ser["shift_ft"].astype(float).to_list()
    shift_stage_axis, shift_stage_vals = None, None

    # Forecast series
    qf   = pd.read_csv("nwm_thebes_q.csv", parse_dates=["time"]).sort_values("time")
    cape = pd.read_csv("cape_stage.csv",    parse_dates=["time"]).sort_values("time")

    # Align by time
    merged = qf.merge(cape, on="time", how="inner")
    if not use_stage_dependent_shift:
        merged = merged.merge(shift_ser[["time", "shift_ft"]], on="time", how="left")

    # If your NWM Q is in m^3/s, convert to cfs here. Uncomment the next line if needed:
    # merged["Q_cfs"] = merged["Q_cfs"] * 35.3147

    Q_fore     = merged["Q_cfs"].astype(float).to_list()
    Cape_stage = merged["Cape_GH_ft"].astype(float).to_list()
    if not use_stage_dependent_shift:
        shift_series = merged["shift_ft"].fillna(0.0).astype(float).to_list()

    # Datums (same as your Excel)
    Ccape   = 304.27
    Cthebes = 299.70

    # Bounds for GH from table axes with small padding
    # Bounds for GH from global table axes (safety net)
    gh_global_min = float(min(fall_stage_tab[0], rating_stage_tab[0])) - 0.5
    gh_global_max = float(max(fall_stage_tab[-1], rating_stage_tab[-1])) + 0.5

    GH_out = []
    for t in range(len(Q_fore)):
        # 1) initial guess like your "GH from RC for initial value"
        GsH0 = interp_inverse(rating_stage_tab, rating_q_tab, Q_fore[t])  # stage (GsH) for target Q on base RC
        sh_t = shift_series[t]
        GH0  = GsH0 - sh_t                                             # convert to GH (unshifted)
        # 2) local bracket around the guess (tune window as you like)
        gh_min = max(gh_global_min, GH0 - 1.5)
        gh_max = min(gh_global_max, GH0 + 1.5)
        # 3) solve
        GH_t, _, _ = solve_GH_bisection_for_time(
            t_idx=t,
            Q_target=float(Q_fore[t]),
            cape_stage_t=float(Cape_stage[t]),
            gh_low=gh_min,
            gh_high=gh_max,
            rating_stage_tab=rating_stage_tab,
            rating_q_tab=rating_q_tab,
            fall_stage_tab=fall_stage_tab,
            fall_fr_tab=fall_fr_tab,
            factor_ratio_tab=factor_ratio_tab,
            factor_val_tab=factor_val_tab,
            c_cape=Ccape,
            c_thebes=Cthebes,
            use_stage_dependent_shift=False,
            shift_series=shift_series
        )
        GH_out.append(GH_t)

    # write results
    out = merged.copy()
    out["GH_pred_ft"] = GH_out
    out.to_csv("predicted_stage.csv", index=False)
    print("Wrote predicted_stage.csv with", len(out), "rows")

Wrote predicted_stage.csv with 5 rows
