In [None]:
import csv
import glob
from pathlib import Path
from typing import Dict, Any, List, Tuple, Optional

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

from darts import TimeSeries
from darts.models import TFTModel
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import CSVLogger

from geospatial_neural_adapter.utils import (
    clear_gpu_memory,
    create_experiment_config,
    print_experiment_summary,
    get_device_info,
)
from geospatial_neural_adapter.metrics import compute_metrics
from geospatial_neural_adapter.data.preprocessing import prepare_all_with_scaling


# Global settings & dirs
MODE = "train"

try:
    EXP_ROOT = Path(__file__).resolve().parent
except NameError:
    EXP_ROOT = Path.cwd()

CKPT_DIR = (EXP_ROOT / "darts_ckpt_favorita_store_paperWindow_h30_clean_noFutureCov_staticOH_oilOnly_addRelIdx_BATCH_oilStd").resolve()
RUNS_DIR = (EXP_ROOT / "TFT_runs_favorita_store_paperWindow_h30_clean_noFutureCov_staticOH_oilOnly_addRelIdx_BATCH_oilStd").resolve()
PLOTS_DIR = (EXP_ROOT / "TFT_plots_favorita_store_paperWindow_h30_clean_noFutureCov_staticOH_oilOnly_addRelIdx_BATCH_oilStd").resolve()

CKPT_DIR.mkdir(parents=True, exist_ok=True)
RUNS_DIR.mkdir(parents=True, exist_ok=True)
PLOTS_DIR.mkdir(parents=True, exist_ok=True)

GLOBAL_SEED = 42
np.random.seed(GLOBAL_SEED)
torch.manual_seed(GLOBAL_SEED)

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device_info = get_device_info()
print(f"Using {device_info['device'].upper()}: {device_info['device_name']}")
if device_info["device"] == "cuda":
    print(f"   Memory: {device_info['memory_gb']} GB")


# TFT config
TREND_CONFIG: Dict[str, Any] = {
    "input_chunk_length": 90,
    "output_chunk_length": 30,
    "n_epochs": 20,
    "hidden_size": 64,
    "num_attention_heads": 4,
    "dropout": 0.1,
    "batch_size": 32,
    "optimizer_kwargs": {"lr": 3e-4},
    "random_state": GLOBAL_SEED,
    "force_reset": True,
    "full_attention": False,
    "add_relative_index": True,
    "pl_trainer_kwargs": {
        "accelerator": "gpu" if torch.cuda.is_available() else "cpu",
        "devices": 1,
        "enable_progress_bar": True,
        "enable_model_summary": False,
        "enable_checkpointing": True,
        "gradient_clip_val": 1.0,
        "callbacks": [
            EarlyStopping(
                monitor="val_loss",
                mode="min",
                patience=5,
                min_delta=5e-4,
            ),
        ],
    },
}

L = int(TREND_CONFIG["input_chunk_length"])
H = int(TREND_CONFIG["output_chunk_length"])
TAG = f"H{H}"
TAG_LOWER = f"h{H}"

print(f"\n=== TFT (STORE-level): input={L}, output={H} (30-step direct predict) ===")
print("Config: add_relative_index=True (only time feature).")
print("Future covariates: NONE")
print("Past covariates  : oil_std ONLY")
print("Static covariates: one-hot store meta")
print("OFFICIAL inference: BATCH predict")
print("oil standardized using TRAIN-only mean/std")


# Paper window & splits (reporting only)
PAPER_TRAIN_START = pd.Timestamp("2015-01-01")
PAPER_TRAIN_END = pd.Timestamp("2015-12-01")

PAPER_VAL_DAYS = 30
PAPER_TEST_DAYS = 30

PAPER_VAL_START = PAPER_TRAIN_END + pd.Timedelta(days=1)
PAPER_VAL_END = PAPER_VAL_START + pd.Timedelta(days=PAPER_VAL_DAYS - 1)

PAPER_TEST_START = PAPER_VAL_END + pd.Timedelta(days=1)
PAPER_TEST_END = PAPER_TEST_START + pd.Timedelta(days=PAPER_TEST_DAYS - 1)

PAPER_ALL_START = PAPER_TRAIN_START
PAPER_ALL_END = PAPER_TEST_END

print("\n=== OFFICIAL (paper-like) window (for reporting/eval only) ===")
print("Train:", PAPER_TRAIN_START.date(), "→", PAPER_TRAIN_END.date())
print("Val  :", PAPER_VAL_START.date(), "→", PAPER_VAL_END.date())
print("Test :", PAPER_TEST_START.date(), "→", PAPER_TEST_END.date())
print("All  :", PAPER_ALL_START.date(), "→", PAPER_ALL_END.date())


