In [None]:
import csv
from pathlib import Path
from typing import Any

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

from geospatial_neural_adapter.utils import (
    get_device_info,
    compute_ols_coefficients,
)
from geospatial_neural_adapter.metrics import compute_metrics
from geospatial_neural_adapter.data.preprocessing import prepare_all_with_scaling


GLOBAL_SEED = 42
np.random.seed(GLOBAL_SEED)
torch.manual_seed(GLOBAL_SEED)

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device_info = get_device_info()
print(f"Using {device_info['device'].upper()}: {device_info['device_name']}")
if device_info["device"] == "cuda":
    print(f"   Memory: {device_info['memory_gb']} GB")

H = 30

PAPER_TRAIN_START = pd.Timestamp("2015-01-01")
PAPER_TRAIN_END = pd.Timestamp("2015-12-01")
PAPER_VAL_DAYS = 30
PAPER_TEST_DAYS = 30

PAPER_VAL_START = PAPER_TRAIN_END + pd.Timedelta(days=1)
PAPER_VAL_END = PAPER_VAL_START + pd.Timedelta(days=PAPER_VAL_DAYS - 1)
PAPER_TEST_START = PAPER_VAL_END + pd.Timedelta(days=1)
PAPER_TEST_END = PAPER_TEST_START + pd.Timedelta(days=PAPER_TEST_DAYS - 1)

PAPER_ALL_START = PAPER_TRAIN_START
PAPER_ALL_END = PAPER_TEST_END

print("\n=== OFFICIAL (paper-like) window ===")
print("Train:", PAPER_TRAIN_START.date(), "→", PAPER_TRAIN_END.date())
print("Val  :", PAPER_VAL_START.date(), "→", PAPER_VAL_END.date())
print("Test :", PAPER_TEST_START.date(), "→", PAPER_TEST_END.date())
print("All  :", PAPER_ALL_START.date(), "→", PAPER_ALL_END.date())


def quantile_risk(y_true: np.ndarray, y_pred: np.ndarray, q: float = 0.5) -> float:
    y = np.asarray(y_true, dtype=np.float64).reshape(-1)
    yhat = np.asarray(y_pred, dtype=np.float64).reshape(-1)

    diff = y - yhat
    pinball = np.maximum(q * diff, (q - 1.0) * diff)
    denom = np.sum(np.abs(y))

    if denom <= 0:
        return float("nan")
    return float(2.0 * np.sum(pinball) / denom)


def inverse_standardize_targets(preprocessor: Any, y_scaled: np.ndarray) -> np.ndarray:
    y_scaled = np.asarray(y_scaled, dtype=np.float32)
    scaler = None

    for name in ["target_scaler", "target_scaler_", "y_scaler", "y_scaler_"]:
        if hasattr(preprocessor, name):
            scaler = getattr(preprocessor, name)
            if scaler is not None:
                break

    if scaler is None:
        raise AttributeError(
            "preprocessor does not expose a target scaler attribute (target_scaler/target_scaler_). "
            "Please check prepare_all_with_scaling() return object and adjust inverse_standardize_targets()."
        )

    if hasattr(scaler, "mean_") and hasattr(scaler, "scale_"):
        mean_ = np.asarray(scaler.mean_, dtype=np.float32)
        scale_ = np.asarray(scaler.scale_, dtype=np.float32)

        if mean_.ndim == 0:
            mean_ = mean_.reshape(1)
        if scale_.ndim == 0:
            scale_ = scale_.reshape(1)

        if mean_.shape[0] == 1:
            mean_b = mean_[0]
        else:
            mean_b = mean_.reshape(1, -1)

        if scale_.shape[0] == 1:
            scale_b = scale_[0]
        else:
            scale_b = scale_.reshape(1, -1)

        return y_scaled * scale_b + mean_b

    raise AttributeError(
        "Found a target scaler object but it does not have mean_ and scale_. "
        "Please adapt inverse_standardize_targets() to your scaler type."
    )


DATA_ROOT = Path("/home/wangxc1117/experiment_data/sales_forecasting_data")
TRAIN_PATH = DATA_ROOT / "train.csv"

if not TRAIN_PATH.exists():
    raise FileNotFoundError(f"train.csv not found at {TRAIN_PATH}")

