In [None]:
# -*- coding: utf-8 -*-
# ============================================================
# Synthetic spatio-temporal data + TFT baseline (Darts)
# [NAR(1) DGP | 1D locs | Gate A2_global_v2 centered | NO static cov | ONLY 2 GAMMAS]
#
# - Data: generate_time_NAR_synthetic_data (A2_global_v2, centered by rolling mean)  <-- SAME AS OLS
# - Locations: 1D synthetic locations (N points)
# - Target: y (T, N), multi-site
# - Dynamic covariates: synthetic features (f1, f2, f3) as past covariates
# - Static covariates: NONE
# - Scaling: POOLED (train-only) via prepare_all_with_scaling (targets scaled train-only)
# - Train/Val/Test: time split (70/15/15)
# - Early stopping: INTERNAL validation (tail of train)
# - Evaluation: Rolling non-overlap over TEST (step=H)
# - Metrics: RMSE/MAE + P50 q-risk (scaled/unscaled)
# - Runs ONLY 2 settings:
#     (y_rho=0.85, win=50, y_gamma=0.10) and (y_rho=0.85, win=50, y_gamma=0.20)
# - Plots (ALL grid points, 2 figs/point):
#     FIG1: TEST first H steps (first rolling window)
#     FIG2: ALL TEST (stitched rolling predictions)
# - Saves summary CSV with 2 rows (one per gamma)
#
# IMPORTANT:
# - This script prints y_md5 (hash) so you can verify the TFT run uses EXACTLY the same y
#   as your OLS/VAR/VARX scripts under identical parameters + seed.
# ============================================================

from pathlib import Path
from typing import List, Tuple, Optional, Dict, Any

import hashlib
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

from darts import TimeSeries
from darts.models import TFTModel
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import CSVLogger

from geospatial_neural_adapter.data.preprocessing import prepare_all_with_scaling


# ============================================================
# 0) Utilities
# ============================================================
def qrisk(y_true, y_pred, q=0.5, eps=1e-8) -> float:
    y_true = np.asarray(y_true, dtype=np.float64)
    y_pred = np.asarray(y_pred, dtype=np.float64)
    e = y_true - y_pred
    return float(
        2.0 * np.sum(np.maximum(q * e, (q - 1) * e)) / (np.sum(np.abs(y_true)) + eps)
    )


def rolling_nonoverlap(model, series, pcov, start, end, L, H):
    """
    Returns:
      yhat_roll: (W, H, N)
      ytrue_roll:(W, H, N)
      t0_list: list of window start indices
    """
    yh, yt = [], []
    t0_list = []
    for t0 in range(start, end - H + 1, H):
        if t0 < L:
            continue

        ctx_s = [s[:t0] for s in series]
        ctx_p = [p[:t0] for p in pcov]

        preds = model.predict(
            n=H,
            series=ctx_s,
            past_covariates=ctx_p,
            verbose=False,
        )
        preds = preds if isinstance(preds, list) else [preds]

        yh.append(
            np.stack([p.values(copy=False)[:, 0] for p in preds], axis=1).astype(np.float32)
        )  # (H, N)
        yt.append(
            np.stack([s[t0:t0 + H].values(copy=False)[:, 0] for s in series], axis=1).astype(np.float32)
        )  # (H, N)
        t0_list.append(int(t0))

    if len(yh) == 0:
        raise RuntimeError("No rolling windows produced. Check start/end/L/H.")
    return np.stack(yh, axis=0), np.stack(yt, axis=0), t0_list