# Load raw files
DATA_ROOT = Path("/home/wangxc1117/experiment_data/sales_forecasting_data")
TRAIN_PATH = DATA_ROOT / "train.csv"
STORES_PATH = DATA_ROOT / "stores.csv"
OIL_PATH = DATA_ROOT / "oil.csv"
TRANSACTIONS_PATH = DATA_ROOT / "transactions.csv"

for p, name in [
    (TRAIN_PATH, "train.csv"),
    (STORES_PATH, "stores.csv"),
    (OIL_PATH, "oil.csv"),
]:
    if not p.exists():
        raise FileNotFoundError(f"{name} not found at {p}")

print("\n=== Loading train.csv ===")
df = pd.read_csv(TRAIN_PATH)
df["date"] = pd.to_datetime(df["date"])
df["onpromotion"] = df["onpromotion"].fillna(0).astype(int)
df = df[(df["date"] >= PAPER_ALL_START) & (df["date"] <= PAPER_ALL_END)].copy()
print(f"Paper-window date range (in df): {df['date'].min()} → {df['date'].max()}")


# Build store-level target
df_store_sales = (
    df.groupby(["date", "store_nbr"], as_index=False)["unit_sales"]
    .sum()
    .rename(columns={"unit_sales": "store_sales"})
)
df_store_sales["store_sales"] = df_store_sales["store_sales"].clip(lower=0.0)

stores = sorted(df_store_sales["store_nbr"].unique())
date_index = pd.date_range(start=PAPER_ALL_START, end=PAPER_ALL_END, freq="D")
T_all = len(date_index)
N = len(stores)

print(f"\nUsing ALL {N} stores in paper window.")
print(f"T_all (days): {T_all}")

full_idx = pd.MultiIndex.from_product([date_index, stores], names=["date", "store_nbr"])

panel_sales_raw = (
    df_store_sales.set_index(["date", "store_nbr"])
    .reindex(full_idx)
    .sort_index()
)

panel_sales = panel_sales_raw.copy()
panel_sales["store_sales"] = panel_sales.groupby("store_nbr")["store_sales"].ffill()
panel_sales.loc[panel_sales["store_sales"].isna(), "store_sales"] = 0.0
panel_sales["log_sales"] = np.log1p(panel_sales["store_sales"]).astype("float32")

Y_df = (
    panel_sales["log_sales"]
    .unstack("store_nbr")
    .reindex(index=date_index, columns=stores)
    .astype("float32")
)
targets_full = Y_df.to_numpy(dtype=np.float32)

print("\n=== Target sanity ===")
print("targets_full shape:", targets_full.shape)
print("targets_full has_na:", bool(np.isnan(targets_full).any()))


# Static cov + oil + transactions (loaded only)
stores_meta = pd.read_csv(STORES_PATH).set_index("store_nbr").loc[stores]

static_cat = stores_meta[["city", "state", "type", "cluster"]].copy()
static_oh = pd.get_dummies(static_cat.astype(str), drop_first=False)
static_cov_df = static_oh.astype("float32")
print("\nStatic cov (one-hot) shape:", static_cov_df.shape)

oil = pd.read_csv(OIL_PATH)
oil["date"] = pd.to_datetime(oil["date"])
oil = oil.set_index("date")["dcoilwtico"].reindex(date_index).astype("float32")
oil = oil.ffill().bfill()
oil_vals = oil.to_numpy().reshape(-1, 1)
oil_mat = np.repeat(oil_vals, N, axis=1).astype("float32")
print("Oil covariate: has_na =", bool(np.isnan(oil_mat).any()))

if TRANSACTIONS_PATH.exists():
    tr = pd.read_csv(TRANSACTIONS_PATH)
    tr["date"] = pd.to_datetime(tr["date"])
    tr = tr[(tr["date"] >= PAPER_ALL_START) & (tr["date"] <= PAPER_ALL_END)].copy()

    tr_panel = (
        tr.set_index(["date", "store_nbr"])["transactions"]
        .reindex(full_idx)
        .sort_index()
        .astype("float32")
    )
    tr_panel = tr_panel.fillna(0.0)
    tr_panel = tr_panel.groupby("store_nbr").ffill().fillna(0.0)
    tr_mat = (
        tr_panel.unstack("store_nbr")
        .reindex(index=date_index, columns=stores)
        .astype("float32")
        .to_numpy()
    )
    print("Transactions loaded (NOT used): has_na =", bool(np.isnan(tr_mat).any()))
else:
    tr_mat = np.zeros((T_all, N), dtype=np.float32)
    print("Transactions not found (NOT used): using zeros placeholder.")


