In [None]:
from pathlib import Path
from typing import Tuple, List

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from geospatial_neural_adapter.data.preprocessing import prepare_all_with_scaling


# Paths & settings
W2K_PATH = Path("/home/wangxc1117/Weather2K/weather2k.npy")

GLOBAL_SEED = 42
np.random.seed(GLOBAL_SEED)

train_ratio = 0.70
val_ratio = 0.15

L = 56
H = 24
STEP = H

P = 1

LAT_MIN, LAT_MAX = 39.4, 41.1
LON_MIN, LON_MAX = 115.4, 117.5

EXP_ROOT = Path.cwd()
PLOTS_DIR = (EXP_ROOT / f"VARX_plots_weather2k_beijing_pooled_freezeCov_wdSinCos_P{P}_H{H}").resolve()
PLOTS_DIR.mkdir(parents=True, exist_ok=True)

print("PLOTS_DIR:", PLOTS_DIR)


# Load data
if not W2K_PATH.exists():
    raise FileNotFoundError(f"Weather2K npy not found at: {W2K_PATH}")

arr = np.load(W2K_PATH, allow_pickle=False).astype(np.float32)
S, V, T = arr.shape
if V != 13:
    raise ValueError(f"Expected 13 variables, got {V}")

print("Weather2K:", arr.shape)
time_index = pd.date_range("2000-01-01", periods=T, freq="3H")

TARGET_IDX = 4

IDX_AP  = 3
IDX_MXT = 5
IDX_MNT = 6
IDX_RH  = 7
IDX_P3  = 8
IDX_WD  = 9
IDX_WS  = 10
IDX_MWD = 11
IDX_MWS = 12

lat = arr[:, 0, 0]
lon = arr[:, 1, 0]
use_idx = np.where(
    (lat >= LAT_MIN) & (lat <= LAT_MAX) &
    (lon >= LON_MIN) & (lon <= LON_MAX)
)[0]
N = len(use_idx)

print("Beijing stations:", N)
if N == 0:
    raise RuntimeError("No Beijing stations found. Check bounding box.")


# Build pooled raw arrays
Y_raw = np.stack([arr[s, TARGET_IDX, :] for s in use_idx], axis=1).astype(np.float32)

ap  = np.stack([arr[s, IDX_AP,  :] for s in use_idx], axis=1).astype(np.float32)
mxt = np.stack([arr[s, IDX_MXT, :] for s in use_idx], axis=1).astype(np.float32)
mnt = np.stack([arr[s, IDX_MNT, :] for s in use_idx], axis=1).astype(np.float32)
rh  = np.stack([arr[s, IDX_RH,  :] for s in use_idx], axis=1).astype(np.float32)
p3  = np.stack([arr[s, IDX_P3,  :] for s in use_idx], axis=1).astype(np.float32)
ws  = np.stack([arr[s, IDX_WS,  :] for s in use_idx], axis=1).astype(np.float32)
mws = np.stack([arr[s, IDX_MWS, :] for s in use_idx], axis=1).astype(np.float32)

wd_deg  = np.stack([arr[s, IDX_WD,  :] for s in use_idx], axis=1).astype(np.float32)
mwd_deg = np.stack([arr[s, IDX_MWD, :] for s in use_idx], axis=1).astype(np.float32)

wd_rad  = (wd_deg.astype(np.float64)  * np.pi / 180.0)
mwd_rad = (mwd_deg.astype(np.float64) * np.pi / 180.0)

wd_sin  = np.sin(wd_rad).astype(np.float32)
wd_cos  = np.cos(wd_rad).astype(np.float32)
mwd_sin = np.sin(mwd_rad).astype(np.float32)
mwd_cos = np.cos(mwd_rad).astype(np.float32)

PAST_COV_COLS = ["ap", "mxt", "mnt", "rh", "p3", "wd_sin", "wd_cos", "ws", "mwd_sin", "mwd_cos", "mws"]
Xcov_raw = np.stack(
    [ap, mxt, mnt, rh, p3, wd_sin, wd_cos, ws, mwd_sin, mwd_cos, mws],
    axis=2
).astype(np.float32)

K = Xcov_raw.shape[2]
print("Y_raw:", Y_raw.shape, "| Xcov_raw:", Xcov_raw.shape, "| K:", K)
print("Past covs:", PAST_COV_COLS)

cat_dummy = np.zeros((T, N, 1), dtype=np.int64)


# POOLED train-only scaling
train_ds, val_ds, test_ds, preprocessor = prepare_all_with_scaling(
    cat_features=cat_dummy,
    cont_features=Xcov_raw,
    targets=Y_raw,
    train_ratio=train_ratio,
    val_ratio=val_ratio,
    feature_scaler_type="standard",
    target_scaler_type="standard",
    fit_on_train_only=True,
)

def stitch_scaled(dsets) -> Tuple[np.ndarray, np.ndarray]:
    Xs = np.concatenate([ds.tensors[1].cpu().numpy().astype(np.float32) for ds in dsets], axis=0)
    Ys = np.concatenate([ds.tensors[2].cpu().numpy().astype(np.float32) for ds in dsets], axis=0)
    return Xs, Ys