def plot_two_figs_all_points(
    out_dir: Path,
    *,
    exp_tag: str,
    scenario_key: str,
    N: int,
    H: int,
    w0: int,
    t0_first: int,
    time_index: pd.DatetimeIndex,
    test_start: int,
    test_end: int,
    ytrue_roll: np.ndarray,   # (W,H,N) scaled
    yhat_roll: np.ndarray,    # (W,H,N) scaled
    Y_scaled: np.ndarray,     # (T,N) scaled
    t0_list: List[int],
):
    out_dir.mkdir(parents=True, exist_ok=True)

    dates_first = time_index[t0_first:t0_first + H]
    dates_test = time_index[test_start:test_end]

    test_len = test_end - test_start
    y_true_test = Y_scaled[test_start:test_end, :].astype(np.float32)  # (test_len, N)

    y_pred_test = np.full((test_len, N), np.nan, dtype=np.float32)
    for w, t0 in enumerate(t0_list):
        a = t0 - test_start
        b = a + H
        if a < 0 or b > test_len:
            continue
        y_pred_test[a:b, :] = yhat_roll[w]

    for j in range(N):
        point_id = j

        # FIG1: first rolling window, H steps
        y_true_H = ytrue_roll[w0][:, j]
        y_pred_H = yhat_roll[w0][:, j]

        plt.figure(figsize=(10, 4), dpi=140)
        plt.plot(dates_first, y_true_H, "-o", linewidth=2, markersize=3, label="True")
        plt.plot(dates_first, y_pred_H, "-o", linewidth=2, markersize=3, label="Pred")
        plt.title(
            f"FIG1 {scenario_key} | NAR(1) A2_global_v2 | grid_{point_id:03d} | "
            f"TEST first {H} steps | y (scaled)"
        )
        plt.xlabel("Time")
        plt.ylabel("Scaled target")
        plt.grid(alpha=0.3)
        plt.legend()
        plt.tight_layout()
        p1 = out_dir / f"FIG1_{exp_tag}_grid{point_id:03d}_test_firstH{H}_t0{t0_first}.png"
        plt.savefig(p1)
        plt.close()

        # FIG2: all TEST (stitched rolling predictions)
        y_true_all = y_true_test[:, j]
        y_pred_all = y_pred_test[:, j]

        plt.figure(figsize=(12, 4), dpi=140)
        plt.plot(dates_test, y_true_all, "-", linewidth=1.8, label="True")
        plt.plot(dates_test, y_pred_all, "-", linewidth=1.8, label="Pred (stitched rolling)")
        plt.axvline(time_index[test_start], linestyle="--", linewidth=1, label="TEST start")
        plt.title(
            f"FIG2 {scenario_key} | NAR(1) A2_global_v2 | grid_{point_id:03d} | ALL TEST | y (scaled)"
        )
        plt.xlabel("Time")
        plt.ylabel("Scaled target")
        plt.grid(alpha=0.3)
        plt.legend()
        plt.tight_layout()
        p2 = out_dir / f"FIG2_{exp_tag}_grid{point_id:03d}_test_all_stitched_step{H}.png"
        plt.savefig(p2)
        plt.close()

    print(f"Saved 2 figs/point for {N} points under: {out_dir}")


# ============================================================
# 1) NAR generator (COPY FROM YOUR OLS SCRIPT)  <<<<<<<< IMPORTANT
# ============================================================
def _generate_spatial_basis(locations: np.ndarray) -> np.ndarray:
    spatial_basis = np.exp(-(locations**2))[:, None]
    spatial_basis /= (np.linalg.norm(spatial_basis) + 1e-12)
    return spatial_basis