# Official cuts (reporting only)
cut_train = int((PAPER_TRAIN_END - PAPER_ALL_START).days + 1)
cut_val = cut_train + PAPER_VAL_DAYS
assert cut_val + PAPER_TEST_DAYS == T_all

print("\n=== OFFICIAL split lengths (reporting only) ===")
print("T_all:", T_all, "| train:", cut_train, "| val:", (cut_val - cut_train), "| test:", (T_all - cut_val))


# Scaling targets only (fit on train only)
cat_dummy = np.zeros((T_all, N, 1), dtype=np.int64)
cont_dummy = np.zeros((T_all, N, 1), dtype=np.float32)

train_ratio = cut_train / T_all
val_ratio = PAPER_VAL_DAYS / T_all

train_dataset_scl, val_dataset_scl, test_dataset_scl, preprocessor = prepare_all_with_scaling(
    cat_features=cat_dummy,
    cont_features=cont_dummy,
    targets=targets_full,
    train_ratio=train_ratio,
    val_ratio=val_ratio,
    feature_scaler_type="standard",
    target_scaler_type="standard",
    fit_on_train_only=True,
)


def stitch_y(dsets) -> np.ndarray:
    ys = [ds.tensors[2].cpu().numpy().astype(np.float32) for ds in dsets]
    return np.concatenate(ys, axis=0)


y_all_s = stitch_y([train_dataset_scl, val_dataset_scl, test_dataset_scl])
assert y_all_s.shape == targets_full.shape
print("y_all_s shape:", y_all_s.shape)


# Standardize oil (train only)
EPS_STD = 1e-6


def train_only_standardize_2d(x_2d: np.ndarray, cut: int) -> Tuple[np.ndarray, float, float]:
    x_train = x_2d[:cut, :].astype(np.float64)
    mean = float(np.mean(x_train))
    std = float(np.std(x_train, ddof=0))
    if std < EPS_STD:
        std = 1.0
    x_std = ((x_2d.astype(np.float64) - mean) / std).astype(np.float32)
    return x_std, mean, std


oil_mat_s, oil_mean, oil_std = train_only_standardize_2d(oil_mat, cut_train)
print("\n=== TRAIN-only standardization stats ===")
print(f"oil: mean={oil_mean:.6f}, std={oil_std:.6f}")


# q-risk helpers
def quantile_loss(y_true: np.ndarray, y_pred: np.ndarray, q: float) -> np.ndarray:
    e = y_true - y_pred
    return np.maximum(q * e, (q - 1.0) * e)


def qrisk(y_true: np.ndarray, y_pred: np.ndarray, q: float, eps: float = 1e-8) -> float:
    y_true = np.asarray(y_true, dtype=np.float64)
    y_pred = np.asarray(y_pred, dtype=np.float64)
    num = 2.0 * np.sum(quantile_loss(y_true, y_pred, q))
    den = np.sum(np.abs(y_true)) + eps
    return float(num / den)


def _try_get_target_scaler_params(prep) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
    if prep is None:
        return None, None
    ts = getattr(prep, "target_scaler", None)
    if ts is None:
        return None, None
    mean_ = getattr(ts, "mean_", None)
    scale_ = getattr(ts, "scale_", None)
    if mean_ is None or scale_ is None:
        return None, None
    return np.asarray(mean_, dtype=np.float64), np.asarray(scale_, dtype=np.float64)


TARGET_MEAN, TARGET_SCALE = _try_get_target_scaler_params(preprocessor)


def inverse_scale_log(y_scaled: np.ndarray) -> Optional[np.ndarray]:
    if TARGET_MEAN is None or TARGET_SCALE is None:
        return None
    y_scaled = np.asarray(y_scaled, dtype=np.float64)

    m = TARGET_MEAN
    s = TARGET_SCALE

    if m.ndim == 1 and m.shape[0] == y_scaled.shape[1]:
        m2 = m.reshape(1, -1)
    elif m.size == 1:
        m2 = np.full((1, y_scaled.shape[1]), float(m.reshape(-1)[0]), dtype=np.float64)
    else:
        m2 = m.reshape(1, -1) if m.ndim == 1 else m

    if s.ndim == 1 and s.shape[0] == y_scaled.shape[1]:
        s2 = s.reshape(1, -1)
    elif s.size == 1:
        s2 = np.full((1, y_scaled.shape[1]), float(s.reshape(-1)[0]), dtype=np.float64)
    else:
        s2 = s.reshape(1, -1) if s.ndim == 1 else s

    return (y_scaled * s2 + m2).astype(np.float32)