print("\n=== Loading train.csv ===")
df = pd.read_csv(TRAIN_PATH, low_memory=False)
df["date"] = pd.to_datetime(df["date"])
df["onpromotion"] = df["onpromotion"].fillna(0).astype(int)

df = df[(df["date"] >= PAPER_ALL_START) & (df["date"] <= PAPER_ALL_END)].copy()
print(f"Paper-window date range (in df): {df['date'].min()} → {df['date'].max()}")

df_store_sales = (
    df.groupby(["date", "store_nbr"], as_index=False)["unit_sales"]
    .sum()
    .rename(columns={"unit_sales": "store_sales"})
)
df_store_sales["store_sales"] = df_store_sales["store_sales"].clip(lower=0.0)

stores = sorted(df_store_sales["store_nbr"].unique())
date_index = pd.date_range(start=PAPER_ALL_START, end=PAPER_ALL_END, freq="D")
T_all = len(date_index)
N = len(stores)
print(f"\nUsing ALL {N} stores in paper window.")
print(f"T_all (days): {T_all}")

full_idx = pd.MultiIndex.from_product([date_index, stores], names=["date", "store_nbr"])

panel_sales = (
    df_store_sales.set_index(["date", "store_nbr"])
    .reindex(full_idx)
    .sort_index()
)
panel_sales["store_sales"] = panel_sales.groupby("store_nbr")["store_sales"].ffill()
panel_sales.loc[panel_sales["store_sales"].isna(), "store_sales"] = 0.0
panel_sales["log_sales"] = np.log1p(panel_sales["store_sales"]).astype("float32")

Y_df = (
    panel_sales["log_sales"]
    .unstack("store_nbr")
    .reindex(index=date_index, columns=stores)
    .astype("float32")
)

print("\n=== Target sanity ===")
print("targets_full:", Y_df.to_numpy(dtype=np.float32).shape, "| has_na:", bool(np.isnan(Y_df.to_numpy()).any()))

cut_train = int((PAPER_TRAIN_END - PAPER_ALL_START).days + 1)
cut_val = cut_train + PAPER_VAL_DAYS
assert cut_val + PAPER_TEST_DAYS == T_all

print("\n=== OFFICIAL split lengths ===")
print("T_all:", T_all, "| train:", cut_train, "| val:", (cut_val - cut_train), "| test:", (T_all - cut_val))

lag1_df = Y_df.shift(1)

Y_shift = Y_df.iloc[1:, :].to_numpy(dtype=np.float32)
lag1_shift = lag1_df.iloc[1:, :].to_numpy(dtype=np.float32)

T = Y_shift.shape[0]
assert T == T_all - 1

feat_names_full = ["lag1"]

X_raw = lag1_shift[..., None].astype("float32")
y_raw = Y_shift.astype("float32")

p_dim = X_raw.shape[2]
print("\n=== OLS data (raw, before scaling) ===")
print("X_raw:", X_raw.shape, "| y_raw:", y_raw.shape, "| p_dim:", p_dim)

if np.isnan(X_raw).any():
    raise ValueError("NaN found in X_raw. Check lag1 construction.")
if np.isnan(y_raw).any():
    raise ValueError("NaN found in y_raw. Check target construction.")

train_len = cut_train - 1
val_len = PAPER_VAL_DAYS
test_len = PAPER_TEST_DAYS

val_start = train_len
test_start = (cut_val - 1)
assert val_start + val_len == test_start
assert test_start + test_len == T

print("\n=== SHIFTED indices (for lag1 model) ===")
print("T (shifted)     =", T)
print("train_len       =", train_len)
print("val_start       =", val_start, "| val_end(excl) =", val_start + val_len)
print("test_start      =", test_start, "| test_end(excl)=", test_start + test_len)

cat_dummy = np.zeros((T, N, 1), dtype=np.int64)
train_ratio = train_len / T
val_ratio = val_len / T

train_ds, val_ds, test_ds, preprocessor = prepare_all_with_scaling(
    cat_features=cat_dummy,
    cont_features=X_raw,
    targets=y_raw,
    train_ratio=train_ratio,
    val_ratio=val_ratio,
    feature_scaler_type="standard",
    target_scaler_type="standard",
    fit_on_train_only=True,
)


