In [None]:
from pathlib import Path
from typing import Tuple

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

from geospatial_neural_adapter.data.preprocessing import prepare_all_with_scaling
from geospatial_neural_adapter.data.generators import generate_time_synthetic_data


# Global settings & dirs
GLOBAL_SEED = 42
np.random.seed(GLOBAL_SEED)
torch.manual_seed(GLOBAL_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

P = 1
L = 56
H = 24
STEP = H

train_ratio = 0.70
val_ratio = 0.15

PLOTS_DIR = Path("VAR_synth_C_TFTaligned")
FIG1_DIR = PLOTS_DIR / "FIG1_first"
FIG2_DIR = PLOTS_DIR / "FIG2_stitched"
FIG1_DIR.mkdir(parents=True, exist_ok=True)
FIG2_DIR.mkdir(parents=True, exist_ok=True)


# Generate synthetic data
N_POINTS = 36
T_TOTAL = 1500
EIGENVALUE = 3.0

locs = np.linspace(-5.0, 5.0, N_POINTS).astype(np.float32)

cat_synth, cont_synth, y_synth = generate_time_synthetic_data(
    locs=locs,
    n_time_steps=T_TOTAL,
    noise_std=0.7,
    eigenvalue=EIGENVALUE,
    eta_rho=0.8,
    f_rho=0.6,
    global_mean=50.0,
    feature_noise_std=0.5,
    non_linear_strength=2.0,
    seed=GLOBAL_SEED,
)

print("Synthetic y shape:", y_synth.shape)
print("Synthetic cont shape:", cont_synth.shape)

Y = y_synth.astype(np.float32)
X_cov = cont_synth.astype(np.float32)

T_total, N = Y.shape
P_cov = X_cov.shape[2]

time_start = pd.Timestamp("2000-01-01 00:00:00")
time_index = pd.date_range(start=time_start, periods=T_total, freq="1H")

print("Total time steps:", T_total)
print("Points:", N)


# Scaling
cat_dummy = np.zeros((T_total, N, 1), dtype=np.int64)

train_ds, val_ds, test_ds, preprocessor = prepare_all_with_scaling(
    cat_features=cat_dummy,
    cont_features=X_cov,
    targets=Y,
    train_ratio=train_ratio,
    val_ratio=val_ratio,
    feature_scaler_type="standard",
    target_scaler_type="standard",
    fit_on_train_only=True,
)

def stitch(dsets):
    X_out = np.concatenate([d.tensors[1].cpu().numpy().astype(np.float32) for d in dsets], axis=0)
    y_out = np.concatenate([d.tensors[2].cpu().numpy().astype(np.float32) for d in dsets], axis=0)
    return X_out, y_out

X_s, y_s = stitch([train_ds, val_ds, test_ds])

if X_s.shape != X_cov.shape:
    raise ValueError(f"X_s shape {X_s.shape} != X_cov {X_cov.shape}")
if y_s.shape != Y.shape:
    raise ValueError(f"y_s shape {y_s.shape} != Y {Y.shape}")

cut_train = int(T_total * train_ratio)
cut_val = int(T_total * (train_ratio + val_ratio))
test_start = cut_val

print("Train:", 0, "->", cut_train - 1, "| len =", cut_train)
print("Val  :", cut_train, "->", cut_val - 1, "| len =", cut_val - cut_train)
print("Test :", cut_val, "->", T_total - 1, "| len =", T_total - cut_val)


# VARX design + fit (train-only)
def build_varx_design(
    y: np.ndarray,
    x: np.ndarray,
    t_start: int,
    t_end: int,
    p: int,
) -> Tuple[np.ndarray, np.ndarray]:
    rows = []
    ys = []
    for t in range(t_start, t_end):
        lag_list = [y[t - i] for i in range(1, p + 1)]
        lag_vec = np.concatenate(lag_list, axis=0)
        x_vec = x[t].reshape(-1)
        feat = np.concatenate([[1.0], lag_vec, x_vec]).astype(np.float32)
        rows.append(feat)
        ys.append(y[t])
    return np.stack(rows, axis=0), np.stack(ys, axis=0)

Xtr, Ytr = build_varx_design(y_s, X_s, t_start=P, t_end=cut_train, p=P)
print("Train design:", Xtr.shape, "->", Ytr.shape)

X64 = torch.from_numpy(Xtr).double()
Y64 = torch.from_numpy(Ytr).double()

B_mat = torch.linalg.lstsq(X64, Y64).solution
B_mat = B_mat.float().cpu().numpy().astype(np.float32)

print("VARX fitted B_mat:", B_mat.shape)


# Rolling (FREEZE cov)
def qrisk(y_true, y_pred, q=0.5, eps=1e-8):
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    e = y_true - y_pred
    return float(2.0 * np.sum(np.maximum(q * e, (q - 1) * e)) / (np.sum(np.abs(y_true)) + eps))

def roll_varx_freeze(
    y: np.ndarray,
    x: np.ndarray,
    B: np.ndarray,
    start_ctx: int,
    end_T: int,
    p: int,
    H: int,
    step: int,
) -> Tuple[np.ndarray, np.ndarray, list]:
    t0_raw = list(range(start_ctx, end_T - H + 1, step))
    yh_list = []
    yt_list = []
    used_t0 = []

    for t0 in t0_raw:
        if t0 < p:
            continue

        hist = [y[t0 - i].copy() for i in range(1, p + 1)]
        y_true = y[t0:t0 + H].copy()
        y_hat = np.empty_like(y_true)

        x_freeze = x[t0 - 1].reshape(-1).astype(np.float32)

        for k in range(H):
            lag_vec = np.concatenate(hist, axis=0).astype(np.float32)
            feat = np.concatenate([[1.0], lag_vec, x_freeze]).astype(np.float32)
            pred = feat @ B
            y_hat[k] = pred.astype(np.float32)
            hist = [pred.astype(np.float32)] + hist[:p - 1]

        yh_list.append(y_hat)
        yt_list.append(y_true)
        used_t0.append(int(t0))

    if len(yh_list) == 0:
        raise RuntimeError("No rolling windows produced. Check start_ctx/end_T/H/step.")
    return np.stack(yh_list, axis=0), np.stack(yt_list, axis=0), used_t0

start_ctx = max(test_start, L, P)
yhat_roll, ytrue_roll, t0_list = roll_varx_freeze(y_s, X_s, B_mat, start_ctx, T_total, P, H, STEP)

W = yhat_roll.shape[0]
diff = yhat_roll - ytrue_roll
rmse = float(np.sqrt(np.mean(diff ** 2)))
mae = float(np.mean(np.abs(diff)))

print("windows=", W, "| step=", STEP, "| H=", H, "| N=", N)
print("RMSE:", rmse)
print("MAE :", mae)

yhat_f = yhat_roll.reshape(-1, N)
ytrue_f = ytrue_roll.reshape(-1, N)

qr_scaled = qrisk(ytrue_f, yhat_f)

ytrue_un = preprocessor.inverse_transform_targets(ytrue_f)
yhat_un = preprocessor.inverse_transform_targets(yhat_f)
qr_unscaled = qrisk(ytrue_un, yhat_un)

print("scaled  :", qr_scaled)
print("unscaled:", qr_unscaled)


# Plots
t0_first = t0_list[0]

ytrue_stitch = ytrue_roll.reshape(W * H, N)
yhat_stitch = yhat_roll.reshape(W * H, N)
times2 = time_index[t0_first:t0_first + W * H]

for j in range(N):
    plt.figure(figsize=(9, 4), dpi=140)
    plt.plot(range(H), ytrue_roll[0, :, j], "-o", label="True")
    plt.plot(range(H), yhat_roll[0, :, j], "-o", label="Pred")
    plt.title(f"FIG1 VARX Synth C | grid{j:03d}")
    plt.xlabel("step")
    plt.ylabel("Scaled y")
    plt.grid()
    plt.tight_layout()
    plt.savefig(FIG1_DIR / f"FIG1_grid{j:03d}.png")
    plt.close()

    plt.figure(figsize=(12, 4), dpi=140)
    plt.plot(times2, ytrue_stitch[:, j], label="True")
    plt.plot(times2, yhat_stitch[:, j], label="Pred")
    plt.axvline(time_index[test_start], linestyle="--", linewidth=1, label="TEST start")
    plt.title(f"FIG2 VARX Synth C | grid{j:03d}")
    plt.xlabel("Time")
    plt.ylabel("Scaled y")
    plt.grid()
    plt.legend()
    plt.tight_layout()
    plt.savefig(FIG2_DIR / f"FIG2_grid{j:03d}.png")
    plt.close()

print("All done.")


âœ… Loaded spatial_utils from: /home/wangxc1117/geospatial-neural-adapter/geospatial_neural_adapter/cpp_extensions/spatial_utils.so
Device: cuda
Synthetic (Scenario C) y shape : (1500, 36)
Synthetic (Scenario C) cont shape: (1500, 36, 3)
Total time steps: 1500
Points: 36

=== Split (by time index) ===
Train: 0 -> 1049 | len = 1050
Val  : 1050 -> 1274 | len = 225
Test : 1275 -> 1499 | len = 225

Train design: (1049, 145) -> (1049, 36)
VARX fitted B_mat: (145, 36)

=== VARX TFT-aligned (Synthetic Scenario C, POOLED, TEST) ===
windows=9 | step=24 | H=24 | N=36
RMSE: 1.8280880451202393
MAE : 1.2372597455978394

=== P50 q-risk (Synthetic Scenario C) ===
scaled  : 0.8438975635680369
unscaled: 0.12899202630046752

All done (VARX baseline, Scenario C).