def log_to_sales(y_log: np.ndarray) -> np.ndarray:
    y = np.expm1(np.asarray(y_log, dtype=np.float64))
    y = np.clip(y, 0.0, None)
    return y.astype(np.float32)


def compute_p50_qrisk_block(y_true_2d: np.ndarray, y_pred_2d: np.ndarray) -> Dict[str, float]:
    out: Dict[str, float] = {}
    out["qrisk_p50_scaled_log"] = qrisk(y_true_2d, y_pred_2d, q=0.5)

    y_true_log = inverse_scale_log(y_true_2d)
    y_pred_log = inverse_scale_log(y_pred_2d)
    if y_true_log is not None and y_pred_log is not None:
        out["qrisk_p50_unscaled_log"] = qrisk(y_true_log, y_pred_log, q=0.5)
        out["qrisk_p50_sales"] = qrisk(log_to_sales(y_true_log), log_to_sales(y_pred_log), q=0.5)
    return out


# Build TimeSeries lists (target + static + past)
series_all: List[TimeSeries] = []
pcov_all_series: List[TimeSeries] = []

for j, sid in enumerate(stores):
    name = f"store_{sid}"

    ts = TimeSeries.from_times_and_values(
        times=date_index,
        values=y_all_s[:, j:j + 1],
        columns=[name],
        freq="D",
    )

    sc = static_cov_df.loc[[sid]].copy()
    sc.index = [name]
    ts = ts.with_static_covariates(sc)
    series_all.append(ts)

    pc_vals = oil_mat_s[:, j:j + 1].astype("float32")
    pcov = TimeSeries.from_times_and_values(
        times=date_index,
        values=pc_vals,
        columns=[f"{name}_oil_std"],
        freq="D",
    )
    pcov_all_series.append(pcov)

print(f"\nBuilt {len(series_all)} TimeSeries. Future covariates: NONE. Past cov: oil_std only.")


# Internal validation slice
INTERNAL_VAL_DAYS = 240


def slice_list(ts_list: List[TimeSeries], a: int, b: int) -> List[TimeSeries]:
    return [ts[a:b] for ts in ts_list]


train_series_in = slice_list(series_all, 0, cut_train)
train_pcov_in = slice_list(pcov_all_series, 0, cut_train)

internal_val_start = max(0, cut_train - INTERNAL_VAL_DAYS)
internal_val_end = cut_train
min_needed = L + H
if (internal_val_end - internal_val_start) < min_needed:
    internal_val_start = max(0, internal_val_end - min_needed)

val_series_in = slice_list(series_all, internal_val_start, internal_val_end)
val_pcov_in = slice_list(pcov_all_series, internal_val_start, internal_val_end)

print("\n=== INTERNAL validation (for early stopping only) ===")
print("Train idx: [0 :", cut_train, ") len =", len(train_series_in[0]))
print("IntVal idx: [", internal_val_start, ":", internal_val_end, ") len =", len(val_series_in[0]))
print("IntVal covers dates:", date_index[internal_val_start].date(), "→", date_index[internal_val_end - 1].date())


# Official inference (batch predict)
def infer_official_direct_30_batch(tft: TFTModel) -> Tuple[np.ndarray, np.ndarray]:
    val_expected_index = date_index[cut_train:cut_val]
    test_expected_index = date_index[cut_val:]

    ts_ctx_val_list = slice_list(series_all, 0, cut_train)
    pc_ctx_val_list = slice_list(pcov_all_series, 0, cut_train)

    pred_val_list = tft.predict(
        n=H,
        series=ts_ctx_val_list,
        past_covariates=pc_ctx_val_list,
        verbose=False,
    )
    if isinstance(pred_val_list, TimeSeries):
        pred_val_list = [pred_val_list]
    if len(pred_val_list) != N:
        raise RuntimeError(f"[Val] expected {N} series predictions, got {len(pred_val_list)}")

    yhat_val = np.full((PAPER_VAL_DAYS, N), np.nan, dtype=np.float32)
    for j, pred in enumerate(pred_val_list):
        if len(pred) != PAPER_VAL_DAYS:
            raise RuntimeError(f"[Val] store idx {j}: expected 30 preds, got {len(pred)}")
        if not pred.time_index.equals(val_expected_index):
            raise RuntimeError("[Val] time_index mismatch")
        yhat_val[:, j] = pred.values(copy=False).astype(np.float32)[:, 0]

    ts_ctx_test_list = slice_list(series_all, 0, cut_val)
    pc_ctx_test_list = slice_list(pcov_all_series, 0, cut_val)

    pred_test_list = tft.predict(
        n=H,
        series=ts_ctx_test_list,
        past_covariates=pc_ctx_test_list,
        verbose=False,
    )
    if isinstance(pred_test_list, TimeSeries):
        pred_test_list = [pred_test_list]
    if len(pred_test_list) != N:
        raise RuntimeError(f"[Test] expected {N} series predictions, got {len(pred_test_list)}")

    yhat_test = np.full((PAPER_TEST_DAYS, N), np.nan, dtype=np.float32)
    for j, pred in enumerate(pred_test_list):
        if len(pred) != PAPER_TEST_DAYS:
            raise RuntimeError(f"[Test] store idx {j}: expected 30 preds, got {len(pred)}")
        if not pred.time_index.equals(test_expected_index):
            raise RuntimeError("[Test] time_index mismatch")
        yhat_test[:, j] = pred.values(copy=False).astype(np.float32)[:, 0]

    if np.isnan(yhat_val).any() or np.isnan(yhat_test).any():
        raise RuntimeError("Inference produced NaNs.")

    return yhat_val, yhat_test


