In [1]:
# -*- coding: utf-8 -*-
# ============================================================
# Weather2K (Beijing subset) + OLS baseline (NO COV)
# - Scaling: POOLED (train-only) via prepare_all_with_scaling
# - Model  : OLS-AR(1) pooled across all stations:
#            y_t = b0 + a * y_{t-1}
# - Inference: Rolling non-overlap (step=H), H-step INTERNAL recursive
# - Metrics: RMSE/MAE + P50 q-risk (scaled/unscaled) over TEST rolling windows
# - Plots : ALL stations
#          FIG1: first rolling window (24 steps)
#          FIG2: full TEST (stitched rolling windows)
# ============================================================

from pathlib import Path
from typing import Tuple, List

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

from geospatial_neural_adapter.data.preprocessing import prepare_all_with_scaling


# ============================================================
# 0) Paths & settings
# ============================================================
W2K_PATH = Path("/home/wangxc1117/Weather2K/weather2k.npy")

GLOBAL_SEED = 42
np.random.seed(GLOBAL_SEED)
torch.manual_seed(GLOBAL_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# Split / rolling config (match your TFT pooled setup)
train_ratio = 0.70
val_ratio = 0.15
L, H = 56, 24  # L only used for rolling start constraint
STEP = H

# Plot config
PLOTS_DIR = Path("OLS_noCov_plots_weather2k_beijing_pooled")
PLOTS_DIR.mkdir(parents=True, exist_ok=True)


# ============================================================
# 1) Load data
# ============================================================
arr = np.load(W2K_PATH, allow_pickle=False).astype(np.float32)
S, V, T = arr.shape
if V != 13:
    raise ValueError(f"Expected 13 variables, got V={V}")
print("Weather2K:", arr.shape)

time_index = pd.date_range("2000-01-01", periods=T, freq="3H")

# Variable indices:
# 0 lat, 1 lon, 2 alt, 3 ap, 4 t, 5 mxt, 6 mnt, 7 rh, 8 p3, 9 wd, 10 ws, 11 mwd, 12 mws
TARGET_IDX = 4

# Beijing subset
LAT_MIN, LAT_MAX = 39.4, 41.1
LON_MIN, LON_MAX = 115.4, 117.5
lat = arr[:, 0, 0]
lon = arr[:, 1, 0]
use_idx = np.where(
    (lat >= LAT_MIN) & (lat <= LAT_MAX) &
    (lon >= LON_MIN) & (lon <= LON_MAX)
)[0]
N = len(use_idx)
print("Beijing stations:", N)
if N == 0:
    raise RuntimeError("No Beijing stations found. Check bounding box.")


# ============================================================
# 2) Build pooled arrays (targets Y, and features Xfull = [lag1] only)
# ============================================================
# Y: (T, N)
Y = np.stack([arr[s, TARGET_IDX] for s in use_idx], axis=1).astype(np.float32)

# lag1 feature: (T, N, 1)
lag1 = np.zeros((T, N, 1), dtype=np.float32)
lag1[1:, :, 0] = Y[:-1, :]

Xfull = lag1  # (T, N, 1)
feat_names_full = ["lag1"]
p_dim = 1

print("Y:", Y.shape, "| Xfull:", Xfull.shape, "| p_dim:", p_dim)

# Dummy cat placeholder for prepare_all_with_scaling
cat_dummy = np.zeros((T, N, 1), dtype=np.int64)


# ============================================================
# 3) Pooled train-only scaling (match TFT pooled pipeline)
# ============================================================
train_ds, val_ds, test_ds, preprocessor = prepare_all_with_scaling(
    cat_features=cat_dummy,
    cont_features=Xfull,
    targets=Y,
    train_ratio=train_ratio,
    val_ratio=val_ratio,
    feature_scaler_type="standard",
    target_scaler_type="standard",
    fit_on_train_only=True,
)

def stitch_from_dsets(dsets) -> Tuple[np.ndarray, np.ndarray]:
    Xs = np.concatenate([ds.tensors[1].cpu().numpy().astype(np.float32) for ds in dsets], axis=0)
    Ys = np.concatenate([ds.tensors[2].cpu().numpy().astype(np.float32) for ds in dsets], axis=0)
    return Xs, Ys

X_s, y_s = stitch_from_dsets([train_ds, val_ds, test_ds])
assert X_s.shape == Xfull.shape
assert y_s.shape == Y.shape

print("\n=== After scaling ===")
print("X_s:", X_s.shape, "| y_s:", y_s.shape)
print("X_s finite:", bool(np.isfinite(X_s).all()), "| y_s finite:", bool(np.isfinite(y_s).all()))
if not np.isfinite(X_s).all() or not np.isfinite(y_s).all():
    raise ValueError("Non-finite values after scaling.")


# ============================================================
# 4) OLS fit (pooled) on TRAIN
#    y = b0 + a * lag1
# ============================================================
cut_train = int(T * train_ratio)
cut_val = int(T * (train_ratio + val_ratio))

train_len = cut_train

# Pooled flatten
Xtr = X_s[:train_len].reshape(-1, p_dim)   # (train_len*N, 1)
ytr = y_s[:train_len].reshape(-1, 1)       # (train_len*N, 1)

Xtr_t = torch.from_numpy(Xtr).to(DEVICE)
ytr_t = torch.from_numpy(ytr).to(DEVICE)

# Add intercept
ones = torch.ones((Xtr_t.shape[0], 1), device=DEVICE, dtype=Xtr_t.dtype)
X_aug = torch.cat([ones, Xtr_t], dim=1)  # (M, 2)

beta_aug = torch.linalg.lstsq(X_aug, ytr_t).solution.squeeze(1)  # (2,)

b0 = float(beta_aug[0].detach().cpu().item())
a = float(beta_aug[1].detach().cpu().item())

print("\n=== OLS-AR(1) fitted (pooled, NO COV) ===")
print("b0:", b0)
print("a :", a)


# ============================================================
# 5) Rolling non-overlap on TEST, H-step INTERNAL recursive
# ============================================================
def qrisk(y_true, y_pred, q=0.5, eps=1e-8):
    y_true = np.asarray(y_true, dtype=np.float64)
    y_pred = np.asarray(y_pred, dtype=np.float64)
    e = y_true - y_pred
    return float(2.0 * np.sum(np.maximum(q * e, (q - 1) * e)) / (np.sum(np.abs(y_true)) + eps))

def rolling_ols_ar1_recursive_nonoverlap(
    y_s_full: np.ndarray,       # (T, N), scaled targets (true)
    start_ctx: int,
    end_T: int,
    H: int,
    step: int,
    b0: float,
    a: float,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    For each window starting at t0:
      - anchor prev = true y at (t0-1)
      - for k in 0..H-1:
          yhat = b0 + a * prev
          prev = yhat
    Return:
      yhat_roll: (W, H, N)
      ytrue_roll:(W, H, N)
      t0_list:   (W,)
    """
    t0_list = list(range(start_ctx, end_T - H + 1, step))
    yh_all, yt_all, kept_t0 = [], [], []

    for t0 in t0_list:
        if t0 < 1:
            continue

        ytru = y_s_full[t0:t0+H].astype(np.float32)  # (H,N)

        prev = y_s_full[t0 - 1].astype(np.float32)   # (N,)
        yhat = np.empty((H, y_s_full.shape[1]), dtype=np.float32)

        for k in range(H):
            ypred = (b0 + a * prev).astype(np.float32)
            yhat[k] = ypred
            prev = ypred

        yh_all.append(yhat)
        yt_all.append(ytru)
        kept_t0.append(t0)

    return np.stack(yh_all, axis=0), np.stack(yt_all, axis=0), np.asarray(kept_t0, dtype=int)

start_ctx = max(cut_val, L)  # match TFT: start from TEST start, ensure enough context
yhat_roll, ytrue_roll, t0_list = rolling_ols_ar1_recursive_nonoverlap(
    y_s_full=y_s,
    start_ctx=start_ctx,
    end_T=T,
    H=H,
    step=STEP,
    b0=b0,
    a=a,
)

W = yhat_roll.shape[0]
diff = yhat_roll - ytrue_roll
rmse = float(np.sqrt(np.mean(diff ** 2)))
mae = float(np.mean(np.abs(diff)))

print("\n=== Rolling non-overlap (POOLED, TEST) [OLS-AR(1) recursive | NO COV] ===")
print(f"windows={W} | step={STEP} | each window predicts H={H} | stations={N}")
print("RMSE:", rmse)
print("MAE :", mae)

yhat_f = yhat_roll.reshape(-1, N)
ytrue_f = ytrue_roll.reshape(-1, N)

qr_scaled = qrisk(ytrue_f, yhat_f, q=0.5)

ytrue_un = preprocessor.inverse_transform_targets(ytrue_f)
yhat_un = preprocessor.inverse_transform_targets(yhat_f)
qr_unscaled = qrisk(ytrue_un, yhat_un, q=0.5)

print("\n=== P50 q-risk (ROLLING, TEST) ===")
print("scaled  :", qr_scaled)
print("unscaled:", qr_unscaled)


# ============================================================
# 6) Plots: ALL stations
#    FIG1: first rolling window (24 steps)
#    FIG2: full TEST (stitched rolling windows)
# ============================================================
# Prepare stitched timeline (non-overlap => contiguous)
t_start = int(t0_list[0])
t_end_excl = int(t0_list[0] + W * H)
times_fig2 = time_index[t_start:t_end_excl]

ytrue_fig2_all = ytrue_roll.reshape(W * H, N)
yhat_fig2_all = yhat_roll.reshape(W * H, N)

print("\n=== Plotting ALL stations ===")
for j in range(N):
    station_id = int(use_idx[j])

    # FIG1: first window
    ytrue_fig1 = ytrue_roll[0, :, j]
    yhat_fig1 = yhat_roll[0, :, j]

    plt.figure(figsize=(9, 4), dpi=140)
    plt.plot(range(H), ytrue_fig1, "-o", linewidth=2, markersize=3, label="True (scaled)")
    plt.plot(range(H), yhat_fig1, "-o", linewidth=2, markersize=3, label="Pred (scaled)")
    plt.title(f"FIG1 OLS-AR(1) NO COV | Beijing st_{station_id} | first window | H={H}")
    plt.xlabel("3-hour step within window")
    plt.ylabel("Scaled t")
    plt.grid(alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.savefig(PLOTS_DIR / f"FIG1_OLS_AR1_noCov_firstWindow_st{station_id}_H{H}_scaled.png")
    plt.close()

    # FIG2: full TEST stitched
    ytrue_fig2 = ytrue_fig2_all[:, j]
    yhat_fig2 = yhat_fig2_all[:, j]

    plt.figure(figsize=(12, 4), dpi=140)
    plt.plot(times_fig2, ytrue_fig2, "-", linewidth=2, label="True (scaled)")
    plt.plot(times_fig2, yhat_fig2, "-", linewidth=2, label="Pred (scaled)")
    plt.title(f"FIG2 OLS-AR(1) NO COV | Beijing st_{station_id} | full TEST (stitched) | H={H}")
    plt.xlabel("Time")
    plt.ylabel("Scaled t")
    plt.grid(alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.savefig(PLOTS_DIR / f"FIG2_OLS_AR1_noCov_testAll_st{station_id}_H{H}_scaled.png")
    plt.close()

print(f"Saved plots to: {PLOTS_DIR.resolve()}")
print("All done.")


âœ… Loaded spatial_utils from: /home/wangxc1117/geospatial-neural-adapter/geospatial_neural_adapter/cpp_extensions/spatial_utils.so
Device: cuda
Weather2K: (1866, 13, 13632)
Beijing stations: 31
Y: (13632, 31) | Xfull: (13632, 31, 1) | p_dim: 1

=== After scaling ===
X_s: (13632, 31, 1) | y_s: (13632, 31)
X_s finite: True | y_s finite: True

=== OLS-AR(1) fitted (pooled, NO COV) ===
b0: -1.1928759313661885e-08
a : 0.9588793516159058

=== Rolling non-overlap (POOLED, TEST) [OLS-AR(1) recursive | NO COV] ===
windows=85 | step=24 | each window predicts H=24 | stations=31
RMSE: 0.6354440450668335
MAE : 0.5235574841499329

=== P50 q-risk (ROLLING, TEST) ===
scaled  : 0.5998525710005296
unscaled: 0.4215433376719579

=== Plotting ALL stations ===
Saved plots to: /home/wangxc1117/TFTModel-use/geospatial-neural-adapter-dev/examples/try/weather2k/OLS_nocov/OLS_noCov_plots_weather2k_beijing_pooled
All done.