def _stitch_x(dsets) -> np.ndarray:
    xs = [ds.tensors[1].cpu().numpy().astype(np.float32) for ds in dsets]
    return np.concatenate(xs, axis=0)


def _stitch_y(dsets) -> np.ndarray:
    ys = [ds.tensors[2].cpu().numpy().astype(np.float32) for ds in dsets]
    return np.concatenate(ys, axis=0)


X_s = _stitch_x([train_ds, val_ds, test_ds])
y_s = _stitch_y([train_ds, val_ds, test_ds])

print("\n=== After scaling ===")
print("X_s:", X_s.shape, "| y_s:", y_s.shape)
print("[NaN/inf check after scaling]")
print("X_s has NaN:", bool(np.isnan(X_s).any()), "| inf:", bool(np.isinf(X_s).any()))
print("y_s has NaN:", bool(np.isnan(y_s).any()), "| inf:", bool(np.isinf(y_s).any()))

if (not np.isfinite(X_s).all()) or (not np.isfinite(y_s).all()):
    raise ValueError("Scaling produced non-finite values. Stop.")

EPS_STD = 1e-8

X_train_s = X_s[:train_len, :, :]
X_flat_train = X_train_s.reshape(-1, p_dim)

feat_std = X_flat_train.std(axis=0)

keep_mask = feat_std > EPS_STD
drop_mask = ~keep_mask

kept_idx = np.where(keep_mask)[0].tolist()
dropped_idx = np.where(drop_mask)[0].tolist()

print("\n=== Feature variance screening (TRAIN ONLY, pooled) ===")
print(f"EPS_STD = {EPS_STD}")
print("All feature std (TRAIN):")
for i, nm in enumerate(feat_names_full):
    print(f"  [{i:02d}] {nm:12s} std={feat_std[i]:.12e}")

if dropped_idx:
    print("\nDropped near-constant features (std <= EPS_STD):")
    for i in dropped_idx:
        print(f"  DROP [{i:02d}] {feat_names_full[i]}  std={feat_std[i]:.12e}")
else:
    print("\nDropped near-constant features: NONE")

print("\nKept features:")
for i in kept_idx:
    print(f"  KEEP [{i:02d}] {feat_names_full[i]}  std={feat_std[i]:.12e}")

if len(kept_idx) == 0:
    raise ValueError("All features were dropped as near-constant. Cannot fit OLS.")

X_s_red = X_s[:, :, keep_mask].astype(np.float32)
p_red = X_s_red.shape[2]
feat_names_red = [feat_names_full[i] for i in kept_idx]

print("\n=== Reduced feature matrix ===")
print("X_s_red:", X_s_red.shape, "| p_red:", p_red)
print("feat_names_red:", feat_names_red)

X_train_red_t = torch.from_numpy(X_s_red[:train_len]).to(DEVICE)
y_train_t = torch.from_numpy(y_s[:train_len]).to(DEVICE)

beta_red_t, b0 = compute_ols_coefficients(X_train_red_t, y_train_t, device=DEVICE)
beta_red = np.asarray(beta_red_t.detach().cpu().numpy(), dtype=np.float32).reshape(-1)
b0 = float(b0)

beta_full = np.zeros((p_dim,), dtype=np.float32)
beta_full[keep_mask] = beta_red

print("\n=== OLS fitted (reduced) ===")
print("beta_red shape:", beta_red.shape, "| b0:", b0)
print("beta_red finite:", bool(np.isfinite(beta_red).all()), "| b0 finite:", bool(np.isfinite(b0)))

print("\n=== OLS coefficients mapped back to FULL feature space ===")
print("beta_full shape:", beta_full.shape)
print("beta_full finite:", bool(np.isfinite(beta_full).all()))

if not np.isfinite(beta_red).all():
    raise ValueError("beta_red is non-finite. Consider CPU solve or ridge.")

IDX_LAG1_FULL = 0


def _map_full_idx_to_red(full_idx: int, keep_mask_arr: np.ndarray) -> int:
    if not keep_mask_arr[full_idx]:
        return -1
    return int(np.sum(keep_mask_arr[:full_idx]))


IDX_LAG1_RED = _map_full_idx_to_red(IDX_LAG1_FULL, keep_mask)