# Loss curve plot
def find_metrics_csv(base_dir: Path, dataset_seed: int, model_name: str) -> Path:
    base = Path(base_dir) / f"TFT_Favorita_seed_{dataset_seed}" / model_name
    pattern = str(base / "version_*" / "metrics.csv")
    candidates = glob.glob(pattern)
    if not candidates:
        raise FileNotFoundError(f"metrics.csv not found; pattern: {pattern}")
    latest = max(candidates, key=lambda p: Path(p).stat().st_mtime)
    return Path(latest).resolve()


def plot_tft_loss_curve(metrics_csv_path: Path, out_path: Path):
    if not metrics_csv_path.exists():
        print(f"[plot_tft_loss_curve] {metrics_csv_path} not found, skip.")
        return

    dfm = pd.read_csv(metrics_csv_path)
    dfm.columns = [c.replace("/", "_").replace(".", "_") for c in dfm.columns]

    if "epoch" not in dfm.columns:
        print("[plot_tft_loss_curve] no epoch col, skip.")
        return

    loss_cols = [c for c in dfm.columns if "loss" in c.lower() and pd.api.types.is_numeric_dtype(dfm[c])]
    if not loss_cols:
        print("[plot_tft_loss_curve] no numeric loss col, skip.")
        return

    grp = dfm[["epoch"] + loss_cols].groupby("epoch", as_index=False).mean(numeric_only=True)

    train_col = "train_loss" if "train_loss" in grp.columns else ("loss" if "loss" in grp.columns else None)
    val_col = next((c for c in grp.columns if "loss" in c.lower() and "val" in c.lower()), None)

    plt.figure(figsize=(8, 4), dpi=140)
    if train_col is not None:
        plt.plot(grp["epoch"], grp[train_col], "-", linewidth=2, label="train_loss (mean/epoch)")
    if val_col is not None:
        plt.plot(grp["epoch"], grp[val_col], "-o", linewidth=2, markersize=4, label="val_loss (INTERNAL slice)")
    plt.xlabel("epoch")
    plt.ylabel("loss")
    plt.title("TFT loss vs epoch (internal val for early stopping)")
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()
    print(f"[plot_tft_loss_curve] Saved -> {out_path}")


# Plotting
def plot_store_two_figs(
    sid: int,
    y_true_scaled_full: np.ndarray,
    yhat_test_scaled: np.ndarray,
    out_dir: Path,
    dataset_seed: int,
):
    j = stores.index(sid)

    y_val_true = y_true_scaled_full[cut_train:cut_val, j]
    y_test_true = y_true_scaled_full[cut_val:, j]
    y_test_pred = yhat_test_scaled[:, j]

    dates_val = date_index[cut_train:cut_val]
    dates_test = date_index[cut_val:]

    dates_true_vt = dates_val.append(dates_test)
    y_true_vt = np.concatenate([y_val_true, y_test_true], axis=0)

    plt.figure(figsize=(10, 4), dpi=140)
    plt.plot(dates_true_vt, y_true_vt, "-", linewidth=2.0, color="k", label="True (Val+Test)")
    plt.plot(dates_test, y_test_pred, "-", linewidth=1.8, color="C3", label="Test Pred")
    plt.axvline(date_index[cut_train], linestyle="--", linewidth=1, label="train/val split")
    plt.axvline(date_index[cut_val], linestyle="--", linewidth=1, label="val/test split")
    plt.title(f"FIG1 {TAG} store {sid}: True (continuous) + Test Pred")
    plt.xlabel("Date")
    plt.ylabel("Scaled log_sales")
    plt.grid(alpha=0.3)
    plt.legend()
    plt.tight_layout()
    p1 = out_dir / f"FIG1_store{sid}_{TAG_LOWER}_seed{dataset_seed}.png"
    plt.savefig(p1)
    plt.close()

    plt.figure(figsize=(10, 4), dpi=140)
    y_all_true = y_true_scaled_full[:, j]
    plt.plot(date_index, y_all_true, "-", linewidth=1.8, color="k", label="All True")
    plt.plot(dates_test, y_test_pred, "-", linewidth=1.8, color="C3", label="Test Pred")
    plt.axvline(date_index[cut_train], linestyle="--", linewidth=1, label="train/val split")
    plt.axvline(date_index[cut_val], linestyle="--", linewidth=1, label="val/test split")
    plt.title(f"FIG2 {TAG} store {sid}: All True + Test Pred")
    plt.xlabel("Date")
    plt.ylabel("Scaled log_sales")
    plt.grid(alpha=0.3)
    plt.legend()
    plt.tight_layout()
    p2 = out_dir / f"FIG2_store{sid}_{TAG_LOWER}_seed{dataset_seed}.png"
    plt.savefig(p2)
    plt.close()


