In [None]:
from pathlib import Path
from typing import Tuple, List

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

from geospatial_neural_adapter.data.preprocessing import prepare_all_with_scaling


# Paths & settings
W2K_PATH = Path("/home/wangxc1117/Weather2K/weather2k.npy")

GLOBAL_SEED = 42
np.random.seed(GLOBAL_SEED)
torch.manual_seed(GLOBAL_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

train_ratio = 0.70
val_ratio = 0.15
L, H = 56, 24
STEP = H
EPS_STD = 1e-8

PLOTS_DIR = Path("OLS_plots_weather2k_beijing_pooled_fair_vs_tft_ALL_wd_sincos")
FIG1_DIR = PLOTS_DIR / "FIG1_first_window"
FIG2_DIR = PLOTS_DIR / "FIG2_full_test_stitched"
FIG1_DIR.mkdir(parents=True, exist_ok=True)
FIG2_DIR.mkdir(parents=True, exist_ok=True)

USE_MXT_MNT = True


# Load data
if not W2K_PATH.exists():
    raise FileNotFoundError(f"Weather2K npy not found at: {W2K_PATH}")

arr = np.load(W2K_PATH, allow_pickle=False).astype(np.float32)
S, V, T = arr.shape
assert V == 13
print("Weather2K:", arr.shape)

time_index = pd.date_range("2000-01-01", periods=T, freq="3H")

TARGET_IDX = 4
IDX_AP  = 3
IDX_MXT = 5
IDX_MNT = 6
IDX_RH  = 7
IDX_P3  = 8
IDX_WD  = 9
IDX_WS  = 10
IDX_MWD = 11
IDX_MWS = 12

LAT_MIN, LAT_MAX = 39.4, 41.1
LON_MIN, LON_MAX = 115.4, 117.5
lat = arr[:, 0, 0]
lon = arr[:, 1, 0]
use_idx = np.where(
    (lat >= LAT_MIN) & (lat <= LAT_MAX) &
    (lon >= LON_MIN) & (lon <= LON_MAX)
)[0]
N = len(use_idx)
print("Beijing stations:", N)
if N == 0:
    raise RuntimeError("No Beijing stations found. Check bounding box.")


# Build pooled arrays
Y = np.stack([arr[s, TARGET_IDX] for s in use_idx], axis=1).astype(np.float32)

ap  = np.stack([arr[s, IDX_AP]  for s in use_idx], axis=1).astype(np.float32)
rh  = np.stack([arr[s, IDX_RH]  for s in use_idx], axis=1).astype(np.float32)
p3  = np.stack([arr[s, IDX_P3]  for s in use_idx], axis=1).astype(np.float32)
ws  = np.stack([arr[s, IDX_WS]  for s in use_idx], axis=1).astype(np.float32)
mws = np.stack([arr[s, IDX_MWS] for s in use_idx], axis=1).astype(np.float32)

mxt = np.stack([arr[s, IDX_MXT] for s in use_idx], axis=1).astype(np.float32)
mnt = np.stack([arr[s, IDX_MNT] for s in use_idx], axis=1).astype(np.float32)

wd_deg  = np.stack([arr[s, IDX_WD]  for s in use_idx], axis=1).astype(np.float32)
mwd_deg = np.stack([arr[s, IDX_MWD] for s in use_idx], axis=1).astype(np.float32)

wd_rad  = (wd_deg.astype(np.float64)  * np.pi / 180.0)
mwd_rad = (mwd_deg.astype(np.float64) * np.pi / 180.0)

wd_sin  = np.sin(wd_rad).astype(np.float32)
wd_cos  = np.cos(wd_rad).astype(np.float32)
mwd_sin = np.sin(mwd_rad).astype(np.float32)
mwd_cos = np.cos(mwd_rad).astype(np.float32)

cov_blocks: List[np.ndarray] = [ap]
cov_names: List[str] = ["ap"]

if USE_MXT_MNT:
    cov_blocks += [mxt, mnt]
    cov_names  += ["mxt", "mnt"]

cov_blocks += [rh, p3, wd_sin, wd_cos, ws, mwd_sin, mwd_cos, mws]
cov_names  += ["rh", "p3", "wd_sin", "wd_cos", "ws", "mwd_sin", "mwd_cos", "mws"]

Xcov = np.stack(cov_blocks, axis=2).astype(np.float32)

print("Using past covariates:", cov_names)

lag1 = np.zeros((T, N, 1), dtype=np.float32)
lag1[1:, :, 0] = Y[:-1, :]

Xfull = np.concatenate([lag1, Xcov], axis=2).astype(np.float32)
feat_names_full = ["lag1"] + cov_names
p_dim = Xfull.shape[2]
print("Y:", Y.shape, "| Xfull:", Xfull.shape, "| p_dim:", p_dim)

cat_dummy = np.zeros((T, N, 1), dtype=np.int64)


# Pooled train-only scaling
train_ds, val_ds, test_ds, preprocessor = prepare_all_with_scaling(
    cat_features=cat_dummy,
    cont_features=Xfull,
    targets=Y,
    train_ratio=train_ratio,
    val_ratio=val_ratio,
    feature_scaler_type="standard",
    target_scaler_type="standard",
    fit_on_train_only=True,
)

def stitch_from_dsets(dsets) -> Tuple[np.ndarray, np.ndarray]:
    Xs = np.concatenate([ds.tensors[1].cpu().numpy().astype(np.float32) for ds in dsets], axis=0)
    Ys = np.concatenate([ds.tensors[2].cpu().numpy().astype(np.float32) for ds in dsets], axis=0)
    return Xs, Ys

X_s, y_s = stitch_from_dsets([train_ds, val_ds, test_ds])
assert X_s.shape == Xfull.shape
assert y_s.shape == Y.shape

print("\n=== After scaling ===")
print("X_s:", X_s.shape, "| y_s:", y_s.shape)
print("X_s finite:", bool(np.isfinite(X_s).all()), "| y_s finite:", bool(np.isfinite(y_s).all()))
if not np.isfinite(X_s).all() or not np.isfinite(y_s).all():
    raise ValueError("Non-finite values after scaling.")

cut_train = int(T * train_ratio)
cut_val = int(T * (train_ratio + val_ratio))
train_len = cut_train
test_start = cut_val


# Drop near-constant features based on TRAIN ONLY (pooled)
X_train = X_s[:train_len]
X_flat = X_train.reshape(-1, p_dim)
feat_std = X_flat.std(axis=0)

keep_mask = feat_std > EPS_STD
kept_idx = np.where(keep_mask)[0].tolist()
dropped_idx = np.where(~keep_mask)[0].tolist()

print("\n=== Feature variance screening (TRAIN ONLY, pooled) ===")
for i, nm in enumerate(feat_names_full):
    print(f"  [{i:02d}] {nm:12s} std={feat_std[i]:.12e}")
print("Dropped:", [feat_names_full[i] for i in dropped_idx] if dropped_idx else "NONE")

if len(kept_idx) == 0:
    raise ValueError("All features dropped; cannot fit OLS.")

X_s_red = X_s[:, :, keep_mask].astype(np.float32)
feat_names_red = [feat_names_full[i] for i in kept_idx]
p_red = X_s_red.shape[2]
print("X_s_red:", X_s_red.shape, "| p_red:", p_red)
print("kept features:", feat_names_red)

if "lag1" not in feat_names_red:
    raise ValueError("lag1 was dropped, which should not happen.")


# OLS fit (pooled) on TRAIN
Xtr = X_s_red[:train_len].reshape(-1, p_red)
ytr = y_s[:train_len].reshape(-1, 1)

Xtr_t = torch.from_numpy(Xtr).to(DEVICE)
ytr_t = torch.from_numpy(ytr).to(DEVICE)

ones = torch.ones((Xtr_t.shape[0], 1), device=DEVICE, dtype=Xtr_t.dtype)
X_aug = torch.cat([ones, Xtr_t], dim=1)

beta_aug = torch.linalg.lstsq(X_aug, ytr_t).solution.squeeze(1)
b0 = float(beta_aug[0].detach().cpu().item())
beta = beta_aug[1:].detach().cpu().numpy().astype(np.float32)

print("\n=== OLS fitted ===")
print("b0:", b0)
print("beta shape:", beta.shape)


# Rolling non-overlap on TEST (FAIR vs TFT)
def qrisk(y_true, y_pred, q=0.5, eps=1e-8):
    y_true = np.asarray(y_true, dtype=np.float64)
    y_pred = np.asarray(y_pred, dtype=np.float64)
    e = y_true - y_pred
    return float(2.0 * np.sum(np.maximum(q * e, (q - 1) * e)) / (np.sum(np.abs(y_true)) + eps))

IDX_LAG1_RED = feat_names_red.index("lag1")

def rolling_ols_recursive_nonoverlap_fair(
    X_s_red_full: np.ndarray,
    y_s_full: np.ndarray,
    start_ctx: int,
    end_T: int,
    H: int,
    step: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    t0_list = list(range(start_ctx, end_T - H + 1, step))
    yh_all, yt_all, kept_t0 = [], [], []

    for t0 in t0_list:
        if t0 < 1:
            continue

        freeze_cov = X_s_red_full[t0 - 1].copy().astype(np.float32)
        ytru = y_s_full[t0:t0+H].astype(np.float32)

        prev = y_s_full[t0 - 1].astype(np.float32)
        yhat = np.empty((H, N), dtype=np.float32)

        for k in range(H):
            Xt = freeze_cov.copy()
            Xt[:, IDX_LAG1_RED] = prev
            ypred = Xt @ beta.reshape(-1, 1)
            ypred = ypred.reshape(-1) + b0
            yhat[k] = ypred.astype(np.float32)
            prev = ypred.astype(np.float32)

        yh_all.append(yhat)
        yt_all.append(ytru)
        kept_t0.append(t0)

    if len(yh_all) == 0:
        raise RuntimeError("No rolling windows produced. Check start/end/H/step.")

    return np.stack(yh_all, axis=0), np.stack(yt_all, axis=0), np.asarray(kept_t0, dtype=int)

start_ctx = max(test_start, L)
yhat_roll, ytrue_roll, t0_list = rolling_ols_recursive_nonoverlap_fair(
    X_s_red_full=X_s_red,
    y_s_full=y_s,
    start_ctx=start_ctx,
    end_T=T,
    H=H,
    step=STEP,
)

W = yhat_roll.shape[0]
diff = yhat_roll - ytrue_roll
rmse = float(np.sqrt(np.mean(diff ** 2)))
mae = float(np.mean(np.abs(diff)))

print("\n=== Rolling non-overlap (POOLED, TEST) [OLS recursive | FAIR vs TFT] ===")
print(f"windows={W} | step={STEP} | each window predicts H={H} | stations={N}")
print("RMSE:", rmse)
print("MAE :", mae)

yhat_f = yhat_roll.reshape(-1, N)
ytrue_f = ytrue_roll.reshape(-1, N)

qr_scaled = qrisk(ytrue_f, yhat_f, q=0.5)
ytrue_un = preprocessor.inverse_transform_targets(ytrue_f)
yhat_un = preprocessor.inverse_transform_targets(yhat_f)
qr_unscaled = qrisk(ytrue_un, yhat_un, q=0.5)

print("\n=== P50 q-risk (ROLLING, TEST) ===")
print("scaled  :", qr_scaled)
print("unscaled:", qr_unscaled)


# Plots: FIG1 first window (H steps), FIG2 full TEST stitched
t_start = int(t0_list[0])
t_end_excl = int(t0_list[0] + W * H)
times_fig2 = time_index[t_start:t_end_excl]
if len(times_fig2) != W * H:
    raise RuntimeError("Time axis length mismatch for stitched TEST plot.")

ytrue_stitched = ytrue_roll.reshape(W * H, N)
yhat_stitched = yhat_roll.reshape(W * H, N)

print(f"\n=== Plotting ALL {N} stations (2 figs per station) ===")
for j in range(N):
    station_id = int(use_idx[j])

    ytrue_fig1 = ytrue_roll[0, :, j]
    yhat_fig1 = yhat_roll[0, :, j]

    plt.figure(figsize=(9, 4), dpi=140)
    plt.plot(range(H), ytrue_fig1, "-o", linewidth=2, markersize=3, label="True (scaled)")
    plt.plot(range(H), yhat_fig1, "-o", linewidth=2, markersize=3, label="Pred (scaled)")
    plt.title(f"FIG1 OLS FAIR | Beijing st_{station_id} | first window | H={H} (scaled)")
    plt.xlabel("3-hour step within window")
    plt.ylabel("Scaled t")
    plt.grid(alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.savefig(FIG1_DIR / f"FIG1_OLS_fair_firstWindow_st{station_id}_H{H}_scaled.png")
    plt.close()

    plt.figure(figsize=(12, 4), dpi=140)
    plt.plot(times_fig2, ytrue_stitched[:, j], "-", linewidth=2, label="True (scaled)")
    plt.plot(times_fig2, yhat_stitched[:, j], "-", linewidth=2, label="Pred (scaled)")
    plt.title(f"FIG2 OLS FAIR | Beijing st_{station_id} | full TEST (stitched) | H={H} (scaled)")
    plt.xlabel("Time")
    plt.ylabel("Scaled t")
    plt.grid(alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.savefig(FIG2_DIR / f"FIG2_OLS_fair_testAll_st{station_id}_H{H}_scaled.png")
    plt.close()

print(f"Saved FIG1 to: {FIG1_DIR.resolve()}")
print(f"Saved FIG2 to: {FIG2_DIR.resolve()}")
print("All done.")


âœ… Loaded spatial_utils from: /home/wangxc1117/geospatial-neural-adapter/geospatial_neural_adapter/cpp_extensions/spatial_utils.so
Device: cuda
Weather2K: (1866, 13, 13632)
Beijing stations: 31
Using past covariates: ['ap', 'mxt', 'mnt', 'rh', 'p3', 'wd_sin', 'wd_cos', 'ws', 'mwd_sin', 'mwd_cos', 'mws']
Y: (13632, 31) | Xfull: (13632, 31, 12) | p_dim: 12

=== After scaling ===
X_s: (13632, 31, 12) | y_s: (13632, 31)
X_s finite: True | y_s finite: True

=== Feature variance screening (TRAIN ONLY, pooled) ===
  [00] lag1         std=9.999539852142e-01
  [01] ap           std=9.999643564224e-01
  [02] mxt          std=1.000031113625e+00
  [03] mnt          std=9.999288916588e-01
  [04] rh           std=9.999594688416e-01
  [05] p3           std=9.997370839119e-01
  [06] wd_sin       std=9.999293088913e-01
  [07] wd_cos       std=1.000008702278e+00
  [08] ws           std=1.000204324722e+00
  [09] mwd_sin      std=9.999343752861e-01
  [10] mwd_cos      std=9.999788999557e-01
  [11] mws   