In [1]:
# -*- coding: utf-8 -*-
# ============================================================
# Weather2K (Beijing subset) + VAR(1) baseline (NO COV)
# - Scaling: POOLED (train-only) via prepare_all_with_scaling
# - Model  : VAR(1): y_t = c + A y_{t-1}
# - Inference: Rolling non-overlap on TEST (step=H), 24-step INTERNAL recursive
# - Metrics: RMSE/MAE + P50 q-risk (scaled/unscaled) over TEST rolling windows
# - Plots : ALL stations
#          FIG1: first rolling window (24 steps)
#          FIG2: full TEST (stitched rolling windows)
# ============================================================

from pathlib import Path
from typing import Tuple

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

from geospatial_neural_adapter.data.preprocessing import prepare_all_with_scaling


# ============================================================
# 0) Paths & settings
# ============================================================
W2K_PATH = Path("/home/wangxc1117/Weather2K/weather2k.npy")

GLOBAL_SEED = 42
np.random.seed(GLOBAL_SEED)
torch.manual_seed(GLOBAL_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# Split / rolling config (match TFT pooled pipeline)
train_ratio = 0.70
val_ratio = 0.15
L, H = 56, 24          # L only constrains rolling start (match TFT)
EPS_RIDGE = 1e-6       # tiny ridge for numerical stability in VAR solve

# Plot config
PLOTS_DIR = Path("VAR1_noCov_plots_weather2k_beijing_pooled")
PLOTS_DIR.mkdir(parents=True, exist_ok=True)


# ============================================================
# 1) Load data + Beijing subset
# ============================================================
arr = np.load(W2K_PATH, allow_pickle=False).astype(np.float32)
S, V, T = arr.shape
assert V == 13
print("Weather2K:", arr.shape)

time_index = pd.date_range("2000-01-01", periods=T, freq="3H")

# 0 lat, 1 lon, 2 alt, 3 ap, 4 t, 5 mxt, 6 mnt, 7 rh, 8 p3, 9 wd, 10 ws, 11 mwd, 12 mws
TARGET_IDX = 4

LAT_MIN, LAT_MAX = 39.4, 41.1
LON_MIN, LON_MAX = 115.4, 117.5
lat = arr[:, 0, 0]
lon = arr[:, 1, 0]

use_idx = np.where(
    (lat >= LAT_MIN) & (lat <= LAT_MAX) &
    (lon >= LON_MIN) & (lon <= LON_MAX)
)[0]
N = len(use_idx)
print("Beijing stations:", N)
if N == 0:
    raise RuntimeError("No Beijing stations found. Check bounding box.")

# Targets Y: (T, N)
Y = np.stack([arr[s, TARGET_IDX] for s in use_idx], axis=1).astype(np.float32)

# Dummy cont + cat for prepare_all_with_scaling (NO COV => cont features can be zeros)
X_dummy = np.zeros((T, N, 1), dtype=np.float32)
cat_dummy = np.zeros((T, N, 1), dtype=np.int64)

print("Y:", Y.shape, "| X_dummy:", X_dummy.shape)


# ============================================================
# 2) Pooled train-only scaling (same preprocessor style as your TFT pooled)
# ============================================================
train_ds, val_ds, test_ds, preprocessor = prepare_all_with_scaling(
    cat_features=cat_dummy,
    cont_features=X_dummy,
    targets=Y,
    train_ratio=train_ratio,
    val_ratio=val_ratio,
    feature_scaler_type="standard",
    target_scaler_type="standard",
    fit_on_train_only=True,
)

def stitch_targets(dsets) -> np.ndarray:
    return np.concatenate([ds.tensors[2].cpu().numpy().astype(np.float32) for ds in dsets], axis=0)

y_s = stitch_targets([train_ds, val_ds, test_ds])  # (T, N)
assert y_s.shape == Y.shape

print("\n=== After scaling ===")
print("y_s:", y_s.shape, "| finite:", bool(np.isfinite(y_s).all()))
if not np.isfinite(y_s).all():
    raise ValueError("Non-finite y_s after scaling.")

cut_train = int(T * train_ratio)
cut_val = int(T * (train_ratio + val_ratio))
train_len = cut_train


# ============================================================
# 3) Fit VAR(1) on TRAIN (pooled)
#    y_t = c + A y_{t-1}
# ============================================================
# Build regression: Y1 = X * B, where X = [1, y_{t-1}], B = [c; A^T]
# Shapes:
#   y_{t-1}: (train_len-1, N)
#   y_t    : (train_len-1, N)
Y_prev = y_s[:train_len - 1, :].astype(np.float64)   # (M, N)
Y_next = y_s[1:train_len, :].astype(np.float64)      # (M, N)
M = Y_prev.shape[0]

X = np.concatenate([np.ones((M, 1), dtype=np.float64), Y_prev], axis=1)  # (M, 1+N)
Ymat = Y_next  # (M, N)

# Ridge-stabilized normal equation:
# B = (X^T X + lam I)^(-1) X^T Y
XtX = X.T @ X
lam = EPS_RIDGE
XtX_reg = XtX + lam * np.eye(XtX.shape[0], dtype=np.float64)
XtY = X.T @ Ymat
B = np.linalg.solve(XtX_reg, XtY)  # (1+N, N)

c = B[0, :].astype(np.float32)           # (N,)
A = B[1:, :].astype(np.float32).T        # (N, N)  because B stores A^T row-wise

print("\n=== VAR(1) fitted (pooled, NO COV) ===")
print("c shape:", c.shape, "| A shape:", A.shape)
print("A finite:", bool(np.isfinite(A).all()), "| c finite:", bool(np.isfinite(c).all()))
if not np.isfinite(A).all() or not np.isfinite(c).all():
    raise ValueError("Non-finite VAR parameters. Try increasing EPS_RIDGE.")


# ============================================================
# 4) Rolling non-overlap on TEST, 24-step INTERNAL recursive
# ============================================================
def qrisk(y_true, y_pred, q=0.5, eps=1e-8) -> float:
    y_true = np.asarray(y_true, dtype=np.float64)
    y_pred = np.asarray(y_pred, dtype=np.float64)
    e = y_true - y_pred
    num = 2.0 * np.sum(np.maximum(q * e, (q - 1.0) * e))
    den = np.sum(np.abs(y_true)) + eps
    return float(num / den)

def rolling_var1_recursive_nonoverlap(
    y_s_full: np.ndarray,   # (T, N) scaled true
    start_ctx: int,
    end_T: int,
    H: int,
    step: int,
    c_vec: np.ndarray,      # (N,)
    A_mat: np.ndarray,      # (N,N)
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    For each window starting at t0:
      prev = true y at (t0-1)
      for k in 0..H-1:
        yhat_k = c + A @ prev
        prev = yhat_k
    Return:
      yhat_roll: (W, H, N)
      ytrue_roll:(W, H, N)
      t0_list:   (W,)
    """
    t0_list = list(range(start_ctx, end_T - H + 1, step))
    yh_all, yt_all, kept_t0 = [], [], []

    c64 = c_vec.astype(np.float64)
    A64 = A_mat.astype(np.float64)

    for t0 in t0_list:
        if t0 < 1:
            continue

        ytru = y_s_full[t0:t0 + H].astype(np.float32)   # (H,N)

        prev = y_s_full[t0 - 1].astype(np.float64)      # (N,)
        yhat = np.empty((H, N), dtype=np.float32)

        for k in range(H):
            ypred = c64 + (A64 @ prev)                  # (N,)
            if not np.isfinite(ypred).all():
                bad = ~np.isfinite(ypred)
                idx = int(np.where(bad)[0][0])
                raise ValueError(f"Non-finite ypred at window t0={t0}, step k={k}, station j={idx}")
            yhat[k, :] = ypred.astype(np.float32)
            prev = ypred

        yh_all.append(yhat)
        yt_all.append(ytru)
        kept_t0.append(t0)

    return np.stack(yh_all, axis=0), np.stack(yt_all, axis=0), np.asarray(kept_t0, dtype=int)

start_ctx = max(cut_val, L)   # TEST start, ensure context >= L (match TFT)
STEP = H

yhat_roll, ytrue_roll, t0_list = rolling_var1_recursive_nonoverlap(
    y_s_full=y_s,
    start_ctx=start_ctx,
    end_T=T,
    H=H,
    step=STEP,
    c_vec=c,
    A_mat=A,
)

W = yhat_roll.shape[0]
diff = yhat_roll - ytrue_roll
rmse = float(np.sqrt(np.mean(diff ** 2)))
mae = float(np.mean(np.abs(diff)))

print("\n=== Rolling non-overlap (POOLED, TEST) [VAR(1) recursive | NO COV] ===")
print(f"windows={W} | step={STEP} | each window predicts H={H} | stations={N}")
print("RMSE:", rmse)
print("MAE :", mae)

yhat_f = yhat_roll.reshape(-1, N)
ytrue_f = ytrue_roll.reshape(-1, N)

qr_scaled = qrisk(ytrue_f, yhat_f, q=0.5)
ytrue_un = preprocessor.inverse_transform_targets(ytrue_f)
yhat_un = preprocessor.inverse_transform_targets(yhat_f)
qr_unscaled = qrisk(ytrue_un, yhat_un, q=0.5)

print("\n=== P50 q-risk (ROLLING, TEST) ===")
print("scaled  :", qr_scaled)
print("unscaled:", qr_unscaled)


# ============================================================
# 5) Plots: ALL stations (2 figs per station)
#     FIG1: first rolling window (24 steps)
#     FIG2: full TEST (stitched rolling windows)
# ============================================================
print("\n=== Plotting ALL stations ===")

if W == 0:
    raise RuntimeError("No rolling windows produced. Check start_ctx/end_T/H/step.")

# For stitched TEST timeline (non-overlap => contiguous)
t_start = int(t0_list[0])
t_end_excl = int(t0_list[0] + W * H)
times_fig2 = time_index[t_start:t_end_excl]

ytrue_stitched = ytrue_roll.reshape(W * H, N)
yhat_stitched = yhat_roll.reshape(W * H, N)

for j in range(N):
    station_id = int(use_idx[j])

    # FIG1
    ytrue_fig1 = ytrue_roll[0, :, j]
    yhat_fig1 = yhat_roll[0, :, j]

    plt.figure(figsize=(9, 4), dpi=140)
    plt.plot(range(H), ytrue_fig1, "-o", linewidth=2, markersize=3, label="True (scaled)")
    plt.plot(range(H), yhat_fig1, "-o", linewidth=2, markersize=3, label="Pred (scaled)")
    plt.title(f"FIG1 VAR(1) NO-COV | Beijing st_{station_id} | first window | H={H} (scaled)")
    plt.xlabel("3-hour step within window")
    plt.ylabel("Scaled t")
    plt.grid(alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.savefig(PLOTS_DIR / f"FIG1_VAR1_noCov_firstWindow_st{station_id}_H{H}_scaled.png")
    plt.close()

    # FIG2
    plt.figure(figsize=(12, 4), dpi=140)
    plt.plot(times_fig2, ytrue_stitched[:, j], "-", linewidth=2, label="True (scaled)")
    plt.plot(times_fig2, yhat_stitched[:, j], "-", linewidth=2, label="Pred (scaled)")
    plt.title(f"FIG2 VAR(1) NO-COV | Beijing st_{station_id} | full TEST (stitched) | H={H} (scaled)")
    plt.xlabel("Time")
    plt.ylabel("Scaled t")
    plt.grid(alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.savefig(PLOTS_DIR / f"FIG2_VAR1_noCov_testAll_st{station_id}_H{H}_scaled.png")
    plt.close()

print(f"Saved plots to: {PLOTS_DIR.resolve()}")
print("\nAll done.")


âœ… Loaded spatial_utils from: /home/wangxc1117/geospatial-neural-adapter/geospatial_neural_adapter/cpp_extensions/spatial_utils.so
Device: cuda
Weather2K: (1866, 13, 13632)
Beijing stations: 31
Y: (13632, 31) | X_dummy: (13632, 31, 1)

=== After scaling ===
y_s: (13632, 31) | finite: True

=== VAR(1) fitted (pooled, NO COV) ===
c shape: (31,) | A shape: (31, 31)
A finite: True | c finite: True

=== Rolling non-overlap (POOLED, TEST) [VAR(1) recursive | NO COV] ===
windows=85 | step=24 | each window predicts H=24 | stations=31
RMSE: 0.5358965992927551
MAE : 0.438996285200119

=== P50 q-risk (ROLLING, TEST) ===
scaled  : 0.5029687725303151
unscaled: 0.35345874129929217

=== Plotting ALL stations ===
Saved plots to: /home/wangxc1117/TFTModel-use/geospatial-neural-adapter-dev/examples/try/weather2k/VAR_nocov/VAR1_noCov_plots_weather2k_beijing_pooled

All done.