def plot_all_stores_two_figs(
    y_true_scaled_full: np.ndarray,
    yhat_test_scaled: np.ndarray,
    dataset_seed: int,
):
    print(f"\nStart plotting ALL {len(stores)} stores (2 figs/store) ...")
    for sid in stores:
        plot_store_two_figs(
            sid=sid,
            y_true_scaled_full=y_true_scaled_full,
            yhat_test_scaled=yhat_test_scaled,
            out_dir=PLOTS_DIR,
            dataset_seed=dataset_seed,
        )
    print(f"All figures saved under: {PLOTS_DIR}")


# Runner
EXPERIMENT_TRIALS_CONFIG = create_experiment_config(
    n_trials_per_seed=1,
    n_dataset_seeds=1,
    seed_range_start=1,
    seed_range_end=2,
)
print_experiment_summary(EXPERIMENT_TRIALS_CONFIG)


def build_model(model_name: str, log_root: Path) -> TFTModel:
    pl_kwargs = dict(TREND_CONFIG["pl_trainer_kwargs"])
    pl_kwargs["logger"] = CSVLogger(save_dir=str(log_root), name=model_name)

    return TFTModel(
        input_chunk_length=L,
        output_chunk_length=H,
        n_epochs=TREND_CONFIG["n_epochs"],
        hidden_size=TREND_CONFIG["hidden_size"],
        num_attention_heads=TREND_CONFIG["num_attention_heads"],
        dropout=TREND_CONFIG["dropout"],
        batch_size=TREND_CONFIG["batch_size"],
        random_state=TREND_CONFIG["random_state"],
        force_reset=TREND_CONFIG["force_reset"],
        full_attention=TREND_CONFIG["full_attention"],
        add_relative_index=TREND_CONFIG["add_relative_index"],
        pl_trainer_kwargs=pl_kwargs,
        model_name=model_name,
        work_dir=str(CKPT_DIR),
        save_checkpoints=True,
        optimizer_kwargs=TREND_CONFIG.get("optimizer_kwargs", None),
    )


def train_one_seed(dataset_seed: int) -> str:
    np.random.seed(dataset_seed)
    torch.manual_seed(dataset_seed)

    log_root = RUNS_DIR / f"TFT_Favorita_seed_{dataset_seed}"
    log_root.mkdir(parents=True, exist_ok=True)

    model_name = f"tft_favorita_store_paperWindow_{TAG_LOWER}_seed{dataset_seed}_noFutureCov_staticOH_oilOnly_addRelIdx_BATCH_oilStd"

    print(
        f"\n[Seed {dataset_seed}] Training TFT "
        f"(input={L}, output={H}, INTERNAL val for early-stopping, OFFICIAL val/test for reporting) ..."
    )

    tft = build_model(model_name=model_name, log_root=log_root)

    tft.fit(
        series=train_series_in,
        past_covariates=train_pcov_in,
        val_series=val_series_in,
        val_past_covariates=val_pcov_in,
        verbose=True,
    )

    return model_name


def load_best_ckpt(model_name: str) -> TFTModel:
    tft = TFTModel.load_from_checkpoint(model_name=model_name, work_dir=str(CKPT_DIR), best=True)
    print("Checkpoint loaded (best by INTERNAL val_loss).")
    return tft