Xcov_s, y_s = stitch_scaled([train_ds, val_ds, test_ds])
if Xcov_s.shape != Xcov_raw.shape or y_s.shape != Y_raw.shape:
    raise ValueError("Scaled arrays shape mismatch. Check prepare_all_with_scaling output.")

print("\n=== After scaling ===")
print("Xcov_s:", Xcov_s.shape, "| y_s:", y_s.shape)
print("Xcov_s finite:", bool(np.isfinite(Xcov_s).all()), "| y_s finite:", bool(np.isfinite(y_s).all()))
if (not np.isfinite(Xcov_s).all()) or (not np.isfinite(y_s).all()):
    raise ValueError("Non-finite values after scaling.")


# Time splits
cut_train = int(T * train_ratio)
cut_val = int(T * (train_ratio + val_ratio))
if not (0 < cut_train < cut_val < T):
    raise ValueError("Bad split indices.")

print("\n=== Split (by time index) ===")
print("Train:", 0, "->", cut_train - 1, "| len =", cut_train)
print("Val  :", cut_train, "->", cut_val - 1, "| len =", cut_val - cut_train)
print("Test :", cut_val, "->", T - 1, "| len =", T - cut_val)

start_ctx = max(cut_val, L, P)


def build_varx_design(y: np.ndarray, x: np.ndarray, t_start: int, t_end_excl: int, p: int) -> Tuple[np.ndarray, np.ndarray]:
    Ttot, Nloc = y.shape
    _, Nloc2, Kloc = x.shape
    if Nloc2 != Nloc:
        raise ValueError("x and y station dimension mismatch.")
    if t_start < p:
        raise ValueError("t_start must be >= p.")

    rows = []
    ys = []
    for t in range(t_start, t_end_excl):
        lag_parts = []
        for i in range(1, p + 1):
            lag_parts.append(y[t - i, :].reshape(-1))
        lag_vec = np.concatenate(lag_parts, axis=0)

        x_vec = x[t, :, :].reshape(-1)
        feat = np.concatenate(
            [np.array([1.0], dtype=np.float32), lag_vec.astype(np.float32), x_vec.astype(np.float32)],
            axis=0
        )
        rows.append(feat)
        ys.append(y[t, :].astype(np.float32))

    Xmat = np.stack(rows, axis=0).astype(np.float32)
    Ymat = np.stack(ys, axis=0).astype(np.float32)
    return Xmat, Ymat

Xtr, Ytr = build_varx_design(
    y=y_s,
    x=Xcov_s,
    t_start=P,
    t_end_excl=cut_train,
    p=P,
)

M, D = Xtr.shape
print("\n=== VARX train design ===")
print("Xtr:", Xtr.shape, "| Ytr:", Ytr.shape, "| N:", N, "| P:", P, "| K:", K)
print("D =", D, " (= 1 + N*P + N*K )")

Xtr64 = Xtr.astype(np.float64)
Ytr64 = Ytr.astype(np.float64)

B, residuals, rank, svals = np.linalg.lstsq(Xtr64, Ytr64, rcond=None)
B = B.astype(np.float32)

print("\n=== VARX fitted (train-only) ===")
print("B:", B.shape, "| rank:", rank)
print("B finite:", bool(np.isfinite(B).all()))
if not np.isfinite(B).all():
    raise ValueError("Non-finite coefficients in VARX fit. Check data/scaling.")


def qrisk(y_true: np.ndarray, y_pred: np.ndarray, q: float = 0.5, eps: float = 1e-8) -> float:
    y_true = np.asarray(y_true, dtype=np.float64)
    y_pred = np.asarray(y_pred, dtype=np.float64)
    e = y_true - y_pred
    return float(2.0 * np.sum(np.maximum(q * e, (q - 1.0) * e)) / (np.sum(np.abs(y_true)) + eps))