def generate_time_NAR_synthetic_data(
    locs: np.ndarray,
    n_time_steps: int,
    noise_std: float,
    eigenvalue: float,
    eta_rho: float = 0.8,
    f_rho: float = 0.6,
    global_mean: float = 50.0,
    feature_noise_std: float = 0.0,
    non_linear_strength: float = 0.0,
    seed: Optional[int] = None,
    y_rho: float = 0.4,
    y_gamma: float = 1.2,
    nar_scale_window: int = 200,
    # ---- Gate A2 params ----
    gate_type: str = "A2_global_v2",
    gate_threshold: float = 0.0,
    gate_low: float = 0.5,
    gate_high: float = 1.0,
    debug_gate_ratio: bool = True,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Gate A2 (GLOBAL) - Version 2:
      z_global(t) = ( mean(prev) - mean(y[w0:t]) ) / std(y[w0:t])   <-- rolling-mean centered
      gate(t) = 1{ z_global > threshold }
      rho_eff = y_rho * ( gate_low + (gate_high-gate_low)*gate )

    NOTE:
      - This makes gate actually flip around 0 (since centered),
        instead of being stuck positive due to global_mean=50.
    """
    if seed is not None:
        np.random.seed(seed)

    locs = np.asarray(locs, dtype=np.float32)
    if locs.ndim != 1:
        raise ValueError(f"expects 1D locs (N,), got {locs.shape}")

    N = len(locs)
    p = 3

    cont = np.random.randn(n_time_steps, N, p).astype(np.float32)

    spatial_pattern = np.sin(locs * np.pi / 5.0).astype(np.float32)
    temporal_trend = np.linspace(0, 2 * np.pi, n_time_steps).astype(np.float32)

    cont[:, :, 0] += (np.outer(temporal_trend, spatial_pattern) * 0.5).astype(np.float32)
    cont[:, :, 1] += (np.outer(temporal_trend, np.ones(N, dtype=np.float32)) * 0.3).astype(np.float32)
    cont[:, :, 2] += (cont[:, :, 0] * cont[:, :, 1] * 0.2).astype(np.float32)

    if feature_noise_std > 0:
        cont += (np.random.randn(*cont.shape) * feature_noise_std).astype(np.float32)

    spatial_basis = _generate_spatial_basis(locs).reshape(-1).astype(np.float32)  # (N,)

    spatial_weights = np.zeros(n_time_steps, dtype=np.float32)
    trend_drift = np.zeros(n_time_steps, dtype=np.float32)

    spatial_weights[0] = float(np.random.randn() * 1.0)
    trend_drift[0] = float(np.random.randn() * 1.0)

    for t in range(1, n_time_steps):
        spatial_weights[t] = eta_rho * spatial_weights[t - 1] + float(np.random.randn() * 0.1)
        trend_drift[t] = f_rho * trend_drift[t - 1] + float(np.random.randn() * 0.1)

    y = np.zeros((n_time_steps, N), dtype=np.float32)
    eps = 1e-6

    def _feat_term(t: int) -> np.ndarray:
        feat = (0.3 * cont[t, :, 0] + 0.4 * cont[t, :, 1] + 0.2 * cont[t, :, 2]).astype(np.float32)
        if non_linear_strength > 0:
            feat = (
                feat
                + non_linear_strength * (cont[t, :, 0] ** 2)
                + non_linear_strength * 0.5 * cont[t, :, 0] * cont[t, :, 1]
                + non_linear_strength * 0.3 * np.sin(cont[t, :, 2])
            ).astype(np.float32)
        return feat

    # init
    trend0 = global_mean + float(trend_drift[0])
    spatial0 = eigenvalue * float(spatial_weights[0]) * spatial_basis
    noise0 = (np.random.randn(N) * noise_std).astype(np.float32)
    y[0] = (trend0 + spatial0 + _feat_term(0) + noise0).astype(np.float32)

    gate_cnt = 0.0

    for t in range(1, n_time_steps):
        trend = global_mean + float(trend_drift[t])
        spatial = eigenvalue * float(spatial_weights[t]) * spatial_basis
        noise = (np.random.randn(N) * noise_std).astype(np.float32)

        prev = y[t - 1]
        w0 = max(0, t - nar_scale_window)

        # rolling scale + rolling mean (Version 2)
        mu = float(np.mean(y[w0:t]))                       # rolling mean over window
        scale = float(np.std(y[w0:t]) + eps)               # rolling std over window
        z_global = float((np.mean(prev) - mu) / scale)     # centered global gate signal

        gate = 1.0 if (z_global > gate_threshold) else 0.0
        gate_cnt += gate

        rho_eff = y_rho * (gate_low + (gate_high - gate_low) * gate)

        rec = (rho_eff * prev + y_gamma * np.tanh(prev / scale)).astype(np.float32)

        y[t] = (trend + spatial + _feat_term(t) + rec + noise).astype(np.float32)

    if debug_gate_ratio:
        gate_ratio = gate_cnt / float(n_time_steps - 1)
        print(f"[Gate {gate_type}] gate_ratio={gate_ratio:.4f} | "
              f"thr={gate_threshold} | low={gate_low} | high={gate_high} | win={nar_scale_window}")

    cat = np.zeros((n_time_steps, N, 0), dtype=np.int64)
    return cat, cont, y


# ============================================================
# 2) Global settings & dirs
# ============================================================
GLOBAL_SEED = 42
np.random.seed(GLOBAL_SEED)
torch.manual_seed(GLOBAL_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

try:
    EXP_ROOT = Path(__file__).resolve().parent
except NameError:
    EXP_ROOT = Path.cwd()

# Scenario table
SCENARIOS: Dict[str, Dict[str, float]] = {
    "A":      {"noise_std": 0.3, "feature_noise_std": 0.0, "non_linear_strength": 0.0, "eta_rho": 0.8, "f_rho": 0.6},
    "Aprime": {"noise_std": 0.3, "feature_noise_std": 0.0, "non_linear_strength": 0.0, "eta_rho": 0.6, "f_rho": 0.6},
    "B":      {"noise_std": 0.5, "feature_noise_std": 0.2, "non_linear_strength": 1.0, "eta_rho": 0.8, "f_rho": 0.6},
    "Bprime": {"noise_std": 0.5, "feature_noise_std": 0.2, "non_linear_strength": 1.0, "eta_rho": 0.6, "f_rho": 0.6},
    "C":      {"noise_std": 0.5, "feature_noise_std": 0.2, "non_linear_strength": 2.0, "eta_rho": 0.8, "f_rho": 0.6},
    "D":      {"noise_std": 0.5, "feature_noise_std": 0.2, "non_linear_strength": 2.5, "eta_rho": 0.4, "f_rho": 0.6},
    "E":      {"noise_std": 0.5, "feature_noise_std": 0.2, "non_linear_strength": 2.5, "eta_rho": 0.4, "f_rho": 0.8},
}

# FIX to Scenario C (match your current experiment)
SCENARIO_KEY = "C"
SC = SCENARIOS[SCENARIO_KEY]

# NAR params: ONLY 2 GAMMAS
Y_RHO = 0.85
NAR_SCALE_WINDOW = 50
GATE_TYPE = "A2_global_v2"
GATE_THRESHOLD = 0.0
GATE_LOW = 0.5
GATE_HIGH = 1.0
GAMMA_LIST = [0.10, 0.20]

# World size
N_POINTS = 36
T_TOTAL = 1500
EIGENVALUE = 3.0
locs = np.linspace(-5.0, 5.0, N_POINTS).astype(np.float32)

# Split
train_ratio = 0.70
val_ratio = 0.15

# TFT hyperparams
L, H = 56, 24
INTERNAL_VAL_STEPS = 240

N_EPOCHS = 30
HIDDEN_SIZE = 64
N_HEADS = 4
DROPOUT = 0.1
BATCH_SIZE = 32
LR = 3e-4

# Time index
time_start = pd.Timestamp("2000-01-01 00:00:00")
time_index = pd.date_range(start=time_start, periods=T_TOTAL, freq="1H")
freq = "1H"

PAST_COV_COLS = ["f1", "f2", "f3"]

# output dirs
ROOT_TAG = f"synth_tft_{SCENARIO_KEY}_NAR_{GATE_TYPE}_yrho{Y_RHO}_win{NAR_SCALE_WINDOW}_1D_NOSTATIC_ONLY2G"
CKPT_DIR = (EXP_ROOT / f"darts_ckpt_{ROOT_TAG}").resolve()
RUNS_DIR = (EXP_ROOT / f"TFT_runs_{ROOT_TAG}").resolve()
PLOTS_DIR = (EXP_ROOT / f"TFT_plots_{ROOT_TAG}_2figs").resolve()
RESULTS_DIR = (EXP_ROOT / f"TFT_results_{ROOT_TAG}").resolve()

for d in [CKPT_DIR, RUNS_DIR, PLOTS_DIR, RESULTS_DIR]:
    d.mkdir(parents=True, exist_ok=True)

SUMMARY_ROWS: List[Dict[str, Any]] = []


# ============================================================
# 3) Loop only two gammas
# ============================================================
for Y_GAMMA in GAMMA_LIST:
    print("\n" + "=" * 60)
    print(f"RUN TFT | SCENARIO={SCENARIO_KEY} | y_rho={Y_RHO} | win={NAR_SCALE_WINDOW} | y_gamma={Y_GAMMA}")
    print("=" * 60)

    EXP_TAG = f"{ROOT_TAG}_ygamma{Y_GAMMA}"

    # per-gamma dirs
    ckpt_dir = (CKPT_DIR / f"ygamma_{Y_GAMMA}").resolve()
    runs_dir = (RUNS_DIR / f"ygamma_{Y_GAMMA}").resolve()
    plots_dir = (PLOTS_DIR / f"ygamma_{Y_GAMMA}").resolve()
    for d in [ckpt_dir, runs_dir, plots_dir]:
        d.mkdir(parents=True, exist_ok=True)

    # ============================================================
    # 3.1) Generate synthetic data (NAR, EXACTLY SAME AS OLS)
    # ============================================================
    print("\n=== Scenario settings ===")
    print("Scenario:", SCENARIO_KEY)
    print(SC)
    print("NAR params:", {"y_rho": Y_RHO, "y_gamma": Y_GAMMA, "nar_scale_window": NAR_SCALE_WINDOW, "gate_type": GATE_TYPE})

    cat_synth, cont_synth, y_synth = generate_time_NAR_synthetic_data(
        locs=locs,
        n_time_steps=T_TOTAL,
        noise_std=SC["noise_std"],
        eigenvalue=EIGENVALUE,
        eta_rho=SC["eta_rho"],
        f_rho=SC["f_rho"],
        global_mean=50.0,
        feature_noise_std=SC["feature_noise_std"],
        non_linear_strength=SC["non_linear_strength"],
        seed=GLOBAL_SEED,
        y_rho=float(Y_RHO),
        y_gamma=float(Y_GAMMA),
        nar_scale_window=int(NAR_SCALE_WINDOW),
        gate_type=GATE_TYPE,
        gate_threshold=float(GATE_THRESHOLD),
        gate_low=float(GATE_LOW),
        gate_high=float(GATE_HIGH),
        debug_gate_ratio=True,
    )

    # ---- HARD identity check (compare with OLS side) ----
    y_md5 = hashlib.md5(y_synth.tobytes()).hexdigest()
    print("Synthetic y shape:", y_synth.shape)
    print("Synthetic cont shape:", cont_synth.shape)
    print("y stats (unscaled) mean/std/mean|y|:",
          float(y_synth.mean()), float(y_synth.std()), float(np.mean(np.abs(y_synth))))
    print("y_md5:", y_md5)

    Y = y_synth.astype(np.float32)     # (T, N)
    X = cont_synth.astype(np.float32)  # (T, N, 3)
    T_total, N = Y.shape

    # dummy categorical features (for prepare_all_with_scaling interface)
    cat_dummy = np.zeros((T_total, N, 1), dtype=np.int64)

    # ============================================================
    # 3.2) Train/Val/Test split + pooled scaling (train-only)
    # ============================================================
    cut_train = int(T_total * train_ratio)
    cut_val = int(T_total * (train_ratio + val_ratio))
    if not (0 < cut_train < cut_val < T_total):
        raise ValueError("Bad split indices computed from ratios.")

    print("\n=== Split (by time index) ===")
    print("Train:", 0, "->", cut_train - 1, "| len =", cut_train)
    print("Val  :", cut_train, "->", cut_val - 1, "| len =", cut_val - cut_train)
    print("Test :", cut_val, "->", T_total - 1, "| len =", T_total - cut_val)

    train_ds, val_ds, test_ds, preprocessor = prepare_all_with_scaling(
        cat_features=cat_dummy,
        cont_features=X,
        targets=Y,
        train_ratio=train_ratio,
        val_ratio=val_ratio,
        feature_scaler_type="standard",
        target_scaler_type="standard",
        fit_on_train_only=True,
    )

    def stitch_targets(dsets) -> np.ndarray:
        return np.concatenate([ds.tensors[2].detach().cpu().numpy().astype(np.float32) for ds in dsets], axis=0)

    def stitch_cont_features(dsets) -> np.ndarray:
        return np.concatenate([ds.tensors[1].detach().cpu().numpy().astype(np.float32) for ds in dsets], axis=0)

    Y_scaled = stitch_targets([train_ds, val_ds, test_ds])        # (T, N)
    X_scaled = stitch_cont_features([train_ds, val_ds, test_ds])  # (T, N, 3)

    print("\n=== After pooled scaling (train-only) ===")
    print("Y_scaled:", Y_scaled.shape, "| finite:", bool(np.isfinite(Y_scaled).all()))
    print("X_scaled:", X_scaled.shape, "| finite:", bool(np.isfinite(X_scaled).all()))

    # ============================================================
    # 3.3) Build TimeSeries lists (target + past cov). NO static cov.
    # ============================================================
    series_all: List[TimeSeries] = []
    pcov_all: List[TimeSeries] = []

    for j in range(N):
        name = f"grid_{j:03d}"

        ts = TimeSeries.from_times_and_values(
            times=time_index,
            values=Y_scaled[:, j:j + 1].astype(np.float32),
            columns=[name],
            freq=freq,
        )
        series_all.append(ts)

        pc = TimeSeries.from_times_and_values(
            times=time_index,
            values=X_scaled[:, j, :].astype(np.float32),
            columns=[f"{name}_{c}" for c in PAST_COV_COLS],
            freq=freq,
        )
        pcov_all.append(pc)

    # ============================================================
    # 3.4) Internal validation (tail of train) for early stopping
    # ============================================================
    def slice_list(xs, a, b):
        return [x[a:b] for x in xs]

    train_series = slice_list(series_all, 0, cut_train)
    train_pcov = slice_list(pcov_all, 0, cut_train)

    iv_start = max(0, cut_train - INTERNAL_VAL_STEPS)
    iv_end = cut_train
    min_needed = L + H
    if (iv_end - iv_start) < min_needed:
        iv_start = max(0, iv_end - min_needed)

    val_series = slice_list(series_all, iv_start, iv_end)
    val_pcov = slice_list(pcov_all, iv_start, iv_end)

    print("\n=== INTERNAL validation (for early stopping only) ===")
    print("Train len:", len(train_series[0]))
    print("IntVal idx:", iv_start, "->", iv_end - 1, "| len =", len(val_series[0]))
    print("IntVal time:", time_index[iv_start], "->", time_index[iv_end - 1])

    # ============================================================
    # 3.5) Train TFT
    # ============================================================
    MODEL_NAME = f"tft_{EXP_TAG}_L{L}_H{H}_seed{GLOBAL_SEED}"

    tft = TFTModel(
        input_chunk_length=L,
        output_chunk_length=H,
        n_epochs=N_EPOCHS,
        hidden_size=HIDDEN_SIZE,
        num_attention_heads=N_HEADS,
        dropout=DROPOUT,
        batch_size=BATCH_SIZE,
        optimizer_kwargs={"lr": LR},
        add_relative_index=True,
        random_state=GLOBAL_SEED,
        force_reset=True,
        model_name=MODEL_NAME,
        work_dir=str(ckpt_dir),
        save_checkpoints=True,
        pl_trainer_kwargs={
            "accelerator": "gpu" if torch.cuda.is_available() else "cpu",
            "devices": 1,
            "enable_progress_bar": True,
            "enable_model_summary": False,
            "enable_checkpointing": True,
            "callbacks": [EarlyStopping(monitor="val_loss", mode="min", patience=6)],
            "logger": CSVLogger(save_dir=str(runs_dir), name=MODEL_NAME),
            "gradient_clip_val": 1.0,
        },
    )

    print(f"\n=== Training TFT ({SCENARIO_KEY}, NAR A2_global_v2, 1D locs, NO static) | y_gamma={Y_GAMMA} ===")
    tft.fit(
        series=train_series,
        past_covariates=train_pcov,
        val_series=val_series,
        val_past_covariates=val_pcov,
        verbose=True,
    )

    print("\n=== Loading best checkpoint ===")
    tft = TFTModel.load_from_checkpoint(model_name=MODEL_NAME, work_dir=str(ckpt_dir), best=True)
    print("Loaded best checkpoint.")

    # ============================================================
    # 3.6) Rolling non-overlap over TEST (step=H)
    # ============================================================
    test_start = cut_val
    start_ctx = max(test_start, L)

    yhat_roll, ytrue_roll, t0_list = rolling_nonoverlap(
        tft, series_all, pcov_all, start_ctx, T_total, L, H
    )

    W = yhat_roll.shape[0]
    print("\n=== Rolling non-overlap (POOLED, TEST) ===")
    print(f"windows={W} | step={H} | each window predicts H={H} | points={N}")

    diff = yhat_roll - ytrue_roll
    rmse = float(np.sqrt(np.mean(diff ** 2)))
    mae = float(np.mean(np.abs(diff)))
    print("RMSE:", rmse)
    print("MAE :", mae)

    # q-risk
    yhat_f = yhat_roll.reshape(-1, N)
    ytrue_f = ytrue_roll.reshape(-1, N)

    print("\n=== P50 q-risk (ROLLING, TEST) ===")
    qr_scaled = qrisk(ytrue_f, yhat_f)
    print("scaled  :", qr_scaled)

    ytrue_un = preprocessor.inverse_transform_targets(ytrue_f)
    yhat_un = preprocessor.inverse_transform_targets(yhat_f)
    qr_unscaled = qrisk(ytrue_un, yhat_un)
    print("unscaled:", qr_unscaled)

    # ============================================================
    # 3.7) Plot: 2 figs per grid point
    # ============================================================
    w0 = 0
    t0_first = t0_list[w0]

    print(f"\n=== Plotting 2 figs per grid point | y_gamma={Y_GAMMA} ===")
    plot_two_figs_all_points(
        plots_dir,
        exp_tag=EXP_TAG,
        scenario_key=SCENARIO_KEY,
        N=N,
        H=H,
        w0=w0,
        t0_first=t0_first,
        time_index=time_index,
        test_start=test_start,
        test_end=T_total,
        ytrue_roll=ytrue_roll,
        yhat_roll=yhat_roll,
        Y_scaled=Y_scaled,
        t0_list=t0_list,
    )

    # ============================================================
    # 3.8) Save per-gamma row
    # ============================================================
    row = {
        "scenario": SCENARIO_KEY,
        "noise_std": float(SC["noise_std"]),
        "feature_noise_std": float(SC["feature_noise_std"]),
        "non_linear_strength": float(SC["non_linear_strength"]),
        "eta_rho": float(SC["eta_rho"]),
        "f_rho": float(SC["f_rho"]),
        "y_rho": float(Y_RHO),
        "y_gamma": float(Y_GAMMA),
        "nar_scale_window": int(NAR_SCALE_WINDOW),
        "gate_type": GATE_TYPE,
        "gate_threshold": float(GATE_THRESHOLD),
        "gate_low": float(GATE_LOW),
        "gate_high": float(GATE_HIGH),
        "y_md5": y_md5,
        "rmse_scaled": rmse,
        "mae_scaled": mae,
        "qrisk_p50_scaled": qr_scaled,
        "qrisk_p50_unscaled": qr_unscaled,
        "windows": int(W),
        "L": int(L),
        "H": int(H),
        "T_total": int(T_TOTAL),
        "N_points": int(N_POINTS),
        "eigenvalue": float(EIGENVALUE),
        "global_seed": int(GLOBAL_SEED),
        "exp_tag": EXP_TAG,
        "plots_dir": str(plots_dir),
        "ckpt_dir": str(ckpt_dir),
    }
    SUMMARY_ROWS.append(row)


# ============================================================
# 4) Save summary CSV (2 rows)
# ============================================================
summary_df = pd.DataFrame(SUMMARY_ROWS)
out_csv = (RESULTS_DIR / "TFT_C_y0p85_w50_g0p1_0p2_A2v2.csv").resolve()
summary_df.to_csv(out_csv, index=False, encoding="utf-8-sig")

print("\n" + "=" * 80)
print("DONE. Saved TFT summary (2 rows) to:")
print(out_csv)
print("=" * 80)
print(summary_df)


  __import__("pkg_resources").declare_namespace(__name__)  # type: ignore


âœ… Loaded spatial_utils from: /home/wangxc1117/geospatial-neural-adapter/geospatial_neural_adapter/cpp_extensions/spatial_utils.so
Device: cuda

RUN TFT | SCENARIO=C | y_rho=0.85 | win=50 | y_gamma=0.1

=== Scenario settings ===
Scenario: C
{'noise_std': 0.5, 'feature_noise_std': 0.2, 'non_linear_strength': 2.0, 'eta_rho': 0.8, 'f_rho': 0.6}
NAR params: {'y_rho': 0.85, 'y_gamma': 0.1, 'nar_scale_window': 50, 'gate_type': 'A2_global_v2'}
[Gate A2_global_v2] gate_ratio=0.5997 | thr=0.0 | low=0.5 | high=1.0 | win=50
Synthetic y shape: (1500, 36)
Synthetic cont shape: (1500, 36, 3)
y stats (unscaled) mean/std/mean|y|: 250.49603271484375 131.10008239746094 250.49603271484375
y_md5: 2f25135ad6de3de96dc68bab1956f95a

=== Split (by time index) ===
Train: 0 -> 1049 | len = 1050
Val  : 1050 -> 1274 | len = 225
Test : 1275 -> 1499 | len = 225

=== After pooled scaling (train-only) ===
Y_scaled: (1500, 36) | finite: True
X_scaled: (1500, 36, 3) | finite: True

=== INTERNAL validation (for early s

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



=== Training TFT (C, NAR A2_global_v2, 1D locs, NO static) | y_gamma=0.1 ===



   | Name                              | Type                             | Params | Mode 
------------------------------------------------------------------------------------------------
0  | train_metrics                     | MetricCollection                 | 0      | train
1  | val_metrics                       | MetricCollection                 | 0      | train
2  | input_embeddings                  | _MultiEmbedding                  | 0      | train
3  | static_covariates_vsn             | _VariableSelectionNetwork        | 0      | train
4  | encoder_vsn                       | _VariableSelectionNetwork        | 8.8 K  | train
5  | decoder_vsn                       | _VariableSelectionNetwork        | 1.6 K  | train
6  | static_context_grn                | _GatedResidualNetwork            | 16.8 K | train
7  | static_context_hidden_encoder_grn | _GatedResidualNetwork            | 16.8 K | train
8  | static_context_cell_encoder_grn   | _GatedResidualNetwork            | 16.8 K 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.



=== Loading best checkpoint ===
Loaded best checkpoint.


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available


=== Rolling non-overlap (POOLED, TEST) ===
windows=9 | step=24 | each window predicts H=24 | points=36
RMSE: 0.6261568665504456
MAE : 0.3076390027999878

=== P50 q-risk (ROLLING, TEST) ===
scaled  : 0.2832062245992567
unscaled: 0.15747766019815884

=== Plotting 2 figs per grid point | y_gamma=0.1 ===


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Saved 2 figs/point for 36 points under: /home/wangxc1117/TFTModel-use/geospatial-neural-adapter-dev/examples/try/simulation_NAR/TFT_sim_NAR/TFT_plots_synth_tft_C_NAR_A2_global_v2_yrho0.85_win50_1D_NOSTATIC_ONLY2G_2figs/ygamma_0.1

RUN TFT | SCENARIO=C | y_rho=0.85 | win=50 | y_gamma=0.2

=== Scenario settings ===
Scenario: C
{'noise_std': 0.5, 'feature_noise_std': 0.2, 'non_linear_strength': 2.0, 'eta_rho': 0.8, 'f_rho': 0.6}
NAR params: {'y_rho': 0.85, 'y_gamma': 0.2, 'nar_scale_window': 50, 'gate_type': 'A2_global_v2'}
[Gate A2_global_v2] gate_ratio=0.6104 | thr=0.0 | low=0.5 | high=1.0 | win=50
Synthetic y shape: (1500, 36)
Synthetic cont shape: (1500, 36, 3)
y stats (unscaled) mean/std/mean|y|: 253.73133850097656 130.7637176513672 253.73133850097656
y_md5: abec649232e0950d399c0c830544bcae

=== Split (by time index) ===
Train: 0 -> 1049 | len = 1050
Val  : 1050 -> 1274 | len = 225
Test : 1275 -> 1499 | len = 225

=== After pooled scaling (train-only) ===
Y_scaled: (1500, 36) | finit


   | Name                              | Type                             | Params | Mode 
------------------------------------------------------------------------------------------------
0  | train_metrics                     | MetricCollection                 | 0      | train
1  | val_metrics                       | MetricCollection                 | 0      | train
2  | input_embeddings                  | _MultiEmbedding                  | 0      | train
3  | static_covariates_vsn             | _VariableSelectionNetwork        | 0      | train
4  | encoder_vsn                       | _VariableSelectionNetwork        | 8.8 K  | train
5  | decoder_vsn                       | _VariableSelectionNetwork        | 1.6 K  | train
6  | static_context_grn                | _GatedResidualNetwork            | 16.8 K | train
7  | static_context_hidden_encoder_grn | _GatedResidualNetwork            | 16.8 K | train
8  | static_context_cell_encoder_grn   | _GatedResidualNetwork            | 16.8 K 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



=== Loading best checkpoint ===
Loaded best checkpoint.


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available


=== Rolling non-overlap (POOLED, TEST) ===
windows=9 | step=24 | each window predicts H=24 | points=36
RMSE: 0.6630433797836304
MAE : 0.32352277636528015

=== P50 q-risk (ROLLING, TEST) ===
scaled  : 0.29601073214743967
unscaled: 0.16485052854802

=== Plotting 2 figs per grid point | y_gamma=0.2 ===
Saved 2 figs/point for 36 points under: /home/wangxc1117/TFTModel-use/geospatial-neural-adapter-dev/examples/try/simulation_NAR/TFT_sim_NAR/TFT_plots_synth_tft_C_NAR_A2_global_v2_yrho0.85_win50_1D_NOSTATIC_ONLY2G_2figs/ygamma_0.2

DONE. Saved TFT summary (2 rows) to:
/home/wangxc1117/TFTModel-use/geospatial-neural-adapter-dev/examples/try/simulation_NAR/TFT_sim_NAR/TFT_results_synth_tft_C_NAR_A2_global_v2_yrho0.85_win50_1D_NOSTATIC_ONLY2G/RESULTS_TFT_NAR_ONLY2GAMMAS_C_A2_global_v2_yrho0.85_win50.csv
  scenario  noise_std  feature_noise_std  non_linear_strength  eta_rho  f_rho  \
0        C        0.5                0.2                  2.0      0.8    0.6   
1        C        0.5          