def official_eval_and_save(dataset_seed: int, model_name: str) -> None:
    tft = load_best_ckpt(model_name=model_name)

    yhat_val_scaled, yhat_test_scaled = infer_official_direct_30_batch(tft)

    y_val_true = y_all_s[cut_train:cut_val, :]
    y_test_true = y_all_s[cut_val:, :]

    val_y_t = torch.from_numpy(y_val_true).to(DEVICE)
    test_y_t = torch.from_numpy(y_test_true).to(DEVICE)
    yhat_val_t = torch.from_numpy(yhat_val_scaled).to(DEVICE)
    yhat_test_t = torch.from_numpy(yhat_test_scaled).to(DEVICE)

    rmse_v, mae_v, r2_v = compute_metrics(val_y_t, yhat_val_t)
    rmse_t, mae_t, r2_t = compute_metrics(test_y_t, yhat_test_t)

    qrisk_val = compute_p50_qrisk_block(y_val_true, yhat_val_scaled)
    qrisk_test = compute_p50_qrisk_block(y_test_true, yhat_test_scaled)

    print(f"\n=== TFT_STORE_{TAG} OFFICIAL eval (paper window; report only) ===")
    print(f"Val  RMSE={rmse_v:.6f}, MAE={mae_v:.6f}, R2={r2_v:.6f}")
    print(f"Test RMSE={rmse_t:.6f}, MAE={mae_t:.6f}, R2={r2_t:.6f}")

    print("\n=== P50 q-risk (Val) ===")
    for k, v in qrisk_val.items():
        print(f"{k} = {v:.6f}")

    print("\n=== P50 q-risk (Test) ===")
    for k, v in qrisk_test.items():
        print(f"{k} = {v:.6f}")

    csv_path = EXP_ROOT / f"metrics_summary_TFT_Favorita_STORE_{TAG}_paperWindow_clean_noFutureCov_staticOH_oilOnly_addRelIdx_BATCH_oilStd.csv"
    write_header = not csv_path.exists()
    with csv_path.open("a", newline="") as f:
        w = csv.writer(f)
        if write_header:
            w.writerow([
                "seed", "model",
                "rmse_val", "mae_val", "r2_val",
                "rmse_test", "mae_test", "r2_test",
                "val_qrisk_p50_scaled_log", "val_qrisk_p50_unscaled_log", "val_qrisk_p50_sales",
                "test_qrisk_p50_scaled_log", "test_qrisk_p50_unscaled_log", "test_qrisk_p50_sales",
                "oil_mean_train", "oil_std_train",
            ])

        model_label = f"TFT_STORE_{TAG}_paperWindow_clean_noFutureCov_oilStd_staticOH_addRelIdx_noEnc_BATCH"
        w.writerow([
            dataset_seed, model_label,
            float(rmse_v), float(mae_v), float(r2_v),
            float(rmse_t), float(mae_t), float(r2_t),
            float(qrisk_val["qrisk_p50_scaled_log"]),
            float(qrisk_val.get("qrisk_p50_unscaled_log", np.nan)),
            float(qrisk_val.get("qrisk_p50_sales", np.nan)),
            float(qrisk_test["qrisk_p50_scaled_log"]),
            float(qrisk_test.get("qrisk_p50_unscaled_log", np.nan)),
            float(qrisk_test.get("qrisk_p50_sales", np.nan)),
            float(oil_mean), float(oil_std),
        ])

    try:
        metrics_csv_path = find_metrics_csv(RUNS_DIR, dataset_seed=dataset_seed, model_name=model_name)
        loss_png = EXP_ROOT / f"tft_loss_favorita_store_{TAG_LOWER}_seed{dataset_seed}_BATCH_oilStd_noFutureCov_addRelIdx.png"
        print("Using metrics.csv:", metrics_csv_path)
        plot_tft_loss_curve(metrics_csv_path, loss_png)
    except Exception as e:
        print("[Loss plot] skipped due to:", e)

    plot_all_stores_two_figs(
        y_true_scaled_full=y_all_s,
        yhat_test_scaled=yhat_test_scaled,
        dataset_seed=dataset_seed,
    )


# Main
if MODE == "train":
    for seed in range(
        EXPERIMENT_TRIALS_CONFIG["seed_range_start"],
        EXPERIMENT_TRIALS_CONFIG["seed_range_end"],
    ):
        print(f"\nStarting TFT STORE-level {TAG} training for seed {seed}")
        model_name = train_one_seed(seed)

        official_eval_and_save(dataset_seed=seed, model_name=model_name)

        clear_gpu_memory()
        print(f"Completed seed {seed}")

    print(f"\nAll TFT STORE-level {TAG} experiments completed!")
    print("All done.")

