In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

from geospatial_neural_adapter.data.preprocessing import prepare_all_with_scaling
from geospatial_neural_adapter.data.generators import generate_time_synthetic_data


# Global settings & dirs
GLOBAL_SEED = 42
np.random.seed(GLOBAL_SEED)
torch.manual_seed(GLOBAL_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

P = 1
L = 56
H = 24
STEP = H

train_ratio = 0.70
val_ratio = 0.15

PLOTS_DIR = Path("VARX_synth_E_TFTaligned")
FIG1_DIR = PLOTS_DIR / "FIG1_first"
FIG2_DIR = PLOTS_DIR / "FIG2_stitched"
FIG1_DIR.mkdir(parents=True, exist_ok=True)
FIG2_DIR.mkdir(parents=True, exist_ok=True)


# Data generation
N_POINTS = 36
T_TOTAL = 1500
EIGENVALUE = 3.0

NOISE_STD = 0.5
FEATURE_NOISE_STD = 0.2
NON_LINEAR_STRENGTH = 2.5
ETA_RHO = 0.4
F_RHO = 0.8
GLOBAL_MEAN = 50.0

locs = np.linspace(-5.0, 5.0, N_POINTS).astype(np.float32)

_, cont_synth, y_synth = generate_time_synthetic_data(
    locs=locs,
    n_time_steps=T_TOTAL,
    noise_std=NOISE_STD,
    eigenvalue=EIGENVALUE,
    eta_rho=ETA_RHO,
    f_rho=F_RHO,
    global_mean=GLOBAL_MEAN,
    feature_noise_std=FEATURE_NOISE_STD,
    non_linear_strength=NON_LINEAR_STRENGTH,
    seed=GLOBAL_SEED,
)

Y = y_synth.astype(np.float32)
Cont = cont_synth.astype(np.float32)

T_total, N = Y.shape
P_cov = Cont.shape[2]

print("Synthetic E (T=1500) y shape :", Y.shape)
print("Synthetic E (T=1500) cont shape:", Cont.shape)

time_start = pd.Timestamp("2000-01-01 00:00:00")
time_index = pd.date_range(start=time_start, periods=T_total, freq="1H")
freq = "1H"
if len(time_index) != T_total:
    raise RuntimeError("time_index length mismatch")

cat_dummy = np.zeros((T_total, N, 1), dtype=np.int64)


# Split & scaling
train_ds, val_ds, test_ds, preprocessor = prepare_all_with_scaling(
    cat_features=cat_dummy,
    cont_features=Cont,
    targets=Y,
    train_ratio=train_ratio,
    val_ratio=val_ratio,
    feature_scaler_type="standard",
    target_scaler_type="standard",
    fit_on_train_only=True,
)

def stitch(dsets):
    X = np.concatenate([d.tensors[1].detach().cpu().numpy().astype(np.float32) for d in dsets], axis=0)
    y = np.concatenate([d.tensors[2].detach().cpu().numpy().astype(np.float32) for d in dsets], axis=0)
    return X, y

X_s, y_s = stitch([train_ds, val_ds, test_ds])

if X_s.shape != Cont.shape:
    raise ValueError(f"X_s shape {X_s.shape} != Cont {Cont.shape}")
if y_s.shape != Y.shape:
    raise ValueError(f"y_s shape {y_s.shape} != Y {Y.shape}")

cut_train = int(T_total * train_ratio)
cut_val = int(T_total * (train_ratio + val_ratio))
test_start = cut_val

print("\n=== Split (Scenario E, T_TOTAL=1500) ===")
print("Train:", 0, "->", cut_train - 1, "| len =", cut_train)
print("Val  :", cut_train, "->", cut_val - 1, "| len =", cut_val - cut_train)
print("Test :", cut_val, "->", T_total - 1, "| len =", T_total - cut_val)


# VARX design & fit
def build_varx_design(y: np.ndarray, x: np.ndarray, t_start: int, t_end: int, p: int):
    rows = []
    ys = []
    for t in range(t_start, t_end):
        lag_vec = np.concatenate([y[t - i] for i in range(1, p + 1)], axis=0)
        x_vec = x[t].reshape(-1)
        feat = np.concatenate(
            [np.array([1.0], dtype=np.float32), lag_vec.astype(np.float32), x_vec.astype(np.float32)],
            axis=0,
        )
        rows.append(feat)
        ys.append(y[t])
    return np.stack(rows, axis=0), np.stack(ys, axis=0)

Xtr, Ytr = build_varx_design(y=y_s, x=X_s, t_start=P, t_end=cut_train, p=P)
print("\nTrain design (Scenario E, T=1500):", Xtr.shape, "->", Ytr.shape)

Xtr_t = torch.from_numpy(Xtr).double()
Ytr_t = torch.from_numpy(Ytr).double()

B, _, _, _ = torch.linalg.lstsq(Xtr_t, Ytr_t)
B = B.float().cpu().numpy()
print("VARX fitted B shape:", B.shape)


# Rolling (recursive, NO freeze cov)
def qrisk(y_true: np.ndarray, y_pred: np.ndarray, q: float = 0.5, eps: float = 1e-8) -> float:
    y_true = np.asarray(y_true, dtype=np.float64)
    y_pred = np.asarray(y_pred, dtype=np.float64)
    e = y_true - y_pred
    return float(2.0 * np.sum(np.maximum(q * e, (q - 1) * e)) / (np.sum(np.abs(y_true)) + eps))

def roll_varx(y: np.ndarray, x: np.ndarray, B: np.ndarray, start: int, end: int, p: int, H: int, step: int):
    yh_list, yt_list, t0_list = [], [], []
    for t0 in range(start, end - H + 1, step):
        if t0 < p:
            continue
        hist = [y[t0 - i].copy() for i in range(1, p + 1)]
        ytrue = y[t0:t0 + H].copy()
        yhat = np.empty_like(ytrue)
        for k in range(H):
            lag_vec = np.concatenate(hist, axis=0)
            x_vec = x[t0 + k].reshape(-1)
            feat = np.concatenate(
                [np.array([1.0], dtype=np.float32), lag_vec.astype(np.float32), x_vec.astype(np.float32)],
                axis=0,
            )
            pred = feat @ B
            yhat[k] = pred.astype(np.float32)
            hist = [pred] + hist[:p - 1]
        yh_list.append(yhat)
        yt_list.append(ytrue)
        t0_list.append(int(t0))
    if len(yh_list) == 0:
        raise RuntimeError("No rolling windows produced. Check start/end/p/H.")
    return np.stack(yh_list, axis=0), np.stack(yt_list, axis=0), t0_list

start_ctx = max(test_start, L, P)
yhat_roll, ytrue_roll, t0_list = roll_varx(
    y=y_s,
    x=X_s,
    B=B,
    start=start_ctx,
    end=T_total,
    p=P,
    H=H,
    step=STEP,
)

W = yhat_roll.shape[0]
diff = yhat_roll - ytrue_roll
rmse = float(np.sqrt(np.mean(diff ** 2)))
mae = float(np.mean(np.abs(diff)))

print(f"\n=== VARX TFT-aligned (Synthetic Scenario E, T_TOTAL={T_TOTAL}, POOLED, TEST) ===")
print(f"windows={W} | step={STEP} | H={H} | N={N}")
print("RMSE:", rmse)
print("MAE :", mae)

yhat_f = yhat_roll.reshape(-1, N)
ytrue_f = ytrue_roll.reshape(-1, N)

qr_scaled = qrisk(ytrue_f, yhat_f)
ytrue_un = preprocessor.inverse_transform_targets(ytrue_f)
yhat_un = preprocessor.inverse_transform_targets(yhat_f)
qr_unscaled = qrisk(ytrue_un, yhat_un)

print("\n=== P50 q-risk (Synthetic Scenario E, T_TOTAL=1500) ===")
print("scaled  :", qr_scaled)
print("unscaled:", qr_unscaled)


# Plots
t0_first = t0_list[0]

ytrue_stitch = ytrue_roll.reshape(W * H, N)
yhat_stitch = yhat_roll.reshape(W * H, N)
times2 = time_index[t0_first:t0_first + W * H]

for j in range(N):
    plt.figure(figsize=(9, 4), dpi=140)
    plt.plot(range(H), ytrue_roll[0, :, j], "-o", label="True")
    plt.plot(range(H), yhat_roll[0, :, j], "-o", label="Pred")
    plt.title(f"FIG1 VARX Synth E | grid{j:03d}")
    plt.xlabel("step")
    plt.ylabel("Scaled y")
    plt.grid()
    plt.tight_layout()
    plt.savefig(FIG1_DIR / f"FIG1_VARX_grid{j:03d}.png")
    plt.close()

    plt.figure(figsize=(12, 4), dpi=140)
    plt.plot(times2, ytrue_stitch[:, j], label="True")
    plt.plot(times2, yhat_stitch[:, j], label="Pred")
    plt.axvline(time_index[test_start], linestyle="--", linewidth=1, label="TEST start")
    plt.title(f"FIG2 VARX Synth E | grid{j:03d}")
    plt.xlabel("Time")
    plt.ylabel("Scaled y")
    plt.grid()
    plt.legend()
    plt.tight_layout()
    plt.savefig(FIG2_DIR / f"FIG2_VARX_grid{j:03d}.png")
    plt.close()

print("\nAll done (VARX baseline, Scenario E, T_TOTAL=1500).")


âœ… Loaded spatial_utils from: /home/wangxc1117/geospatial-neural-adapter/geospatial_neural_adapter/cpp_extensions/spatial_utils.so
Device: cuda
Synthetic E (T=1500) y shape : (1500, 36)
Synthetic E (T=1500) cont shape: (1500, 36, 3)
Total time steps: 1500
Points: 36
Freq: 1H

=== Split (by time index, Scenario E, T_TOTAL=1500) ===
Train: 0 -> 1049 | len = 1050
Val  : 1050 -> 1274 | len = 225
Test : 1275 -> 1499 | len = 225

Train design (Scenario E, T=1500): (1049, 145) -> (1049, 36)
VARX fitted B shape: (145, 36)

=== VARX TFT-aligned (Synthetic Scenario E, T_TOTAL=1500, POOLED, TEST) ===
windows=9 | step=24 | H=24 | N=36
RMSE: 1.228173017501831
MAE : 0.7897313833236694

=== P50 q-risk (Synthetic Scenario E, T_TOTAL=1500) ===
scaled  : 0.5051114584809987
unscaled: 0.08886847342457452

All done (VARX baseline, Scenario E, T_TOTAL=1500).
