In [None]:
from pathlib import Path
from typing import Tuple, Optional, Dict, List, Any
import hashlib

import numpy as np
import pandas as pd
import torch

from geospatial_neural_adapter.data.preprocessing import prepare_all_with_scaling


# Global settings & dirs
GLOBAL_SEED = 42
np.random.seed(GLOBAL_SEED)
torch.manual_seed(GLOBAL_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

train_ratio = 0.70
val_ratio = 0.15
L, H = 56, 24
STEP = H
EPS_STD = 1e-8

SCENARIO_KEY = "C"
SC: Dict[str, float] = {
    "noise_std": 0.5,
    "feature_noise_std": 0.2,
    "non_linear_strength": 2.0,
    "eta_rho": 0.8,
    "f_rho": 0.6,
}

Y_RHO = 0.85
NAR_SCALE_WINDOW = 50
Y_GAMMA_LIST = [0.1, 0.2]

GATE_TYPE = "A2_global_v2"
GATE_THRESHOLD = 0.0
GATE_LOW = 0.5
GATE_HIGH = 1.0

N_POINTS = 36
T_TOTAL = 1500
EIGENVALUE = 3.0
locs = np.linspace(-5.0, 5.0, N_POINTS).astype(np.float32)


# Data generation
def _generate_spatial_basis(locations: np.ndarray) -> np.ndarray:
    spatial_basis = np.exp(-(locations ** 2))[:, None]
    spatial_basis /= (np.linalg.norm(spatial_basis) + 1e-12)
    return spatial_basis


def generate_time_NAR_synthetic_data(
    locs: np.ndarray,
    n_time_steps: int,
    noise_std: float,
    eigenvalue: float,
    eta_rho: float = 0.8,
    f_rho: float = 0.6,
    global_mean: float = 50.0,
    feature_noise_std: float = 0.0,
    non_linear_strength: float = 0.0,
    seed: Optional[int] = None,
    y_rho: float = 0.4,
    y_gamma: float = 1.2,
    nar_scale_window: int = 200,
    gate_type: str = "A2_global_v2",
    gate_threshold: float = 0.0,
    gate_low: float = 0.5,
    gate_high: float = 1.0,
    debug_gate_ratio: bool = True,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    if seed is not None:
        np.random.seed(seed)

    locs = np.asarray(locs, dtype=np.float32)
    if locs.ndim != 1:
        raise ValueError(f"expects 1D locs (N,), got {locs.shape}")

    N = len(locs)
    p = 3

    cont = np.random.randn(n_time_steps, N, p).astype(np.float32)

    spatial_pattern = np.sin(locs * np.pi / 5.0).astype(np.float32)
    temporal_trend = np.linspace(0, 2 * np.pi, n_time_steps).astype(np.float32)

    cont[:, :, 0] += (np.outer(temporal_trend, spatial_pattern) * 0.5).astype(np.float32)
    cont[:, :, 1] += (np.outer(temporal_trend, np.ones(N, dtype=np.float32)) * 0.3).astype(np.float32)
    cont[:, :, 2] += (cont[:, :, 0] * cont[:, :, 1] * 0.2).astype(np.float32)

    if feature_noise_std > 0:
        cont += (np.random.randn(*cont.shape) * feature_noise_std).astype(np.float32)

    spatial_basis = _generate_spatial_basis(locs).reshape(-1).astype(np.float32)

    spatial_weights = np.zeros(n_time_steps, dtype=np.float32)
    trend_drift = np.zeros(n_time_steps, dtype=np.float32)

    spatial_weights[0] = float(np.random.randn() * 1.0)
    trend_drift[0] = float(np.random.randn() * 1.0)

    for t in range(1, n_time_steps):
        spatial_weights[t] = eta_rho * spatial_weights[t - 1] + float(np.random.randn() * 0.1)
        trend_drift[t] = f_rho * trend_drift[t - 1] + float(np.random.randn() * 0.1)

    y = np.zeros((n_time_steps, N), dtype=np.float32)
    eps = 1e-6

    def _feat_term(t: int) -> np.ndarray:
        feat = (0.3 * cont[t, :, 0] + 0.4 * cont[t, :, 1] + 0.2 * cont[t, :, 2]).astype(np.float32)
        if non_linear_strength > 0:
            feat = (
                feat
                + non_linear_strength * (cont[t, :, 0] ** 2)
                + non_linear_strength * 0.5 * cont[t, :, 0] * cont[t, :, 1]
                + non_linear_strength * 0.3 * np.sin(cont[t, :, 2])
            ).astype(np.float32)
        return feat

    trend0 = global_mean + float(trend_drift[0])
    spatial0 = eigenvalue * float(spatial_weights[0]) * spatial_basis
    noise0 = (np.random.randn(N) * noise_std).astype(np.float32)
    y[0] = (trend0 + spatial0 + _feat_term(0) + noise0).astype(np.float32)

    gate_cnt = 0.0

    for t in range(1, n_time_steps):
        trend = global_mean + float(trend_drift[t])
        spatial = eigenvalue * float(spatial_weights[t]) * spatial_basis
        noise = (np.random.randn(N) * noise_std).astype(np.float32)

        prev = y[t - 1]
        w0 = max(0, t - nar_scale_window)

        mu = float(np.mean(y[w0:t]))
        scale = float(np.std(y[w0:t]) + eps)
        z_global = float((np.mean(prev) - mu) / scale)

        gate = 1.0 if (z_global > gate_threshold) else 0.0
        gate_cnt += gate

        rho_eff = y_rho * (gate_low + (gate_high - gate_low) * gate)
        rec = (rho_eff * prev + y_gamma * np.tanh(prev / scale)).astype(np.float32)

        y[t] = (trend + spatial + _feat_term(t) + rec + noise).astype(np.float32)

    if debug_gate_ratio:
        gate_ratio = gate_cnt / float(n_time_steps - 1)
        print(
            f"[Gate {gate_type}] gate_ratio={gate_ratio:.4f} | "
            f"thr={gate_threshold} | low={gate_low} | high={gate_high} | win={nar_scale_window}"
        )

    cat = np.zeros((n_time_steps, N, 0), dtype=np.int64)
    return cat, cont, y


# Utils
def qrisk(y_true, y_pred, q=0.5, eps=1e-8) -> float:
    y_true = np.asarray(y_true, dtype=np.float64)
    y_pred = np.asarray(y_pred, dtype=np.float64)
    e = y_true - y_pred
    return float(2.0 * np.sum(np.maximum(q * e, (q - 1) * e)) / (np.sum(np.abs(y_true)) + eps))


def md5_of_array(a: np.ndarray) -> str:
    a = np.asarray(a)
    b = a.astype(np.float32, copy=False).tobytes(order="C")
    return hashlib.md5(b).hexdigest()


def stitch(dslist):
    X = np.concatenate([d.tensors[1].detach().cpu().numpy().astype(np.float32) for d in dslist], axis=0)
    y = np.concatenate([d.tensors[2].detach().cpu().numpy().astype(np.float32) for d in dslist], axis=0)
    return X, y


# OLS
def fit_ols1_per_site_pastcov_diag(
    y_s: np.ndarray,
    X_s_cov: np.ndarray,
    cut_train: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    Y = y_s[1:cut_train].astype(np.float32)
    Yprev_diag = y_s[0:cut_train - 1].astype(np.float32)
    Xprev = X_s_cov[0:cut_train - 1].astype(np.float32)

    Tm1, N = Y.shape
    Pcov = Xprev.shape[2]
    ones = np.ones((Tm1, 1), dtype=np.float32)

    c = np.zeros(N, dtype=np.float32)
    a = np.zeros(N, dtype=np.float32)
    B = np.zeros((N, Pcov), dtype=np.float32)

    for i in range(N):
        Xi = np.concatenate([ones, Yprev_diag[:, i:i + 1], Xprev[:, i, :]], axis=1).astype(np.float32)
        yi = Y[:, i:i + 1].astype(np.float32)

        Xi_t = torch.from_numpy(Xi).to(DEVICE)
        yi_t = torch.from_numpy(yi).to(DEVICE)

        bi = torch.linalg.lstsq(Xi_t, yi_t).solution.squeeze(1)
        bi = bi.detach().cpu().numpy().astype(np.float32)

        c[i] = bi[0]
        a[i] = bi[1]
        B[i, :] = bi[2:]

    return c, a, B


def rolling_ols1_tftaligned_pastcov_frozen(
    y_s_full: np.ndarray,
    X_s_cov_full: np.ndarray,
    c: np.ndarray,
    a: np.ndarray,
    B: np.ndarray,
    start_ctx: int,
    end_T: int,
    H: int,
    step: int,
) -> Tuple[np.ndarray, np.ndarray, List[int]]:
    t0_raw = list(range(start_ctx, end_T - H + 1, step))
    yh, yt, t0_list = [], [], []

    for t0 in t0_raw:
        if t0 <= 0:
            continue

        prev = y_s_full[t0 - 1].astype(np.float32)
        x_fixed = X_s_cov_full[t0 - 1].astype(np.float32)

        ytru = y_s_full[t0:t0 + H].astype(np.float32)
        yhat = np.empty_like(ytru)

        bx = np.sum(B * x_fixed, axis=1).astype(np.float32)

        for k in range(H):
            pred = (c + a * prev + bx).astype(np.float32)
            yhat[k] = pred
            prev = pred

        yh.append(yhat)
        yt.append(ytru)
        t0_list.append(int(t0))

    if len(yh) == 0:
        raise RuntimeError("No rolling windows produced. Check start_ctx / end_T / H / step.")

    return np.stack(yh, axis=0), np.stack(yt, axis=0), t0_list


# Run
print("\n=== RUN (OLS, Scenario C only, Gate A2 v2 centered) ===")
print("Scenario:", SCENARIO_KEY, "|", SC)
print("NAR params:", {"y_rho": Y_RHO, "nar_scale_window": NAR_SCALE_WINDOW, "gate_type": GATE_TYPE})
print("Y_GAMMA_LIST:", Y_GAMMA_LIST)
print(f"Gate A2 v2 params: thr={GATE_THRESHOLD}, low={GATE_LOW}, high={GATE_HIGH}")

ols_results: List[Dict[str, Any]] = []

for Y_GAMMA in Y_GAMMA_LIST:
    print("\n============================================================")
    print(f"SCENARIO={SCENARIO_KEY} | y_rho={Y_RHO} | win={NAR_SCALE_WINDOW} | y_gamma={Y_GAMMA}")
    print("============================================================")

    _, cont_synth, y_synth = generate_time_NAR_synthetic_data(
        locs=locs,
        n_time_steps=T_TOTAL,
        noise_std=SC["noise_std"],
        eigenvalue=EIGENVALUE,
        eta_rho=SC["eta_rho"],
        f_rho=SC["f_rho"],
        global_mean=50.0,
        feature_noise_std=SC["feature_noise_std"],
        non_linear_strength=SC["non_linear_strength"],
        seed=GLOBAL_SEED,
        y_rho=float(Y_RHO),
        y_gamma=float(Y_GAMMA),
        nar_scale_window=int(NAR_SCALE_WINDOW),
        gate_type=GATE_TYPE,
        gate_threshold=float(GATE_THRESHOLD),
        gate_low=float(GATE_LOW),
        gate_high=float(GATE_HIGH),
        debug_gate_ratio=True,
    )

    Y = y_synth.astype(np.float32)
    Cont = cont_synth.astype(np.float32)
    T_total, N = Y.shape

    y_mean = float(np.mean(Y))
    y_std = float(np.std(Y))
    y_mean_abs = float(np.mean(np.abs(Y)))
    y_md5 = md5_of_array(Y)
    print("Synthetic y shape:", Y.shape)
    print("Synthetic cont shape:", Cont.shape)
    print("y stats (unscaled) mean/std/mean|y|:", y_mean, y_std, y_mean_abs)
    print("y_md5:", y_md5)

    lag1 = np.zeros((T_total, N, 1), dtype=np.float32)
    lag1[1:, :, 0] = Y[:-1]
    Xfull = np.concatenate([lag1, Cont], axis=2).astype(np.float32)
    feat_names_full = ["lag1", "f1", "f2", "f3"]
    P_full = Xfull.shape[2]

    cat_dummy = np.zeros((T_total, N, 1), dtype=np.int64)
    train_ds, val_ds, test_ds, preprocessor = prepare_all_with_scaling(
        cat_features=cat_dummy,
        cont_features=Xfull,
        targets=Y,
        train_ratio=train_ratio,
        val_ratio=val_ratio,
        feature_scaler_type="standard",
        target_scaler_type="standard",
        fit_on_train_only=True,
    )

    X_s, y_s = stitch([train_ds, val_ds, test_ds])

    cut_train = int(T_total * train_ratio)
    cut_val = int(T_total * (train_ratio + val_ratio))
    test_start = cut_val

    X_train = X_s[:cut_train]
    flat = X_train.reshape(-1, P_full)
    stds = flat.std(axis=0)
    keep = (stds > EPS_STD)

    feat_names_red = [f for f, k in zip(feat_names_full, keep) if k]
    X_s_red = X_s[:, :, keep].astype(np.float32)

    if "lag1" not in feat_names_red:
        raise RuntimeError("lag1 feature was dropped; check EPS_STD / scaling logic.")

    cov_idx_red = [i for i, n in enumerate(feat_names_red) if n != "lag1"]
    X_s_cov = X_s_red[:, :, cov_idx_red].astype(np.float32)
    Pcov = X_s_cov.shape[2]
    if Pcov == 0:
        raise RuntimeError("No covariates kept. This OLS version expects past covariates kept.")

    c_ols, a_ols, B_ols = fit_ols1_per_site_pastcov_diag(y_s, X_s_cov, cut_train=cut_train)

    start_ctx = max(test_start, L)
    yhat_roll, ytrue_roll, _ = rolling_ols1_tftaligned_pastcov_frozen(
        y_s_full=y_s,
        X_s_cov_full=X_s_cov,
        c=c_ols,
        a=a_ols,
        B=B_ols,
        start_ctx=start_ctx,
        end_T=T_total,
        H=H,
        step=STEP,
    )

    W = int(yhat_roll.shape[0])
    print("\n=== Rolling non-overlap (POOLED, TEST) ===")
    print(f"windows={W} | step={H} | each window predicts H={H} | points={N}")

    diff = yhat_roll - ytrue_roll
    rmse = float(np.sqrt(np.mean(diff ** 2)))
    mae = float(np.mean(np.abs(diff)))
    print("[OLS] RMSE:", rmse)
    print("[OLS] MAE :", mae)

    yhat_f = yhat_roll.reshape(-1, N)
    ytrue_f = ytrue_roll.reshape(-1, N)

    qr_scaled = qrisk(ytrue_f, yhat_f)

    ytrue_un = preprocessor.inverse_transform_targets(ytrue_f)
    yhat_un = preprocessor.inverse_transform_targets(yhat_f)
    qr_unscaled = qrisk(ytrue_un, yhat_un)

    print("\n=== P50 q-risk (ROLLING, TEST) ===")
    print("scaled  :", qr_scaled)
    print("unscaled:", qr_unscaled)

    ols_results.append({
        "scenario": SCENARIO_KEY,
        "noise_std": float(SC["noise_std"]),
        "feature_noise_std": float(SC["feature_noise_std"]),
        "non_linear_strength": float(SC["non_linear_strength"]),
        "eta_rho": float(SC["eta_rho"]),
        "f_rho": float(SC["f_rho"]),
        "y_rho": float(Y_RHO),
        "y_gamma": float(Y_GAMMA),
        "nar_scale_window": int(NAR_SCALE_WINDOW),
        "gate_type": GATE_TYPE,
        "gate_threshold": float(GATE_THRESHOLD),
        "gate_low": float(GATE_LOW),
        "gate_high": float(GATE_HIGH),
        "y_mean_unscaled": y_mean,
        "y_std_unscaled": y_std,
        "y_meanabs_unscaled": y_mean_abs,
        "y_md5": y_md5,
        "rmse": rmse,
        "mae": mae,
        "qrisk_p50_scaled": qr_scaled,
        "qrisk_p50_unscaled": qr_unscaled,
        "windows": W,
        "L": int(L),
        "H": int(H),
        "T_total": int(T_total),
        "N_points": int(N),
        "eigenvalue": float(EIGENVALUE),
        "global_seed": int(GLOBAL_SEED),
    })


# Save
df_ols = pd.DataFrame(ols_results)
out_ols = Path("OLS_C_y0p85_w50_g0p1_0p2_A2v2.csv").resolve()
df_ols.to_csv(out_ols, index=False, encoding="utf-8-sig")

print("\n============================================================")
print("Saved OLS results to:", out_ols)
print(df_ols)
print("============================================================")
print("ALL RUNS DONE")


âœ… Loaded spatial_utils from: /home/wangxc1117/geospatial-neural-adapter/geospatial_neural_adapter/cpp_extensions/spatial_utils.so
Device: cuda

=== RUN (OLS, Scenario C only, Gate A2 v2 centered) ===
Scenario: C | {'noise_std': 0.5, 'feature_noise_std': 0.2, 'non_linear_strength': 2.0, 'eta_rho': 0.8, 'f_rho': 0.6}
NAR params: {'y_rho': 0.85, 'nar_scale_window': 50, 'gate_type': 'A2_global_v2'}
Y_GAMMA_LIST: [0.1, 0.2]
Gate A2 v2 params: thr=0.0, low=0.5, high=1.0

SCENARIO=C | y_rho=0.85 | win=50 | y_gamma=0.1
[Gate A2_global_v2] gate_ratio=0.5997 | thr=0.0 | low=0.5 | high=1.0 | win=50
Synthetic y shape: (1500, 36)
Synthetic cont shape: (1500, 36, 3)
y stats (unscaled) mean/std/mean|y|: 250.49603271484375 131.10008239746094 250.49603271484375
y_md5: 2f25135ad6de3de96dc68bab1956f95a

=== Rolling non-overlap (POOLED, TEST) ===
windows=9 | step=24 | each window predicts H=24 | points=36
[OLS] RMSE: 0.8124235272407532
[OLS] MAE : 0.52541583776474

=== P50 q-risk (ROLLING, TEST) ===
sca