print("\n=== Feature index mapping (full -> reduced) ===")
print("lag1: full =", IDX_LAG1_FULL, "-> red =", IDX_LAG1_RED)

if IDX_LAG1_RED < 0:
    raise ValueError("lag1 was dropped as near-constant, which should not happen. Check your data.")


def ols_direct_h_lag1_only(
    beta_red: np.ndarray,
    b0: float,
    X_s_red_full: np.ndarray,
    y_s_full: np.ndarray,
    start_idx: int,
    horizon: int = 30,
) -> np.ndarray:
    beta_red = np.asarray(beta_red, dtype=np.float32).reshape(-1)
    T_tot, N_loc, p = X_s_red_full.shape
    assert p == beta_red.shape[0]

    end_idx = start_idx + horizon
    if start_idx < 1:
        raise ValueError("start_idx must be >=1 for lag1 anchor.")
    if end_idx > T_tot:
        raise ValueError(f"Need end_idx={end_idx} <= T_tot={T_tot}")

    yhat = np.empty((horizon, N_loc), dtype=np.float32)

    prev = y_s_full[start_idx - 1, :].astype(np.float32)

    for k, t in enumerate(range(start_idx, end_idx)):
        Xt = X_s_red_full[t, :, :].copy()
        Xt[:, IDX_LAG1_RED] = prev

        y_pred = (Xt @ beta_red.reshape(p, 1)).reshape(-1) + float(b0)

        if not np.isfinite(y_pred).all():
            bad = ~np.isfinite(y_pred)
            idx = int(np.where(bad)[0][0])
            print("\n[ERROR] non-finite in y_pred at step k =", k, " store_idx =", idx, " t =", t)
            print("  y_pred[idx] =", y_pred[idx])
            print("  prev[idx]   =", prev[idx])
            print("  Xt[idx,:] finite:", bool(np.isfinite(Xt[idx, :]).all()))
            print("  beta_red finite:", bool(np.isfinite(beta_red).all()), "| b0 finite:", bool(np.isfinite(b0)))
            raise ValueError("y_pred contains NaN/inf (likely numerical instability).")

        yhat[k, :] = y_pred.astype(np.float32)
        prev = y_pred.astype(np.float32)

    return yhat


yhat_val_s = ols_direct_h_lag1_only(
    beta_red=beta_red,
    b0=b0,
    X_s_red_full=X_s_red,
    y_s_full=y_s,
    start_idx=val_start,
    horizon=H,
)

yhat_test_s = ols_direct_h_lag1_only(
    beta_red=beta_red,
    b0=b0,
    X_s_red_full=X_s_red,
    y_s_full=y_s,
    start_idx=test_start,
    horizon=H,
)

y_val_true_s = y_s[val_start:val_start + H, :]
y_test_true_s = y_s[test_start:test_start + H, :]

val_y_t = torch.from_numpy(y_val_true_s).to(DEVICE)
test_y_t = torch.from_numpy(y_test_true_s).to(DEVICE)
yhat_val_t = torch.from_numpy(yhat_val_s).to(DEVICE)
yhat_test_t = torch.from_numpy(yhat_test_s).to(DEVICE)

rmse_v, mae_v, r2_v = compute_metrics(val_y_t, yhat_val_t)
rmse_t, mae_t, r2_t = compute_metrics(test_y_t, yhat_test_t)

print("\n=== OLS (LAG1 ONLY; no cov) OFFICIAL eval (standardized space) ===")
print(f"Val  RMSE={rmse_v:.6f}, MAE={mae_v:.6f}, R2={r2_v:.6f}")
print(f"Test RMSE={rmse_t:.6f}, MAE={mae_t:.6f}, R2={r2_t:.6f}")

qrisk_val_scaled_log = quantile_risk(y_val_true_s, yhat_val_s, q=0.5)
qrisk_test_scaled_log = quantile_risk(y_test_true_s, yhat_test_s, q=0.5)

y_val_true_log = inverse_standardize_targets(preprocessor, y_val_true_s)
yhat_val_log = inverse_standardize_targets(preprocessor, yhat_val_s)
y_test_true_log = inverse_standardize_targets(preprocessor, y_test_true_s)
yhat_test_log = inverse_standardize_targets(preprocessor, yhat_test_s)