elif MODE == "infer_plot":
    DATASET_SEED = 1
    model_name = f"tft_favorita_store_paperWindow_{TAG_LOWER}_seed{DATASET_SEED}_noFutureCov_staticOH_oilOnly_addRelIdx_BATCH_oilStd"
    official_eval_and_save(dataset_seed=DATASET_SEED, model_name=model_name)
    print("\nAll done.")
else:
    raise ValueError(f"Unknown MODE: {MODE}")


  __import__("pkg_resources").declare_namespace(__name__)  # type: ignore


✅ Loaded spatial_utils from: /home/wangxc1117/geospatial-neural-adapter/geospatial_neural_adapter/cpp_extensions/spatial_utils.so
Using CUDA: NVIDIA GeForce RTX 4060 Laptop GPU
   Memory: 8.6 GB

=== TFT (STORE-level): input=90, output=30 (30-step direct predict) ===
Config: add_relative_index=True (only time feature).
Future covariates: NONE
Past covariates  : oil_std ONLY
Static covariates: one-hot store meta
OFFICIAL inference: BATCH predict
oil standardized using TRAIN-only mean/std

=== OFFICIAL (paper-like) window (for reporting/eval only) ===
Train: 2015-01-01 → 2015-12-01
Val  : 2015-12-02 → 2015-12-31
Test : 2016-01-01 → 2016-01-30
All  : 2015-01-01 → 2016-01-30

=== Loading train.csv ===


  df = pd.read_csv(TRAIN_PATH)


Paper-window date range (in df): 2015-01-01 00:00:00 → 2016-01-30 00:00:00


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs



Using ALL 53 stores in paper window.
T_all (days): 395

=== Target sanity ===
targets_full shape: (395, 53)
targets_full has_na: False

Static cov (one-hot) shape: (53, 60)
Oil covariate: has_na = False
Transactions loaded (NOT used): has_na = False

=== OFFICIAL split lengths (reporting only) ===
T_all: 395 | train: 335 | val: 30 | test: 30
y_all_s shape: (395, 53)

=== TRAIN-only standardization stats ===
oil: mean=49.765164, std=6.051614

Built 53 TimeSeries. Future covariates: NONE. Past cov: oil_std only.

=== INTERNAL validation (for early stopping only) ===
Train idx: [0 : 335 ) len = 335
IntVal idx: [ 95 : 335 ) len = 240
IntVal covers dates: 2015-04-06 → 2015-12-01
Experiment Configuration:
  Trials per seed: 1
  Dataset seeds: 1 to 1
  Total experiments: 1
  Device: GPU

Starting TFT STORE-level H30 training for seed 1

[Seed 1] Training TFT (input=90, output=30, INTERNAL val for early-stopping, OFFICIAL val/test for reporting) ...


You are using a CUDA device ('NVIDIA GeForce RTX 4060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                              | Type                             | Params | Mode 
------------------------------------------------------------------------------------------------
0  | train_metrics                     | MetricCollection                 | 0      | train
1  | val_metrics                       | MetricCollection                 | 0      | train
2  | input_embeddings                  | _MultiEmbedding                  | 0      | train
3  | static_covariates_vsn             | _VariableSelectionNetwork        | 138 K  | train
4  | encoder_vsn              

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Checkpoint loaded (best by INTERNAL val_loss).


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



=== TFT_STORE_H30 OFFICIAL eval (paper window; report only) ===
Val  RMSE=0.215322, MAE=0.166161, R2=-2.109778
Test RMSE=0.186304, MAE=0.145831, R2=-1.301501

=== P50 q-risk (Val) ===
qrisk_p50_scaled_log = 0.427480
qrisk_p50_unscaled_log = 0.036526
qrisk_p50_sales = 0.311621

=== P50 q-risk (Test) ===
qrisk_p50_scaled_log = 0.434851
qrisk_p50_unscaled_log = 0.032575
qrisk_p50_sales = 0.321846
Using metrics.csv: /home/wangxc1117/TFTModel-use/geospatial-neural-adapter-dev/examples/try/use_admm_crood/sales_forecasting/TFT/sales_TFT_ADMM/store_level_new/compare_cov/no_future_cov/TFT_runs_favorita_store_paperWindow_h30_clean_noFutureCov_staticOH_oilOnly_addRelIdx_BATCH_oilStd/TFT_Favorita_seed_1/tft_favorita_store_paperWindow_h30_seed1_noFutureCov_staticOH_oilOnly_addRelIdx_BATCH_oilStd/version_0/metrics.csv
[plot_tft_loss_curve] Saved -> /home/wangxc1117/TFTModel-use/geospatial-neural-adapter-dev/examples/try/use_admm_crood/sales_forecasting/TFT/sales_TFT_ADMM/store_level_new/compare_cov