In [None]:
from pathlib import Path
from typing import Tuple

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

from geospatial_neural_adapter.data.preprocessing import prepare_all_with_scaling
from geospatial_neural_adapter.data.generators import generate_time_synthetic_data

GLOBAL_SEED = 42
np.random.seed(GLOBAL_SEED)
torch.manual_seed(GLOBAL_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

P = 1             
L = 56            
H = 24            
STEP = H          

train_ratio = 0.70
val_ratio   = 0.15

N_POINTS  = 36
T_TOTAL   = 1500
EIGENVALUE = 3.0

# Plots
PLOTS_DIR = Path("VAR_synth_A_TFTaligned")
FIG1_DIR  = PLOTS_DIR / "FIG1_first"
FIG2_DIR  = PLOTS_DIR / "FIG2_stitched"
FIG1_DIR.mkdir(parents=True, exist_ok=True)
FIG2_DIR.mkdir(parents=True, exist_ok=True)


# Generate synthetic data

locs = np.linspace(-5.0, 5.0, N_POINTS).astype(np.float32)  # (N,)

cat_synth, cont_synth, y_synth = generate_time_synthetic_data(
    locs=locs,
    n_time_steps=T_TOTAL,
    noise_std=0.3,
    eigenvalue=EIGENVALUE,
    eta_rho=0.8,
    f_rho=0.6,
    global_mean=50.0,
    feature_noise_std=0.0,
    non_linear_strength=0.0,
    seed=GLOBAL_SEED,
)

Y = y_synth.astype(np.float32)          
X_cov = cont_synth.astype(np.float32)   

T_total, N = Y.shape
_, _, P_cov = X_cov.shape

print("Y shape   :", Y.shape)
print("X_cov shape:", X_cov.shape)

time_start = pd.Timestamp("2000-01-01 00:00:00")
time_index = pd.date_range(start=time_start, periods=T_total, freq="1H")

cat_dummy = np.zeros((T_total, N, 1), dtype=np.int64)


# Scaling 

train_ds, val_ds, test_ds, preprocessor = prepare_all_with_scaling(
    cat_features=cat_dummy,
    cont_features=X_cov,        
    targets=Y,
    train_ratio=train_ratio,
    val_ratio=val_ratio,
    feature_scaler_type="standard",
    target_scaler_type="standard",
    fit_on_train_only=True,
)

def stitch(dsets):
    X = np.concatenate(
        [d.tensors[1].cpu().numpy().astype(np.float32) for d in dsets],
        axis=0,
    )
    y = np.concatenate(
        [d.tensors[2].cpu().numpy().astype(np.float32) for d in dsets],
        axis=0,
    )
    return X, y

X_s, y_s = stitch([train_ds, val_ds, test_ds])  

if X_s.shape != X_cov.shape:
    raise ValueError(f"X_s shape {X_s.shape} != X_cov {X_cov.shape}")
if y_s.shape != Y.shape:
    raise ValueError(f"y_s shape {y_s.shape} != Y {Y.shape}")

cut_train = int(T_total * train_ratio)
cut_val   = int(T_total * (train_ratio + val_ratio))
test_start = cut_val

print("\n=== Split (by time index) ===")
print("Train:", 0, "->", cut_train - 1, "| len =", cut_train)
print("Val  :", cut_train, "->", cut_val - 1, "| len =", cut_val - cut_train)
print("Test :", cut_val, "->", T_total - 1, "| len =", T_total - cut_val)


# Build VARX design matrix 

def build_varx_design(
    y: np.ndarray,
    x: np.ndarray,
    t_start: int,
    t_end: int,
    p: int,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Build pooled VARX design:
      For each t in [t_start, t_end):
        feat(t) = [1, vec(Y_{t-1..t-p}), vec(X_t)]
        target  = Y_t

    Args:
        y: (T, N) scaled target
        x: (T, N, P_cov) scaled cov
        t_start: int >= p
        t_end: int <= T
        p: VAR order

    Returns:
        X_design: (M, D)
        Y_design: (M, N)
    """
    rows = []
    ys   = []
    for t in range(t_start, t_end):
        lag_list = []
        for i in range(1, p + 1):
            lag_list.append(y[t - i])        
        lag_vec = np.concatenate(lag_list, axis=0)  

        x_vec = x[t].reshape(-1)             

        feat = np.concatenate([[1.0], lag_vec, x_vec]).astype(np.float32)
        rows.append(feat)
        ys.append(y[t])                      

    X_design = np.stack(rows, axis=0)       
    Y_design = np.stack(ys, axis=0)          
    return X_design, Y_design

Xtr, Ytr = build_varx_design(
    y=y_s,
    x=X_s,
    t_start=P,       
    t_end=cut_train,
    p=P,
)

print("\n=== Train design (VARX) ===")
print("Xtr:", Xtr.shape, "Ytr:", Ytr.shape)

Xtr_t = torch.from_numpy(Xtr).double().to(DEVICE) 
Ytr_t = torch.from_numpy(Ytr).double().to(DEVICE)  

B, *_ = torch.linalg.lstsq(Xtr_t, Ytr_t)
B = B.float().cpu().numpy()

print("B shape (coeff matrix):", B.shape)

# Rolling recursive VARX

def qrisk(y_true: np.ndarray, y_pred: np.ndarray, q: float = 0.5, eps: float = 1e-8) -> float:
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    e = y_true - y_pred
    return float(
        2.0 * np.sum(np.maximum(q * e, (q - 1) * e)) / (np.sum(np.abs(y_true)) + eps)
    )

def rolling_varx(
    y: np.ndarray,
    x: np.ndarray,
    B: np.ndarray,
    start_ctx: int,
    end_T: int,
    p: int,
    H: int,
    step: int,
):

    t0_candidates = range(start_ctx, end_T - H + 1, step)
    yh_list = []
    yt_list = []
    t0_list = []

    for t0 in t0_candidates:
        if t0 < p:
            continue

        hist = [y[t0 - i].copy() for i in range(1, p + 1)] 

        x_vec_frozen = x[t0 - 1].reshape(-1)              
    
        ytrue = y[t0:t0 + H].copy()                         
        yhat  = np.empty_like(ytrue)                      

        for k in range(H):
            lag_vec = np.concatenate(hist, axis=0)         

            feat = np.concatenate([[1.0], lag_vec, x_vec_frozen]).astype(np.float32) 
            pred = feat @ B                                 
            yhat[k] = pred

            hist = [pred] + hist[:p - 1]

        yh_list.append(yhat)
        yt_list.append(ytrue)
        t0_list.append(t0)

    if len(yh_list) == 0:
        raise RuntimeError("No rolling windows produced. Check start_ctx/end_T/H/step.")

    return np.stack(yh_list, axis=0), np.stack(yt_list, axis=0), t0_list

start_ctx = max(test_start, L, P)
yhat_roll, ytrue_roll, t0_list = rolling_varx(
    y=y_s,
    x=X_s,
    B=B,
    start_ctx=start_ctx,
    end_T=T_total,
    p=P,
    H=H,
    step=STEP,
)

W = yhat_roll.shape[0]
diff = yhat_roll - ytrue_roll
rmse = float(np.sqrt(np.mean(diff ** 2)))
mae  = float(np.mean(np.abs(diff)))

print("\n=== VARX TFT-aligned (Synthetic Scenario A, POOLED, TEST, NO future cov) ===")
print(f"windows={W} | step={STEP} | H={H} | N={N}")
print("RMSE:", rmse)
print("MAE :", mae)

yhat_f = yhat_roll.reshape(-1, N)
ytrue_f = ytrue_roll.reshape(-1, N)

qr_scaled = qrisk(ytrue_f, yhat_f)
ytrue_un = preprocessor.inverse_transform_targets(ytrue_f)
yhat_un = preprocessor.inverse_transform_targets(yhat_f)
qr_unscaled = qrisk(ytrue_un, yhat_un)

print("\n=== P50 q-risk (Synthetic Scenario A, NO future cov) ===")
print("scaled  :", qr_scaled)
print("unscaled:", qr_unscaled)


# plots

t0_first = t0_list[0]
times1   = time_index[t0_first:t0_first + H]

st_ytrue = ytrue_roll.reshape(W * H, N)
st_yhat  = yhat_roll.reshape(W * H, N)
times2   = time_index[t0_first:t0_first + W * H]

for j in range(N):
    plt.figure(figsize=(9, 4), dpi=140)
    plt.plot(range(H), ytrue_roll[0, :, j], "-o", label="True")
    plt.plot(range(H), yhat_roll[0, :, j], "-o", label="Pred")
    plt.title(f"FIG1 VARX Synth A | grid{j:03d}")
    plt.xlabel("step")
    plt.ylabel("Scaled y")
    plt.grid()
    plt.tight_layout()
    plt.savefig(FIG1_DIR / f"FIG1_VAR_grid{j:03d}.png")
    plt.close()

    plt.figure(figsize=(12, 4), dpi=140)
    plt.plot(times2, st_ytrue[:, j], label="True")
    plt.plot(times2, st_yhat[:, j], label="Pred")
    plt.axvline(time_index[test_start], linestyle="--", linewidth=1, label="TEST start")
    plt.title(f"FIG2 VARX Synth A | grid{j:03d}")
    plt.xlabel("Time")
    plt.ylabel("Scaled y")
    plt.grid()
    plt.legend()
    plt.tight_layout()
    plt.savefig(FIG2_DIR / f"FIG2_VAR_grid{j:03d}.png")
    plt.close()

print("\nAll done (VARX baseline, Scenario A, NO future cov).")


âœ… Loaded spatial_utils from: /home/wangxc1117/geospatial-neural-adapter/geospatial_neural_adapter/cpp_extensions/spatial_utils.so
Device: cuda
Y shape   : (1500, 36)
X_cov shape: (1500, 36, 3)

=== Split (by time index) ===
Train: 0 -> 1049 | len = 1050
Val  : 1050 -> 1274 | len = 225
Test : 1275 -> 1499 | len = 225

=== Train design (VARX) ===
Xtr: (1049, 145) Ytr: (1049, 36)
B shape (coeff matrix): (145, 36)

=== VARX TFT-aligned (Synthetic Scenario A, POOLED, TEST, NO future cov) ===
windows=9 | step=24 | H=24 | N=36
RMSE: 1.2062946557998657
MAE : 0.9611192345619202

=== P50 q-risk (Synthetic Scenario A, NO future cov) ===
scaled  : 0.8045394427193219
unscaled: 0.013991723780361455

All done (VARX baseline, Scenario A, NO future cov).