qrisk_val_unscaled_log = quantile_risk(y_val_true_log, yhat_val_log, q=0.5)
qrisk_test_unscaled_log = quantile_risk(y_test_true_log, yhat_test_log, q=0.5)

y_val_true_sales = np.expm1(y_val_true_log).astype(np.float64)
yhat_val_sales = np.expm1(yhat_val_log).astype(np.float64)
y_test_true_sales = np.expm1(y_test_true_log).astype(np.float64)
yhat_test_sales = np.expm1(yhat_test_log).astype(np.float64)

y_val_true_sales = np.clip(y_val_true_sales, 0.0, None)
yhat_val_sales = np.clip(yhat_val_sales, 0.0, None)
y_test_true_sales = np.clip(y_test_true_sales, 0.0, None)
yhat_test_sales = np.clip(yhat_test_sales, 0.0, None)

qrisk_val_sales = quantile_risk(y_val_true_sales, yhat_val_sales, q=0.5)
qrisk_test_sales = quantile_risk(y_test_true_sales, yhat_test_sales, q=0.5)

print("\n=== P50 q-risk (Val) ===")
print(f"qrisk_p50_scaled_log   = {qrisk_val_scaled_log:.6f}")
print(f"qrisk_p50_unscaled_log = {qrisk_val_unscaled_log:.6f}")
print(f"qrisk_p50_sales        = {qrisk_val_sales:.6f}")

print("\n=== P50 q-risk (Test) ===")
print(f"qrisk_p50_scaled_log   = {qrisk_test_scaled_log:.6f}")
print(f"qrisk_p50_unscaled_log = {qrisk_test_unscaled_log:.6f}")
print(f"qrisk_p50_sales        = {qrisk_test_sales:.6f}")

OLS_PLOTS_DIR = Path(f"OLS_plots_store_paperWindow_h{H}_LAG1_ONLY")
OLS_PLOTS_DIR.mkdir(parents=True, exist_ok=True)

date_index_shift = date_index[1:]
assert len(date_index_shift) == y_s.shape[0] == X_s.shape[0]


def plot_store_two_figs_like_tft_ols(
    sid: int,
    y_true_scaled_full_shift: np.ndarray,
    yhat_test_scaled: np.ndarray,
    out_dir: Path,
):
    j = stores.index(sid)

    y_val_true = y_true_scaled_full_shift[val_start:val_start + H, j]
    y_test_true = y_true_scaled_full_shift[test_start:test_start + H, j]
    y_test_pred = yhat_test_scaled[:, j]

    dates_val = date_index_shift[val_start:val_start + H]
    dates_test = date_index_shift[test_start:test_start + H]

    dates_true_vt = dates_val.append(dates_test)
    y_true_vt = np.concatenate([y_val_true, y_test_true], axis=0)

    plt.figure(figsize=(10, 4), dpi=140)
    plt.plot(dates_true_vt, y_true_vt, "-", linewidth=2.0, color="k", label="True (Val+Test)")
    plt.plot(dates_test, y_test_pred, "-", linewidth=1.8, color="C3", label="Test Pred")
    plt.axvline(PAPER_VAL_START, linestyle="--", linewidth=1, label="train/val split")
    plt.axvline(PAPER_TEST_START, linestyle="--", linewidth=1, label="val/test split")
    plt.title(f"FIG1 H{H} store {sid}: True (continuous) + Test Pred (OLS LAG1 ONLY)")
    plt.xlabel("Date")
    plt.ylabel("Scaled log_sales")
    plt.grid(alpha=0.3)
    plt.legend()
    plt.tight_layout()
    p1 = out_dir / f"FIG1_store{sid}_h{H}_ols_lag1only.png"
    plt.savefig(p1)
    plt.close()

    plt.figure(figsize=(10, 4), dpi=140)
    y_all_true = y_true_scaled_full_shift[:, j]
    plt.plot(date_index_shift, y_all_true, "-", linewidth=1.8, color="k", label="All True")
    plt.plot(dates_test, y_test_pred, "-", linewidth=1.8, color="C3", label="Test Pred")
    plt.axvline(PAPER_VAL_START, linestyle="--", linewidth=1, label="train/val split")
    plt.axvline(PAPER_TEST_START, linestyle="--", linewidth=1, label="val/test split")
    plt.title(f"FIG2 H{H} store {sid}: All True + Test Pred (OLS LAG1 ONLY)")
    plt.xlabel("Date")
    plt.ylabel("Scaled log_sales")
    plt.grid(alpha=0.3)
    plt.legend()
    plt.tight_layout()
    p2 = out_dir / f"FIG2_store{sid}_h{H}_ols_lag1only.png"
    plt.savefig(p2)
    plt.close()


