In [None]:
from pathlib import Path
from typing import List, Tuple

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

from geospatial_neural_adapter.data.preprocessing import prepare_all_with_scaling
from geospatial_neural_adapter.data.generators import generate_time_synthetic_data


# Global settings & dirs
GLOBAL_SEED = 42
np.random.seed(GLOBAL_SEED)
torch.manual_seed(GLOBAL_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

train_ratio = 0.70
val_ratio = 0.15
L, H = 56, 24
STEP = H
EPS_STD = 1e-8

PLOTS_DIR = Path("OLS_synth_Bprime_TFTaligned")
FIG1_DIR = PLOTS_DIR / "FIG1_first"
FIG2_DIR = PLOTS_DIR / "FIG2_stitched"
FIG1_DIR.mkdir(parents=True, exist_ok=True)
FIG2_DIR.mkdir(parents=True, exist_ok=True)


# Generate synthetic data
N_POINTS = 36
T_TOTAL = 1500
EIGENVALUE = 3.0

locs = np.linspace(-5.0, 5.0, N_POINTS).astype(np.float32)

cat_synth, cont_synth, y_synth = generate_time_synthetic_data(
    locs=locs,
    n_time_steps=T_TOTAL,
    noise_std=0.5,
    eigenvalue=EIGENVALUE,
    eta_rho=0.6,
    f_rho=0.6,
    global_mean=50.0,
    feature_noise_std=0.2,
    non_linear_strength=1.0,
    seed=GLOBAL_SEED,
)

print("Synthetic B′ y shape :", y_synth.shape)
print("Synthetic B′ cont shape:", cont_synth.shape)

Y = y_synth.astype(np.float32)
Cont = cont_synth.astype(np.float32)

T_total, N = Y.shape
_, _, P_cov = Cont.shape

time_start = pd.Timestamp("2000-01-01 00:00:00")
time_index = pd.date_range(start=time_start, periods=T_total, freq="1H")

print("Total time steps:", T_total)
print("Points:", N)


# Build features
lag1 = np.zeros((T_total, N, 1), dtype=np.float32)
lag1[1:, :, 0] = Y[:-1]

Xfull = np.concatenate([lag1, Cont], axis=2).astype(np.float32)
feat_names_full = ["lag1", "f1", "f2", "f3"]
P_full = Xfull.shape[2]

print("Xfull shape:", Xfull.shape)


# Scaling
cat_dummy = np.zeros((T_total, N, 1), dtype=np.int64)

train_ds, val_ds, test_ds, preprocessor = prepare_all_with_scaling(
    cat_features=cat_dummy,
    cont_features=Xfull,
    targets=Y,
    train_ratio=train_ratio,
    val_ratio=val_ratio,
    feature_scaler_type="standard",
    target_scaler_type="standard",
    fit_on_train_only=True,
)

def stitch(dslist):
    X = np.concatenate([d.tensors[1].cpu().numpy().astype(np.float32) for d in dslist], axis=0)
    y = np.concatenate([d.tensors[2].cpu().numpy().astype(np.float32) for d in dslist], axis=0)
    return X, y

X_s, y_s = stitch([train_ds, val_ds, test_ds])

if X_s.shape != Xfull.shape:
    raise ValueError(f"X_s shape {X_s.shape} != Xfull {Xfull.shape}")
if y_s.shape != Y.shape:
    raise ValueError(f"y_s shape {y_s.shape} != Y {Y.shape}")

cut_train = int(T_total * train_ratio)
cut_val = int(T_total * (train_ratio + val_ratio))
test_start = cut_val

print("\n=== Split (by time index) ===")
print("Train:", 0, "->", cut_train - 1, "| len =", cut_train)
print("Val  :", cut_train, "->", cut_val - 1, "| len =", cut_val - cut_train)
print("Test :", cut_val, "->", T_total - 1, "| len =", T_total - cut_val)


# Drop near-constant features
X_train = X_s[:cut_train]
flat = X_train.reshape(-1, P_full)
stds = flat.std(axis=0)
keep = (stds > EPS_STD)

feat_names_red = [f for f, k in zip(feat_names_full, keep) if k]
X_s_red = X_s[:, :, keep]
p_red = X_s_red.shape[2]

if "lag1" not in feat_names_red:
    raise RuntimeError("lag1 feature was dropped; check EPS_STD / scaling logic.")

IDX_LAG1_RED = feat_names_red.index("lag1")

print("\n=== Feature selection (near-constant drop) ===")
print("stds:", stds)
print("keep mask:", keep)
print("kept features:", feat_names_red)
print("p_red:", p_red)


# OLS fit
Xtr = X_s_red[:cut_train].reshape(-1, p_red)
ytr = y_s[:cut_train].reshape(-1, 1)

Xtr_t = torch.from_numpy(Xtr).to(DEVICE)
ytr_t = torch.from_numpy(ytr).to(DEVICE)

ones = torch.ones((Xtr_t.shape[0], 1), device=DEVICE)
X_aug = torch.cat([ones, Xtr_t], dim=1)

beta_aug = torch.linalg.lstsq(X_aug, ytr_t).solution.squeeze(1)
b0 = float(beta_aug[0].cpu().item())
beta = beta_aug[1:].cpu().numpy().astype(np.float32)

print("\n=== OLS fitted (Scenario B′) ===")
print("p_red =", p_red)
print("Intercept b0:", b0)
print("beta:", beta)


# Rolling OLS
def qrisk(y_true, y_pred, q=0.5, eps=1e-8):
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    e = y_true - y_pred
    return float(
        2.0 * np.sum(np.maximum(q * e, (q - 1) * e))
        / (np.sum(np.abs(y_true)) + eps)
    )

def rolling_ols_tftaligned(X_s_full, y_s_full, start_ctx, end_T, H, step):
    t0_raw = list(range(start_ctx, end_T - H + 1, step))
    yh = []
    yt = []
    t0_list = []

    for t0 in t0_raw:
        prev = y_s_full[t0 - 1].astype(np.float32)
        ytru = y_s_full[t0:t0 + H].astype(np.float32)
        yhat = np.empty((H, prev.shape[0]), dtype=np.float32)

        X_freeze = X_s_full[t0 - 1].copy()

        for k in range(H):
            Xt = X_freeze.copy()
            Xt[:, IDX_LAG1_RED] = prev
            ypred = Xt @ beta.reshape(-1, 1)
            ypred = ypred.reshape(-1) + b0
            yhat[k] = ypred
            prev = ypred

        yh.append(yhat)
        yt.append(ytru)
        t0_list.append(t0)

    return np.stack(yh), np.stack(yt), t0_list

start_ctx = max(test_start, L)
yhat_roll, ytrue_roll, t0_list = rolling_ols_tftaligned(
    X_s_red, y_s, start_ctx, T_total, H, STEP
)

W = yhat_roll.shape[0]
diff = yhat_roll - ytrue_roll
rmse = float(np.sqrt(np.mean(diff ** 2)))
mae = float(np.mean(np.abs(diff)))

print("\n=== OLS TFT-aligned (Synthetic Scenario B′, POOLED, TEST) ===")
print(f"windows={W} | step={STEP} | H={H} | N={N}")
print("RMSE:", rmse)
print("MAE :", mae)

yhat_f = yhat_roll.reshape(-1, N)
ytrue_f = ytrue_roll.reshape(-1, N)

qr_scaled = qrisk(ytrue_f, yhat_f)
ytrue_un = preprocessor.inverse_transform_targets(ytrue_f)
yhat_un = preprocessor.inverse_transform_targets(yhat_f)
qr_unscaled = qrisk(ytrue_un, yhat_un)

print("\n=== P50 q-risk (Synthetic Scenario B′) ===")
print("scaled  :", qr_scaled)
print("unscaled:", qr_unscaled)


# Plots
first0 = t0_list[0]
ytrue_stitch = ytrue_roll.reshape(W * H, N)
yhat_stitch = yhat_roll.reshape(W * H, N)
times2 = time_index[first0:first0 + W * H]

for j in range(N):
    plt.figure(figsize=(9, 4), dpi=140)
    plt.plot(range(H), ytrue_roll[0, :, j], "-o", label="True")
    plt.plot(range(H), yhat_roll[0, :, j], "-o", label="Pred")
    plt.title(f"FIG1 OLS Synth B′ | grid{j:03d}")
    plt.xlabel("step")
    plt.ylabel("Scaled y")
    plt.grid()
    plt.tight_layout()
    plt.savefig(FIG1_DIR / f"FIG1_grid{j:03d}.png")
    plt.close()

    plt.figure(figsize=(12, 4), dpi=140)
    plt.plot(times2, ytrue_stitch[:, j], label="True")
    plt.plot(times2, yhat_stitch[:, j], label="Pred")
    plt.axvline(time_index[test_start], linestyle="--", linewidth=1, label="TEST start")
    plt.title(f"FIG2 OLS Synth B′ | grid{j:03d}")
    plt.xlabel("Time")
    plt.ylabel("Scaled y")
    plt.grid()
    plt.legend()
    plt.tight_layout()
    plt.savefig(FIG2_DIR / f"FIG2_grid{j:03d}.png")
    plt.close()

print("\nAll done.")


✅ Loaded spatial_utils from: /home/wangxc1117/geospatial-neural-adapter/geospatial_neural_adapter/cpp_extensions/spatial_utils.so
Device: cuda
Synthetic B′ y shape : (1500, 36)
Synthetic B′ cont shape: (1500, 36, 3)
Total time steps: 1500
Points: 36
Xfull shape: (1500, 36, 4)

=== Split (by time index) ===
Train: 0 -> 1049 | len = 1050
Val  : 1050 -> 1274 | len = 225
Test : 1275 -> 1499 | len = 225

=== Feature selection (near-constant drop) ===
stds: [0.9999992  0.99999833 0.99999976 0.9999991 ]
keep mask: [ True  True  True  True]
kept features: ['lag1', 'f1', 'f2', 'f3']
p_red: 4

=== OLS fitted (Scenario B′) ===
p_red = 4
Intercept b0: 8.589343125642301e-10
beta: [0.12906636 0.28881532 0.21167853 0.1931144 ]

=== OLS TFT-aligned (Synthetic Scenario B′, POOLED, TEST) ===
windows=9 | step=24 | H=24 | N=36
RMSE: 2.125718355178833
MAE : 1.4474035501480103

=== P50 q-risk (Synthetic Scenario B′) ===
scaled  : 0.9444728352181325
unscaled: 0.07914048871967956

All done (OLS baseline, Scen