def rolling_varx_freeze_nonoverlap(
    y: np.ndarray,
    x: np.ndarray,
    B: np.ndarray,
    start_ctx: int,
    end_T: int,
    p: int,
    H: int,
    step: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    t0_list = list(range(start_ctx, end_T - H + 1, step))
    yh_all: List[np.ndarray] = []
    yt_all: List[np.ndarray] = []
    kept_t0: List[int] = []

    for t0 in t0_list:
        if t0 < p:
            continue

        ytru = y[t0:t0 + H, :].astype(np.float32)

        x_freeze = x[t0 - 1, :, :].reshape(-1).astype(np.float32)

        hist = [y[t0 - i, :].astype(np.float32) for i in range(1, p + 1)]

        yhat = np.empty((H, N), dtype=np.float32)

        for k in range(H):
            lag_vec = np.concatenate([hist[i].reshape(-1) for i in range(p)], axis=0).astype(np.float32)
            feat = np.concatenate([np.array([1.0], dtype=np.float32), lag_vec, x_freeze], axis=0)
            ypred = feat @ B
            yhat[k, :] = ypred.astype(np.float32)
            hist = [ypred.astype(np.float32)] + hist[:p - 1]

        yh_all.append(yhat)
        yt_all.append(ytru)
        kept_t0.append(t0)

    if len(yh_all) == 0:
        raise RuntimeError("No rolling windows produced. Check start_ctx/end_T/H/step/p.")

    return np.stack(yh_all, axis=0), np.stack(yt_all, axis=0), np.asarray(kept_t0, dtype=int)

yhat_roll, ytrue_roll, t0_list = rolling_varx_freeze_nonoverlap(
    y=y_s,
    x=Xcov_s,
    B=B,
    start_ctx=start_ctx,
    end_T=T,
    p=P,
    H=H,
    step=STEP,
)

W = yhat_roll.shape[0]
diff = yhat_roll - ytrue_roll
rmse = float(np.sqrt(np.mean(diff ** 2)))
mae = float(np.mean(np.abs(diff)))

print("\n=== Rolling non-overlap (POOLED, TEST) [VARX + FREEZE cov | FAIR vs TFT] ===")
print(f"windows={W} | step={STEP} | each window predicts H={H} | stations={N} | P={P} | K={K}")
print("RMSE:", rmse)
print("MAE :", mae)

yhat_f = yhat_roll.reshape(-1, N)
ytrue_f = ytrue_roll.reshape(-1, N)

qr_scaled = qrisk(ytrue_f, yhat_f, q=0.5)

ytrue_un = preprocessor.inverse_transform_targets(ytrue_f)
yhat_un = preprocessor.inverse_transform_targets(yhat_f)
qr_unscaled = qrisk(ytrue_un, yhat_un, q=0.5)

print("\n=== P50 q-risk (ROLLING, TEST) ===")
print("scaled  :", qr_scaled)
print("unscaled:", qr_unscaled)


w0 = 0
t_start = int(t0_list[0])
t_end_excl = t_start + W * H
times_fig2 = time_index[t_start:t_end_excl]

ytrue_stitched = ytrue_roll.reshape(W * H, N)
yhat_stitched = yhat_roll.reshape(W * H, N)

print("\n=== Plotting ALL stations (2 figs/station) ===")
for j in range(N):
    station_id = int(use_idx[j])

    ytrue_fig1 = ytrue_roll[w0, :, j]
    yhat_fig1 = yhat_roll[w0, :, j]

    plt.figure(figsize=(9, 4), dpi=140)
    plt.plot(range(H), ytrue_fig1, "-o", linewidth=2, markersize=3, label="True (scaled)")
    plt.plot(range(H), yhat_fig1, "-o", linewidth=2, markersize=3, label="Pred (scaled)")
    plt.title(f"FIG1 VARX-FREEZE | Beijing st_{station_id} | first rolling window | H={H} | P={P}")
    plt.xlabel("3-hour step within window")
    plt.ylabel("Scaled t")
    plt.grid(alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.savefig(PLOTS_DIR / f"FIG1_VARX_freeze_wdSinCos_st{station_id}_H{H}_P{P}_scaled.png")
    plt.close()

    plt.figure(figsize=(12, 4), dpi=140)
    plt.plot(times_fig2, ytrue_stitched[:, j], "-", linewidth=2, label="True (scaled)")
    plt.plot(times_fig2, yhat_stitched[:, j], "-", linewidth=2, label="Pred (scaled)")
    plt.title(f"FIG2 VARX-FREEZE | Beijing st_{station_id} | full TEST (stitched) | H={H} | P={P}")
    plt.xlabel("Time")
    plt.ylabel("Scaled t")
    plt.grid(alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.savefig(PLOTS_DIR / f"FIG2_VARX_freeze_wdSinCos_st{station_id}_H{H}_P{P}_scaled.png")
    plt.close()

print(f"Saved plots to: {PLOTS_DIR.resolve()}")
print("All done.")


âœ… Loaded spatial_utils from: /home/wangxc1117/geospatial-neural-adapter/geospatial_neural_adapter/cpp_extensions/spatial_utils.so
PLOTS_DIR: /home/wangxc1117/TFTModel-use/geospatial-neural-adapter-dev/examples/try/weather2k/VAR/VARX_plots_weather2k_beijing_pooled_freezeCov_wdSinCos_P1_H24
Weather2K: (1866, 13, 13632)
Beijing stations: 31
Y_raw: (13632, 31) | Xcov_raw: (13632, 31, 11) | K: 11
Past covs: ['ap', 'mxt', 'mnt', 'rh', 'p3', 'wd_sin', 'wd_cos', 'ws', 'mwd_sin', 'mwd_cos', 'mws']

=== After scaling ===
Xcov_s: (13632, 31, 11) | y_s: (13632, 31)
Xcov_s finite: True | y_s finite: True

=== Split (by time index) ===
Train: 0 -> 9541 | len = 9542
Val  : 9542 -> 11586 | len = 2045
Test : 11587 -> 13631 | len = 2045

=== VARX train design ===
Xtr: (9541, 373) | Ytr: (9541, 31) | N: 31 | P: 1 | K: 11
D = 373  (= 1 + N*P + N*K )

=== VARX fitted (train-only) ===
B: (373, 31) | rank: 373
B finite: True

=== Rolling non-overlap (POOLED, TEST) [VARX + FREEZE cov | FAIR vs TFT] ===
wind