def plot_all_stores_two_figs_like_tft_ols(
    y_true_scaled_full_shift: np.ndarray,
    yhat_test_scaled: np.ndarray,
):
    print(f"\nStart plotting ALL {len(stores)} stores (2 figs/store) ...")
    for sid in stores:
        plot_store_two_figs_like_tft_ols(
            sid=sid,
            y_true_scaled_full_shift=y_true_scaled_full_shift,
            yhat_test_scaled=yhat_test_scaled,
            out_dir=OLS_PLOTS_DIR,
        )
    print(f"All figures saved under: {OLS_PLOTS_DIR}")


plot_all_stores_two_figs_like_tft_ols(
    y_true_scaled_full_shift=y_s,
    yhat_test_scaled=yhat_test_s,
)

csv_path = Path(f"metrics_summary_OLS_Favorita_STORE_H{H}_paperWindow_LAG1_ONLY.csv")
write_header = not csv_path.exists()

notes = (
    "X_full=[lag1 only]; "
    f"drop_near_constant(EPS_STD={EPS_STD}) on TRAIN ONLY; "
    "direct 30-step; paper window; no covariates; plots like TFT (2 figs/store)"
)

with csv_path.open("a", newline="") as f:
    w = csv.writer(f)
    if write_header:
        w.writerow([
            "seed", "model",
            "rmse_val", "mae_val", "r2_val",
            "rmse_test", "mae_test", "r2_test",
            "qrisk_p50_scaled_log_val", "qrisk_p50_scaled_log_test",
            "qrisk_p50_unscaled_log_val", "qrisk_p50_unscaled_log_test",
            "qrisk_p50_sales_val", "qrisk_p50_sales_test",
            "dropped_features",
            "kept_features",
            "notes",
        ])
    w.writerow([
        GLOBAL_SEED,
        f"OLS_STORE_H{H}_paperWindow_LAG1_ONLY",
        float(rmse_v), float(mae_v), float(r2_v),
        float(rmse_t), float(mae_t), float(r2_t),
        float(qrisk_val_scaled_log), float(qrisk_test_scaled_log),
        float(qrisk_val_unscaled_log), float(qrisk_test_unscaled_log),
        float(qrisk_val_sales), float(qrisk_test_sales),
        ";".join([feat_names_full[i] for i in dropped_idx]) if dropped_idx else "",
        ";".join(feat_names_red),
        notes,
    ])

print(f"\nSaved metrics to {csv_path}")
print(f"Saved OLS plots to folder: {OLS_PLOTS_DIR}")


✅ Loaded spatial_utils from: /home/wangxc1117/geospatial-neural-adapter/geospatial_neural_adapter/cpp_extensions/spatial_utils.so
Using CUDA: NVIDIA GeForce RTX 4060 Laptop GPU
   Memory: 8.6 GB

=== OFFICIAL (paper-like) window ===
Train: 2015-01-01 → 2015-12-01
Val  : 2015-12-02 → 2015-12-31
Test : 2016-01-01 → 2016-01-30
All  : 2015-01-01 → 2016-01-30

=== Loading train.csv ===
Paper-window date range (in df): 2015-01-01 00:00:00 → 2016-01-30 00:00:00

Using ALL 53 stores in paper window.
T_all (days): 395

=== Target sanity ===
targets_full: (395, 53) | has_na: False

=== OFFICIAL split lengths ===
T_all: 395 | train: 335 | val: 30 | test: 30

=== OLS data (raw, before scaling) ===
X_raw: (394, 53, 1) | y_raw: (394, 53) | p_dim: 1

=== SHIFTED indices (for lag1 model) ===
T (shifted)     = 394
train_len       = 334
val_start       = 334 | val_end(excl) = 364
test_start      = 364 | test_end(excl)= 394

=== After scaling ===
X_s: (394, 53, 1) | y_s: (394, 53)
[NaN/inf check after sc