# TFT vs. Spatial Adapter Comparison with Tuning Parameter Selection

This notebook implements a comprehensive comparison between:
1. **TFT** - Linear baseline (no spatial term)
2. **Unregularized Spatial Adapter** - Neural spatial model without regularization
3. **Regularized Spatial Adapter** - Neural spatial model with optimized tau1, tau2 parameters

The experiment uses Optuna for hyperparameter optimization and evaluates performance across multiple random seeds.

my work: ols 換成 TFT 然後執行模擬

epoch 30 ->10
'n_time_steps' 1024 -> 512
'n_locations': 512 -> 256 
GPU n_trials_per_seed 20 -> 10 
n_dataset_seeds 10 -> 5
完整實驗待跑

In [None]:
# Import required libraries
import csv
import math
from pathlib import Path

import numpy as np
import optuna
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import pytorch_lightning as pl
from optuna.pruners import MedianPruner
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from typing import Tuple, Dict, Any, List

from darts import TimeSeries
from darts.models import TFTModel

from geospatial_neural_adapter.cpp_extensions import estimate_covariance
from geospatial_neural_adapter.utils.experiment import log_covariance_and_basis
from geospatial_neural_adapter.utils import (
    ModelCache,
    clear_gpu_memory,
    create_experiment_config,
    print_experiment_summary,
    get_device_info,
)
from geospatial_neural_adapter.models.spatial_basis_learner import SpatialBasisLearner
from geospatial_neural_adapter.models.spatial_neural_adapter import SpatialNeuralAdapter
from geospatial_neural_adapter.models.trend_model import TrendModel
from geospatial_neural_adapter.models.wrapper_examples.tft_wrapper import TFTWrapper
from geospatial_neural_adapter.data.generators import generate_time_synthetic_data
from geospatial_neural_adapter.data.preprocessing import (
    prepare_all_with_scaling,
    denormalize_predictions,
)
from geospatial_neural_adapter.metrics import compute_metrics
from geospatial_neural_adapter.models.pretrained_trend_model import (
    create_pretrained_trend_model,
)

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
plt.style.use('default')
sns.set_palette("husl")

print("✅ Imports successful (TFT backbone enabled; all OLS utilities removed).")


## 1. Parameter Configuration and Setup

In [None]:
# Experiment Configuration
EXPERIMENT_CONFIG = {
    'seed': 42,
    'n_time_steps': 512,  #1024
    'n_locations': 256,  #512
    'noise_std': 4.0,
    'eigenvalue': 16.0,
    'latent_dim': 1,
    'ckpt_dir': "admm_bcd_ckpts",
}

SPLIT_CONFIG = {
    "train_ratio": 0.70,
    "val_ratio": 0.15,   
}

TFT_CONFIG = {
    "input_chunk_length": 48,
    "output_chunk_length": 1,
    "hidden_size": 64,
    "lstm_layers": 1,
    "num_attention_heads": 4,
    "dropout": 0.10,
    "batch_size": 64,
    "n_epochs": 10,  #30
    "random_state": 42,
    "add_relative_index": True,
}

PL_TRAINER_KWARGS = (
    {"accelerator": "gpu", "devices": 1,
     "logger": True,                 
     "enable_progress_bar": True,   
     "enable_model_summary": False,
     "num_sanity_val_steps": 0}
    if torch.cuda.is_available()
    else {"accelerator": "cpu", "devices": 1,
          "logger": True,
          "enable_progress_bar": True,
          "enable_model_summary": False,
          "num_sanity_val_steps": 0}
)

# Spatial Neural Adapter Configuration using dataclasses
from geospatial_neural_adapter.models.spatial_neural_adapter import (
    SpatialNeuralAdapterConfig, ADMMConfig, TrainingConfig, BasisConfig
)

# ADMM Configuration
admm_config = ADMMConfig(
    rho=1.0,            
    dual_momentum=0.2,  
    max_iters=3000,     
    min_outer=20,       
    tol=1e-4,           
)

training_config = TrainingConfig(
    lr_mu=1e-2,           
    batch_size=64,        
    pretrain_epochs=5,    
    use_mixed_precision=False,
)

# Basis Configuration
basis_config = BasisConfig(
    phi_every=5,        
    phi_freeze=200,     
    matrix_reg=1e-6,    
    irl1_max_iters=10,  
    irl1_eps=1e-6,      
    irl1_tol=5e-4,      
)

# Complete Spatial Neural Adapter Configuration
SPATIAL_CONFIG = SpatialNeuralAdapterConfig(
    admm=admm_config,
    training=training_config,
    basis=basis_config
)

# Legacy config dict for backward compatibility (if needed)
CFG = SPATIAL_CONFIG.to_dict()
CFG.update(EXPERIMENT_CONFIG)
CFG.update({
    "split": SPLIT_CONFIG,
    "tft": TFT_CONFIG,
})

# Set random seed
np.random.seed(EXPERIMENT_CONFIG["seed"])
torch.manual_seed(EXPERIMENT_CONFIG["seed"])
Path(EXPERIMENT_CONFIG["ckpt_dir"]).mkdir(exist_ok=True)

# Device setup
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device_info = get_device_info()
print(f"Using {device_info['device'].upper()}: {device_info['device_name']}")
if device_info['device'] == 'cuda':
    print(f"   Memory: {device_info['memory_gb']} GB")

# Print configuration summary
print("\n=== Experiment Configuration ===")
for key, value in EXPERIMENT_CONFIG.items():
    print(f"{key}: {value}")

print("\n=== Split Configuration ===")
for k, v in SPLIT_CONFIG.items():
    print(f"{k}: {v}")

print("\n=== TFT Configuration (Baseline) ===")
for k, v in TFT_CONFIG.items():
    print(f"{k}: {v}")

print("\n=== Spatial Neural Adapter Configuration ===")
SPATIAL_CONFIG.log_config()

## 2. Initialize Utilities
 目前減少跑的次數 確認完之後要跑完整的

In [None]:
# Create experiment configuration
EXPERIMENT_TRIALS_CONFIG = create_experiment_config(
    n_trials_per_seed=10 if torch.cuda.is_available() else 50,
    n_dataset_seeds=5,
    seed_range_start=1,
    seed_range_end=6,
)

print_experiment_summary(EXPERIMENT_TRIALS_CONFIG)
print("Utilities initialized successfully!")

## 3. Data Generation and Preprocessing

In [None]:
# Generate synthetic data with meaningful correlations
print("Generating correlated synthetic data...")

locs = np.linspace(-3, 3, CFG["n_locations"])
cat_features, cont_features, targets = generate_time_synthetic_data(
    locs=locs,
    n_time_steps=CFG["n_time_steps"],
    noise_std=CFG["noise_std"],
    eigenvalue=CFG["eigenvalue"],
    eta_rho=0.8,
    f_rho=0.6,
    global_mean=50.0,
    feature_noise_std=0.1,
    non_linear_strength=0.2,
    seed=CFG["seed"]
)

# Prepare datasets with scaling
train_dataset, val_dataset, test_dataset, preprocessor = prepare_all_with_scaling(
    cat_features=cat_features,
    cont_features=cont_features,
    targets=targets,
    train_ratio=SPLIT_CONFIG["train_ratio"],
    val_ratio=SPLIT_CONFIG["val_ratio"],
    feature_scaler_type="standard",
    target_scaler_type="standard",
    fit_on_train_only=True
)

train_loader = DataLoader(train_dataset, batch_size=CFG["tft"]["batch_size"], shuffle=True)

# Extract tensors (scaled)
_, train_X, train_y = train_dataset.tensors
_, val_X,   val_y   = val_dataset.tensors
_, test_X,  test_y  = test_dataset.tensors

if train_y.ndim == 2: train_y = train_y.unsqueeze(-1)
if val_y.ndim   == 2: val_y   = val_y.unsqueeze(-1)
if test_y.ndim  == 2: test_y  = test_y.unsqueeze(-1)

y_all_scaled = torch.cat([train_y, val_y, test_y], dim=0).squeeze(-1).cpu().numpy()  # (T_full, N)
x_all_scaled = torch.cat([train_X, val_X, test_X], dim=0).cpu().numpy()              # (T_full, N, F)

T_full = y_all_scaled.shape[0]
N      = y_all_scaled.shape[1]
F      = x_all_scaled.shape[2]

train_T = int(T_full * SPLIT_CONFIG["train_ratio"])
val_T   = int(T_full * (SPLIT_CONFIG["train_ratio"] + SPLIT_CONFIG["val_ratio"]))
test_T  = T_full

series_all   = TimeSeries.from_values(y_all_scaled)      # shape (T_full, N)
series_train = series_all[:train_T]
series_val   = series_all[train_T:val_T]
series_test  = series_all[val_T:]

y_true_future = y_all_scaled[train_T:]   # (T_val+T_test, N)


p_dim = train_X.shape[-1]
print(f"Data shapes (cont_features, targets): {cont_features.shape}, {targets.shape}")
print(f"Original targets - Mean: {targets.mean():.2f}, Std: {targets.std():.2f}")
print(f"Original targets - Range: {targets.min():.2f} to {targets.max():.2f}")
print(f"Feature dimension: {p_dim}")
print(f"T_full={T_full}, N={N}, F={F} | splits: train={train_T}, val={val_T-train_T}, test={test_T-val_T}")
print("Scaled series prepared for TFT (Darts).")

In [None]:
# # Visualize data characteristics
# fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# # Plot 1: Target distribution
# axes[0, 0].hist(targets.flatten(), bins=30, alpha=0.7, edgecolor='black')
# axes[0, 0].set_title('Target Distribution')
# axes[0, 0].set_xlabel('Target Value')
# axes[0, 0].set_ylabel('Frequency')
# axes[0, 0].grid(True, alpha=0.3)

# # Plot 2: Spatial pattern at first time step
# axes[0, 1].plot(locs, targets[0, :], 'o-', linewidth=2, markersize=4)
# axes[0, 1].set_title('Spatial Pattern at t=0')
# axes[0, 1].set_xlabel('Location')
# axes[0, 1].set_ylabel('Target Value')
# axes[0, 1].grid(True, alpha=0.3)

# # Plot 3: Temporal pattern at middle location
# time_steps = np.arange(len(targets))
# axes[1, 0].plot(time_steps, targets[:, 25], linewidth=2)
# axes[1, 0].set_title('Temporal Pattern at Location 25')
# axes[1, 0].set_xlabel('Time Step')
# axes[1, 0].set_ylabel('Target Value')
# axes[1, 0].grid(True, alpha=0.3)

# # Plot 4: Feature correlations
# feature_corrs = []
# for i in range(cont_features.shape[-1]):
#     corr = np.corrcoef(targets.flatten(), cont_features[:, :, i].flatten())[0, 1]
#     feature_corrs.append(corr)

# axes[1, 1].bar(range(len(feature_corrs)), feature_corrs, alpha=0.7, edgecolor='black')
# axes[1, 1].set_title('Feature-Target Correlations')
# axes[1, 1].set_xlabel('Feature Index')
# axes[1, 1].set_ylabel('Correlation')
# axes[1, 1].grid(True, alpha=0.3)

# plt.tight_layout()
# plt.show()


In [None]:
# === 多-seed 本地粗調（不改全域） ===
def multi_seed_tft_tuning(
    dataset_seed_list=(1,2,3),
    split_cfg={"train_ratio": 0.70, "val_ratio": 0.15},
    sim_base={"n_time_steps": 512, "n_locations": 256, "noise_std": 4.0, "eigenvalue": 16.0,
              "eta_rho": 0.8, "f_rho": 0.6, "global_mean": 50.0,
              "feature_noise_std": 0.1, "non_linear_strength": 0.2},
    candidates=None
):
    import numpy as np, torch, pytorch_lightning as pl
    from typing import Dict, Any, List, Tuple
    from darts import TimeSeries
    from darts.models import TFTModel
    from geospatial_neural_adapter.data.generators import generate_time_synthetic_data
    from geospatial_neural_adapter.data.preprocessing import prepare_all_with_scaling, denormalize_predictions

    # ==== Helpers: shape & metrics ====
    def _to_tensor_2d(x):
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x)
        if x.ndim == 3 and x.shape[-1] == 1:  # (T, N, 1) -> (T, N)
            x = x.squeeze(-1)
        return x.float()

    def _ensure_TN(x, N_expected):
        # 統一為 (T, N)；若收到 (N, T) 就轉置
        if x.ndim != 2:
            raise ValueError(f"Expect 2D tensor, got shape {tuple(x.shape)}")
        T, N = x.shape
        if N != N_expected and T == N_expected:
            return x.T.contiguous()
        return x

    def _denorm_to_TN(x, preproc, N_expected):
        # denorm -> tensor -> squeeze -> ensure (T, N)
        if isinstance(x, torch.Tensor):
            x = x.detach().cpu().numpy()
        y_den = denormalize_predictions(x, preproc)        # 可能回 np.ndarray 或 torch.Tensor
        y_den = _to_tensor_2d(y_den)                       # -> torch.Tensor (T,N) or (N,T)
        y_den = _ensure_TN(y_den, N_expected)              # -> (T, N)
        return y_den

    def _metrics(true_den, pred_den):
        eps = 1e-8
        mse  = torch.mean((pred_den-true_den)**2); rmse  = torch.sqrt(mse).item()
        rmspe = torch.sqrt(torch.mean(((pred_den-true_den)/(true_den.abs()+eps))**2)).item()
        ss_res = torch.sum((true_den-pred_den)**2)
        ss_tot = torch.sum((true_den - torch.mean(true_den,0,True))**2)
        r2 = 1.0 - (ss_res/(ss_tot+eps)).item()
        return rmse, rmspe, r2

    if candidates is None:
        candidates = [
            {"name":"A_len24_h64",
             "input_chunk_length":24, "output_chunk_length":1,
             "hidden_size":64, "lstm_layers":1, "num_attention_heads":4,
             "dropout":0.10, "batch_size":64, "n_epochs":20,
             "add_relative_index":True, "random_state":42,
             # 可選： "gradient_clip_val": 0.5,
             # 可選： "early_stopping":{"monitor":"val_loss","patience":3,"min_delta":0.0},
            },
            {"name":"B_len48_h64",
             "input_chunk_length":48, "output_chunk_length":1,
             "hidden_size":64, "lstm_layers":1, "num_attention_heads":4,
             "dropout":0.10, "batch_size":64, "n_epochs":20,
             "add_relative_index":True, "random_state":42,
            },
            {"name":"C_len72_h64",
             "input_chunk_length":72, "output_chunk_length":1,
             "hidden_size":64, "lstm_layers":1, "num_attention_heads":4,
             "dropout":0.10, "batch_size":64, "n_epochs":20,
             "add_relative_index":True, "random_state":42,
            },
        ]

    # ==== make series ====
    def _mk_series(train_y, val_y, test_y, train_X, val_X, test_X):
        if train_y.ndim == 2: train_y = train_y.unsqueeze(-1)
        if val_y.ndim   == 2: val_y   = val_y.unsqueeze(-1)
        if test_y.ndim  == 2: test_y  = test_y.unsqueeze(-1)
        y_all = torch.cat([train_y, val_y, test_y], 0).squeeze(-1).cpu().numpy()
        x_all = torch.cat([train_X, val_X, test_X], 0).cpu().numpy()
        T_full, N, F = y_all.shape[0], y_all.shape[1], x_all.shape[2]
        T_tr = int(T_full*split_cfg["train_ratio"])
        T_va = int(T_full*(split_cfg["train_ratio"]+split_cfg["val_ratio"]))
        val_len, test_len = T_va - T_tr, T_full - T_va
        series_list, past_list = [], []
        for i in range(N):
            series_list.append(TimeSeries.from_values(y_all[:, i]))
            past_list.append(TimeSeries.from_values(x_all[:, i, :]))
        return series_list, past_list, T_tr, T_va, val_len, test_len, N

    # ==== Lightning Trainer kwargs（不依賴外部） ====
    trainer_kwargs = ({"accelerator":"gpu","devices":1,"logger":False,"enable_progress_bar":False,
                       "enable_model_summary":False,"num_sanity_val_steps":0}
                      if torch.cuda.is_available()
                      else {"accelerator":"cpu","devices":1,"logger":False,"enable_progress_bar":False,
                            "enable_model_summary":False,"num_sanity_val_steps":0})

    from statistics import mean, median, pstdev
    summary = {c["name"]: {"seed_metrics": []} for c in candidates}

    for ds_seed in dataset_seed_list:
        # 生成該 seed 的資料
        locs = np.linspace(-3, 3, sim_base["n_locations"])
        catf, conf, tgts = generate_time_synthetic_data(
            locs=locs, n_time_steps=sim_base["n_time_steps"],
            noise_std=sim_base["noise_std"], eigenvalue=sim_base["eigenvalue"],
            eta_rho=sim_base["eta_rho"], f_rho=sim_base["f_rho"],
            global_mean=sim_base["global_mean"], feature_noise_std=sim_base["feature_noise_std"],
            non_linear_strength=sim_base["non_linear_strength"], seed=ds_seed
        )
        train_ds, val_ds, test_ds, preproc = prepare_all_with_scaling(
            cat_features=catf, cont_features=conf, targets=tgts,
            train_ratio=split_cfg["train_ratio"], val_ratio=split_cfg["val_ratio"],
            feature_scaler_type="standard", target_scaler_type="standard",
            fit_on_train_only=True
        )
        _, train_X, train_y = train_ds.tensors
        _, val_X,   val_y   = val_ds.tensors
        _, test_X,  test_y  = test_ds.tensors

        series_list, past_list, T_tr, T_va, val_len, test_len, N = _mk_series(
            train_y, val_y, test_y, train_X, val_X, test_X
        )

        # 真值（反標準化 → (T,N)）
        val_y_den  = _denorm_to_TN(val_y.squeeze(-1),  preproc, N)
        test_y_den = _denorm_to_TN(test_y.squeeze(-1), preproc, N)

        for cfg in candidates:
            pl.seed_everything(cfg.get("random_state", 123))

            # EarlyStopping + gradient clipping 放進 trainer kwargs
            tk = dict(trainer_kwargs)
            if cfg.get("early_stopping"):
                from pytorch_lightning.callbacks import EarlyStopping
                es = EarlyStopping(monitor=cfg["early_stopping"]["monitor"],
                                   patience=cfg["early_stopping"]["patience"],
                                   min_delta=cfg["early_stopping"]["min_delta"], mode="min")
                tk["callbacks"] = list(tk.get("callbacks", [])) + [es]
            if cfg.get("gradient_clip_val") is not None:
                tk["gradient_clip_val"] = cfg["gradient_clip_val"]

            # 建立模型
            model = TFTModel(
                input_chunk_length=cfg["input_chunk_length"],
                output_chunk_length=cfg["output_chunk_length"],
                hidden_size=cfg["hidden_size"], lstm_layers=cfg["lstm_layers"],
                num_attention_heads=cfg["num_attention_heads"],
                dropout=cfg["dropout"], batch_size=cfg["batch_size"],
                n_epochs=cfg["n_epochs"], add_relative_index=cfg["add_relative_index"],
                random_state=cfg["random_state"],
                # 若需要可加：optimizer_kwargs=cfg.get("optimizer_kwargs"),
                pl_trainer_kwargs=tk,
            )

            # 拆分序列與 past covariates
            series_train = [s[:T_tr] for s in series_list]
            series_val   = [s[T_tr:T_va] for s in series_list]
            past_train   = [c[:T_tr] for c in past_list]
            past_val     = [c[T_tr:T_va] for c in past_list]

            # 訓練
            model.fit(series=series_train, val_series=series_val,
                      past_covariates=past_train, val_past_covariates=past_val, verbose=True)

            # 滾動 one-step 預測（val + test）
            import numpy as _np
            yval, ytest = [], []
            for i in range(N):
                preds = model.historical_forecasts(
                    series=series_list[i], past_covariates=past_list[i],
                    start=T_tr, forecast_horizon=1, retrain=False, verbose=False
                ).values()
                yval.append(preds[:val_len])
                ytest.append(preds[val_len:val_len+test_len])
            yval = _np.stack(yval, 1)   # (T_val,  N)
            ytest = _np.stack(ytest, 1) # (T_test, N)

            # 反標準化 → (T,N)
            yval_den  = _denorm_to_TN(yval,  preproc, N)
            ytest_den = _denorm_to_TN(ytest, preproc, N)

            # 保險檢查
            assert val_y_den.shape == yval_den.shape
            assert test_y_den.shape == ytest_den.shape

            # 指標
            rmse_v, rmspe_v, r2_v   = _metrics(val_y_den,  yval_den)
            rmse_t, rmspe_t, r2_tst = _metrics(test_y_den, ytest_den)

            summary[cfg["name"]]["seed_metrics"].append({
                "seed": ds_seed, "rmse_val": rmse_v, "rmspe_val": rmspe_v, "r2_val": r2_v,
                "rmse_test": rmse_t, "rmspe_test": rmspe_t, "r2_test": r2_tst
            })
            print(f"[{cfg['name']} | seed={ds_seed}] "
                  f"val RMSE={rmse_v:.4f}, R²={r2_v:.4f} | "
                  f"test RMSE={rmse_t:.4f}, R²={r2_tst:.4f}")

    # ==== 聚合各組表現（以 test RMSE 為主、std 為穩定性、同時回報 R²>0 比例）====
    print("\n=== Multi-seed Summary ===")
    ranking = []
    for name, rec in summary.items():
        mets = rec["seed_metrics"]
        mr = [m["rmse_test"] for m in mets]
        rr = [m["r2_test"] for m in mets]
        from statistics import mean, median, pstdev
        entry = {
            "name": name,
            "mean_rmse_test": mean(mr),
            "median_rmse_test": median(mr),
            "std_rmse_test": pstdev(mr) if len(mr) > 1 else 0.0,
            "r2_pos_ratio": sum(1 for x in rr if x > 0.0) / len(rr),
            "detail": mets
        }
        ranking.append(entry)
        print(f"- {name}: mean RMSE={entry['mean_rmse_test']:.4f} | "
              f"median RMSE={entry['median_rmse_test']:.4f} | "
              f"std={entry['std_rmse_test']:.4f} | "
              f"R²>0 ratio={entry['r2_pos_ratio']:.2f}")

    ranking.sort(key=lambda x: (x["mean_rmse_test"], x["std_rmse_test"]))
    best = ranking[0]
    print("\n>>> 建議採用（以 mean test RMSE 最小、std 次要）：", best["name"])
    print("   平均RMSE={:.4f} | 中位數RMSE={:.4f} | std={:.4f} | R²>0比例={:.2f}".format(
        best["mean_rmse_test"], best["median_rmse_test"], best["std_rmse_test"], best["r2_pos_ratio"]
    ))
    chosen = next(c for c in candidates if c["name"] == best["name"])
    suggest_TFT_CONFIG = {
        "input_chunk_length": chosen["input_chunk_length"],
        "output_chunk_length": chosen["output_chunk_length"],
        "hidden_size": chosen["hidden_size"],
        "lstm_layers": chosen["lstm_layers"],
        "num_attention_heads": chosen["num_attention_heads"],
        "dropout": chosen["dropout"],
        "batch_size": chosen["batch_size"],
        "n_epochs": chosen["n_epochs"],
        "add_relative_index": chosen["add_relative_index"],
        "random_state": chosen["random_state"],
    }
    print("\n請手動回填到全域 TFT_CONFIG（核心欄位）：")
    print(suggest_TFT_CONFIG)

# 範例呼叫（你可改 seed 列表與候選組合）
multi_seed_tft_tuning(dataset_seed_list=(1,2,3))


## 4. Baseline Implementation
在Main Experiment Function 做了訓練與測試

In [None]:
# ==== TFT Baseline (demo style; no spatial term, no future covariates) ====
print("Training TFT baseline (demo-only, no spatial)...")

pl.seed_everything(CFG["tft"]["random_state"])

# 1) Per-location target & past covariates（不建 future covariates）
series_list, past_cov_list = [], []
for i in range(N):
    s_i   = TimeSeries.from_values(y_all_scaled[:, i])       # target (T,) -> TS
    cov_i = TimeSeries.from_values(x_all_scaled[:, i, :])    # cont features (T, F) -> TS
    series_list.append(s_i)
    past_cov_list.append(cov_i)

# 2) Split（train/val 對齊）
series_train_list = [s[:train_T]      for s in series_list]
series_val_list   = [s[train_T:val_T] for s in series_list]
past_cov_train    = [c[:train_T]      for c in past_cov_list]
val_past_covs     = [c[train_T:val_T] for c in past_cov_list]

# 3) Train TFT（只傳 past；驗證也傳 val_past_covariates）
tft = TFTModel(
    **TFT_CONFIG,                      # 需含 add_relative_index=True 或等價 encoders
    pl_trainer_kwargs=PL_TRAINER_KWARGS
)
tft.fit(
    series=series_train_list,
    val_series=series_val_list,
    past_covariates=past_cov_train,
    val_past_covariates=val_past_covs,
    verbose=True
)

# 4) One-step rolling forecast via historical_forecasts（只用 past_covariates）
val_len  = val_y.shape[0]
test_len = test_y.shape[0]

yhat_val_list, yhat_test_list = [], []
for i in range(N):
    yhat_i = tft.historical_forecasts(
        series=series_list[i],
        past_covariates=past_cov_list[i],
        start=train_T,
        forecast_horizon=1,
        retrain=False,
        verbose=False
    ).values()  # (Tval+Ttest,)
    yhat_val_list.append(yhat_i[:val_len])
    yhat_test_list.append(yhat_i[val_len:val_len + test_len])

# 回到 (T, N)（scaled space）
yhat_val_sc  = np.stack(yhat_val_list,  axis=1)
yhat_test_sc = np.stack(yhat_test_list, axis=1)

# 5) 指標在原始尺度計算（denorm → torch.Tensor 2D）
def to_tensor_2d(x):
    if isinstance(x, np.ndarray): x = torch.from_numpy(x)
    if x.ndim == 3: x = x.squeeze(-1)
    return x.float()

y_val_den_t     = to_tensor_2d(denormalize_predictions(val_y.squeeze(-1),  preprocessor))               # (Tval, N)
y_test_den_t    = to_tensor_2d(denormalize_predictions(test_y.squeeze(-1), preprocessor))               # (Ttest, N)
yhat_val_den_t  = to_tensor_2d(denormalize_predictions(torch.from_numpy(yhat_val_sc).float(),  preprocessor))
yhat_test_den_t = to_tensor_2d(denormalize_predictions(torch.from_numpy(yhat_test_sc).float(), preprocessor))

rmse_tft_val,  mae_tft_val,  r2_tft_val  = compute_metrics(y_val_den_t,  yhat_val_den_t)
rmse_tft_test, mae_tft_test, r2_tft_test = compute_metrics(y_test_den_t, yhat_test_den_t)

print(f"TFT Validation - RMSE: {rmse_tft_val:.4f}, R²: {r2_tft_val:.4f}")
print(f"TFT Test       - RMSE: {rmse_tft_test:.4f}, R²: {r2_tft_test:.4f}")


## 5. Main Experiment Function


In [None]:
def _mk_series_per_station(
    train_y, val_y, test_y, train_X, val_X, test_X, split_cfg
) -> Tuple[List[TimeSeries], List[TimeSeries], List[TimeSeries], int, int, int, int, int]:
    """
    (T, N, F) / (T, N[,1]) 串回完整時間，轉成每站一條 TimeSeries + covariates
    retrrn: series_list, past_cov_list, future_cov_list, T_tr, T_va, val_len, test_len, N, F
    """
    if train_y.ndim == 2: train_y = train_y.unsqueeze(-1)
    if val_y.ndim   == 2: val_y   = val_y.unsqueeze(-1)
    if test_y.ndim  == 2: test_y  = test_y.unsqueeze(-1)

    y_all_scaled = torch.cat([train_y, val_y, test_y], dim=0).squeeze(-1).cpu().numpy()  # (T_full, N)
    x_all_scaled = torch.cat([train_X, val_X, test_X], dim=0).cpu().numpy()              # (T_full, N, F)

    T_full = y_all_scaled.shape[0]
    N      = y_all_scaled.shape[1]
    F      = x_all_scaled.shape[2]

    T_tr = int(T_full * split_cfg["train_ratio"])
    T_va = int(T_full * (split_cfg["train_ratio"] + split_cfg["val_ratio"]))
    val_len  = T_va - T_tr
    test_len = T_full - T_va

    series_list, past_cov_list, future_cov_list = [], [], []
    for i in range(N):
        s_i   = TimeSeries.from_values(y_all_scaled[:, i])    # (T,)
        cov_i = TimeSeries.from_values(x_all_scaled[:, i, :]) # (T, F)
        series_list.append(s_i)
        past_cov_list.append(cov_i)
        future_cov_list.append(cov_i)  

    return series_list, past_cov_list, future_cov_list, T_tr, T_va, val_len, test_len, N, F


def build_tft_and_data(
    *,
    cat_features: np.ndarray,          
    cont_features: np.ndarray,         
    targets: np.ndarray,               
    split_cfg: Dict[str, Any],         
    tft_cfg: Dict[str, Any],           
    pl_trainer_kwargs: Dict[str, Any],
) -> Dict[str, Any]:
    # 1) scaling
    train_ds, val_ds, test_ds, preproc = prepare_all_with_scaling(
        cat_features=cat_features,
        cont_features=cont_features,
        targets=targets,
        train_ratio=split_cfg["train_ratio"],
        val_ratio=split_cfg["val_ratio"],
        feature_scaler_type="standard",
        target_scaler_type="standard",
        fit_on_train_only=True,
    )
    _, train_X, train_y = train_ds.tensors
    _, val_X,   val_y   = val_ds.tensors
    _, test_X,  test_y  = test_ds.tensors

    (series_list, past_cov_list, _future_cov_list,
     T_tr, T_va, val_len, test_len, N, F) = _mk_series_per_station(
        train_y, val_y, test_y, train_X, val_X, test_X, split_cfg
    )

    pl.seed_everything(tft_cfg.get("random_state", 42))
    tft_model = TFTModel(**tft_cfg, pl_trainer_kwargs=pl_trainer_kwargs)
    tft_model.fit(
        series=[s[:T_tr] for s in series_list],
        val_series=[s[T_tr:T_va] for s in series_list],
        past_covariates=[c[:T_tr] for c in past_cov_list],
        val_past_covariates=[c[T_tr:T_va] for c in past_cov_list],
        verbose=True
    )

    yhat_val_list, yhat_test_list = [], []
    for i in range(N):
        preds = tft_model.historical_forecasts(
            series=series_list[i],
            past_covariates=past_cov_list[i],
            start=T_tr, forecast_horizon=1, retrain=False, verbose=False
        ).values()  # (Tval+Ttest,)
        yhat_val_list.append(preds[:val_len])
        yhat_test_list.append(preds[val_len:val_len+test_len])
    yhat_val_sc = np.stack(yhat_val_list, 1)
    yhat_test_sc = np.stack(yhat_test_list, 1)

    def _t2d(x):
        if isinstance(x, np.ndarray): x = torch.from_numpy(x)
        if x.ndim == 3: x = x.squeeze(-1)
        return x.float()

    val_y_den_t     = _t2d(denormalize_predictions(val_y.squeeze(-1),  preproc))
    test_y_den_t    = _t2d(denormalize_predictions(test_y.squeeze(-1), preproc))
    yhat_val_den_t  = _t2d(denormalize_predictions(torch.from_numpy(yhat_val_sc).float(),  preproc))
    yhat_test_den_t = _t2d(denormalize_predictions(torch.from_numpy(yhat_test_sc).float(), preproc))

    return {
        "tft_model": tft_model,
        "preprocessor": preproc,
        "split_idx": {"T_tr": T_tr, "T_va": T_va, "val_len": val_len, "test_len": test_len},
        "N": N, "F": F,
        "val_y_den_t": val_y_den_t, "test_y_den_t": test_y_den_t,
        "yhat_val_den_t": yhat_val_den_t, "yhat_test_den_t": yhat_test_den_t,
        "val_y_sc": val_y.squeeze(-1).cpu().numpy(),
        "test_y_sc": test_y.squeeze(-1).cpu().numpy(),
        "yhat_val_sc": yhat_val_sc, "yhat_test_sc": yhat_test_sc,
    }


def build_spatial_adapter_from_demo(
    *,
    tft_model,
    device,
    n_locations: int,
    latent_dim: int,
    num_features: int,      
    train_loader,
    val_cont: torch.Tensor,
    val_y: torch.Tensor,
    locs,
    adapter_cfg,            
    tau1: float,
    tau2: float,
    writer=None,
):
    """
    TFTWrapper -> create_pretrained_trend_model(含 residual head) -> SpatialNeuralAdapter
    """
    tft_wrapper = TFTWrapper(
        tft_model=tft_model,
        num_locations=n_locations,
        num_features=num_features,
    )

    tft_trend_model = create_pretrained_trend_model(
        pretrained_model=tft_wrapper,
        input_shape=(None, n_locations, num_features),  # (B, N, F)
        output_shape=(None, n_locations),               # (B, N)
        model_type="custom",
        freeze_backbone=True,       
        add_residual_head=True,     
        residual_hidden_dim=64,     
        dropout_rate=0.1,
    )

    basis = SpatialBasisLearner(n_locations, latent_dim).to(device)

    adapter = SpatialNeuralAdapter(
        trend=tft_trend_model,
        basis=basis,
        train_loader=train_loader,
        val_cont=val_cont.to(device),
        val_y=val_y.to(device),
        locs=locs,
        config=adapter_cfg,   
        device=device,
        writer=writer,
        tau1=tau1,
        tau2=tau2,
    )
    return adapter


def train_unregularized_adapter(
    adapter: SpatialNeuralAdapter,
    pretrain_epochs: int = 5
) -> Dict[str, Any]:
    adapter.pretrain_trend(epochs=pretrain_epochs)
    adapter.init_basis_dense()
    adapter.run()
    return {"adapter": adapter}


def train_regularized_adapter_with_optuna(
    build_adapter_fn,                             # closure: (tau1, tau2) -> SpatialNeuralAdapter
    val_y_den_t: torch.Tensor,
    predict_val_den_fn,                           # closure: (adapter) -> torch.Tensor denorm predictions on val
    n_trials: int = 30,
    study_name: str = "TFT_spatial_adapter_reg",
) -> Dict[str, Any]:
    def objective(trial: optuna.trial.Trial):
        tau1 = trial.suggest_float("tau1", 1e-4, 1e8, log=True)
        tau2 = trial.suggest_float("tau2", 1e-4, 1e8, log=True)
        adapter = build_adapter_fn(tau1, tau2)
        adapter.pretrain_trend(epochs=3)
        adapter.init_basis_dense()
        adapter.run()

        y_val_pred_den = predict_val_den_fn(adapter)
        rmse, mae, r2 = compute_metrics(val_y_den_t, y_val_pred_den)

        trial.set_user_attr("rmse", rmse)
        trial.set_user_attr("mae", mae)
        trial.set_user_attr("r2",  r2)
        return rmse

    study = optuna.create_study(
        study_name=study_name, direction="minimize",
        sampler=optuna.samplers.TPESampler(),
        pruner=MedianPruner(n_warmup_steps=5),
        load_if_exists=False,
    )
    study.optimize(objective, n_trials=n_trials, n_jobs=1)

    best = study.best_trial
    return {
        "tau1": best.params["tau1"],
        "tau2": best.params["tau2"],
        "rmse": best.user_attrs["rmse"],
        "mae":  best.user_attrs["mae"],
        "r2":   best.user_attrs["r2"],
        "best_trial": best.number,
        "study": study,
    }


def evaluate_adapter_on_test_from_demo(
    adapter: SpatialNeuralAdapter,
    denorm_true_test: torch.Tensor,
    predict_test_den_fn,          
) -> Dict[str, float]:

    y_test_pred_den = predict_test_den_fn(adapter)
    rmse, mae, r2 = compute_metrics(denorm_true_test, y_test_pred_den)
    return {"rmse": rmse, "mae": mae, "r2": r2}

In [None]:
def run_one_experiment(dataset_seed: int, n_trials: int = 30):
    log_root = Path("TFT_runs") / f"TFT_seed_{dataset_seed}"
    log_root.mkdir(parents=True, exist_ok=True)

    # Data 
    catf, conf, tgts = generate_time_synthetic_data(
        locs=locs,
        n_time_steps=CFG["n_time_steps"],
        noise_std=CFG["noise_std"],
        eigenvalue=CFG["eigenvalue"],
        eta_rho=0.8,
        f_rho=0.6,
        global_mean=50.0,
        feature_noise_std=0.1,
        non_linear_strength=0.2,
        seed=dataset_seed,
    )
    train_ds, val_ds, test_ds, preproc = prepare_all_with_scaling(
        cat_features=catf, cont_features=conf, targets=tgts,
        train_ratio=SPLIT_CONFIG["train_ratio"], val_ratio=SPLIT_CONFIG["val_ratio"],
        feature_scaler_type="standard", target_scaler_type="standard", fit_on_train_only=True
    )
    train_loader = DataLoader(train_ds, batch_size=SPATIAL_CONFIG.training.batch_size, shuffle=True)
    _, val_X,  val_y  = val_ds.tensors
    _, test_X, test_y = test_ds.tensors

    # TFT baseline
    tft_pack = build_tft_and_data(
        cat_features=catf, cont_features=conf, targets=tgts,
        split_cfg=SPLIT_CONFIG, tft_cfg=TFT_CONFIG, pl_trainer_kwargs=PL_TRAINER_KWARGS,
    )
    tft_model        = tft_pack["tft_model"]
    val_y_den_t      = tft_pack["val_y_den_t"]
    test_y_den_t     = tft_pack["test_y_den_t"]
    yhat_val_den_t   = tft_pack["yhat_val_den_t"]
    yhat_test_den_t  = tft_pack["yhat_test_den_t"]

    rmse_tft, mae_tft, r2_tft = compute_metrics(val_y_den_t,  yhat_val_den_t)
    rmse_tft_test, mae_tft_test, r2_tft_test = compute_metrics(test_y_den_t, yhat_test_den_t)

    # Unregularized adapter 
    writer_boot = SummaryWriter(log_dir=log_root / "bootstrap")
    adapter_unreg = build_spatial_adapter_from_demo(
        tft_model=tft_model,
        device=DEVICE,
        n_locations=EXPERIMENT_CONFIG["n_locations"],
        latent_dim=EXPERIMENT_CONFIG["latent_dim"],
        num_features=val_X.shape[-1],
        train_loader=train_loader,
        val_cont=val_X,
        val_y=val_y,
        locs=locs,
        adapter_cfg=SPATIAL_CONFIG,   
        tau1=0.0, tau2=0.0,
        writer=writer_boot,
    )
    train_unregularized_adapter(adapter_unreg, pretrain_epochs=5)
    writer_boot.close()

    def predict_val_den_fn(adapter):
        with torch.no_grad():
            y_pred_sc = adapter.predict(val_X.to(DEVICE), val_y.to(DEVICE))  
            if y_pred_sc.ndim == 3:
                y_pred_sc = y_pred_sc.squeeze(-1)
            # -> CPU numpy
            y_pred_sc_np = y_pred_sc.detach().cpu().numpy()
            y_pred_den = denormalize_predictions(y_pred_sc_np, preproc)       
            if isinstance(y_pred_den, np.ndarray):
                y_pred_den = torch.from_numpy(y_pred_den).float()
            return y_pred_den  

    def predict_test_den_fn(adapter):
        with torch.no_grad():
            y_pred_sc = adapter.predict(test_X.to(DEVICE), test_y.to(DEVICE)) 
            if y_pred_sc.ndim == 3:
                y_pred_sc = y_pred_sc.squeeze(-1)
            # -> CPU numpy
            y_pred_sc_np = y_pred_sc.detach().cpu().numpy()
            y_pred_den = denormalize_predictions(y_pred_sc_np, preproc)
            if isinstance(y_pred_den, np.ndarray):
                y_pred_den = torch.from_numpy(y_pred_den).float()
            return y_pred_den  
        
    y_unreg_val_den_t  = predict_val_den_fn(adapter_unreg)
    y_unreg_test_den_t = predict_test_den_fn(adapter_unreg)
    rmse_unreg, mae_unreg, r2_unreg = compute_metrics(val_y_den_t,  y_unreg_val_den_t)
    rmse_unreg_test, mae_unreg_test, r2_unreg_test = compute_metrics(test_y_den_t, y_unreg_test_den_t)

    # Regularized (Optuna) 
    def build_adapter_fn(tau1: float, tau2: float):
        writer = SummaryWriter(log_dir=log_root / f"trial_tau1_{tau1:.3g}_tau2_{tau2:.3g}")
        adapter = build_spatial_adapter_from_demo(
            tft_model=tft_model,
            device=DEVICE,
            n_locations=EXPERIMENT_CONFIG["n_locations"],
            latent_dim=EXPERIMENT_CONFIG["latent_dim"],
            num_features=val_X.shape[-1],
            train_loader=train_loader,
            val_cont=val_X,
            val_y=val_y,
            locs=locs,
            adapter_cfg=SPATIAL_CONFIG,  
            tau1=tau1, tau2=tau2,
            writer=writer,
        )
        adapter._tmp_writer = writer
        return adapter

    def predict_val_den_for_optuna(adapter):
        y_pred = predict_val_den_fn(adapter)
        if hasattr(adapter, "_tmp_writer") and adapter._tmp_writer is not None:
            adapter._tmp_writer.close()
            adapter._tmp_writer = None
        return y_pred

    reg_search = train_regularized_adapter_with_optuna(
        build_adapter_fn=build_adapter_fn,
        val_y_den_t=val_y_den_t,
        predict_val_den_fn=predict_val_den_for_optuna,
        n_trials=n_trials,
        study_name=f"TFT_spatial_adapter_reg_ds{dataset_seed}",
    )
    tau1_opt, tau2_opt = reg_search["tau1"], reg_search["tau2"]
    rmse_opt, mae_opt, r2_opt = reg_search["rmse"], reg_search["mae"], reg_search["r2"]
    best_no = reg_search["best_trial"]

    # Best adapter → retrain → test
    writer_best = SummaryWriter(log_dir=log_root / f"best_tau1_{tau1_opt:.3g}_tau2_{tau2_opt:.3g}")
    adapter_best = build_spatial_adapter_from_demo(
        tft_model=tft_model,
        device=DEVICE,
        n_locations=EXPERIMENT_CONFIG["n_locations"],
        latent_dim=EXPERIMENT_CONFIG["latent_dim"],
        num_features=val_X.shape[-1],
        train_loader=train_loader,
        val_cont=val_X,
        val_y=val_y,
        locs=locs,
        adapter_cfg=SPATIAL_CONFIG,   
        tau1=tau1_opt, tau2=tau2_opt,
        writer=writer_best,
    )
    train_unregularized_adapter(adapter_best, pretrain_epochs=5)
    writer_best.close()

    y_reg_test_den_t = predict_test_den_fn(adapter_best)
    rmse_reg_test, mae_reg_test, r2_reg_test = compute_metrics(test_y_den_t, y_reg_test_den_t)

    # write
    csv_path = Path("metrics_summary_TFT.csv")
    write_header = not csv_path.exists()
    with csv_path.open("a", newline="") as f:
        w = csv.writer(f)
        if write_header:
            w.writerow(["seed","model","trial","tau1","tau2","rmse_val","mae_val","r2_val","rmse_test","mae_test","r2_test"])
        w.writerow([dataset_seed,"TFT","", "", "", f"{rmse_tft:.6f}",f"{mae_tft:.6f}",f"{r2_tft:.6f}",
                    f"{rmse_tft_test:.6f}",f"{mae_tft_test:.6f}",f"{r2_tft_test:.6f}"])
        w.writerow([dataset_seed,"Unreg","", "0","0", f"{rmse_unreg:.6f}",f"{mae_unreg:.6f}",f"{r2_unreg:.6f}",
                    f"{rmse_unreg_test:.6f}",f"{mae_unreg_test:.6f}",f"{r2_unreg_test:.6f}"])
        w.writerow([dataset_seed,"Reg", best_no, f"{tau1_opt:.6g}",f"{tau2_opt:.6g}",
                    f"{rmse_opt:.6f}",f"{mae_opt:.6f}",f"{r2_opt:.6f}",
                    f"{rmse_reg_test:.6f}",f"{mae_reg_test:.6f}",f"{r2_reg_test:.6f}"])

    print(
        f"Dataset {dataset_seed}: "
        f"TFT {rmse_tft:.3f} | Unreg {rmse_unreg:.3f} | "
        f"Reg {rmse_opt:.3f} (test {rmse_reg_test:.3f})"
    )

    return {
        "TFT":   {"rmse_val": rmse_tft, "rmse_test": rmse_tft_test, "r2_val": r2_tft, "r2_test": r2_tft_test},
        "unreg": {"rmse_val": rmse_unreg, "rmse_test": rmse_unreg_test, "r2_val": r2_unreg, "r2_test": r2_unreg_test},
        "reg":   {"rmse_val": rmse_opt, "rmse_test": rmse_reg_test, "r2_val": r2_opt, "r2_test": r2_reg_test,
                  "tau1": tau1_opt, "tau2": tau2_opt},
    }


## 6. Run Full Experiment Suite
before test epoch change to lower

In [None]:
all_results = []
for seed in range(EXPERIMENT_TRIALS_CONFIG['seed_range_start'], EXPERIMENT_TRIALS_CONFIG['seed_range_end']):
    print(f"\nStarting experiment for seed {seed}")
    results = run_one_experiment(seed, n_trials=EXPERIMENT_TRIALS_CONFIG['n_trials_per_seed'])
    all_results.append(results)
    clear_gpu_memory()
    print(f"✅ Completed seed {seed}")

print("\n🎉 All experiments completed!")

## 7. Results Analysis and Visualization


In [None]:
# Load results
results_df = pd.read_csv("metrics_summary_TFT.csv")
print("📊 Results Summary:")
print(results_df.groupby('model')[['rmse_val', 'rmse_test', 'r2_val', 'r2_test']].mean())

# Plot comparison
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# RMSE comparison
sns.boxplot(data=results_df, x='model', y='rmse_val', ax=axes[0,0])
axes[0,0].set_title('Validation RMSE')
axes[0,0].set_ylabel('RMSE')
axes[0,0].grid(True, alpha=0.3)

sns.boxplot(data=results_df, x='model', y='rmse_test', ax=axes[0,1])
axes[0,1].set_title('Test RMSE')
axes[0,1].set_ylabel('RMSE')
axes[0,1].grid(True, alpha=0.3)

# R² comparison
sns.boxplot(data=results_df, x='model', y='r2_val', ax=axes[1,0])
axes[1,0].set_title('Validation R²')
axes[1,0].set_ylabel('R²')
axes[1,0].grid(True, alpha=0.3)

sns.boxplot(data=results_df, x='model', y='r2_test', ax=axes[1,1])
axes[1,1].set_title('Test R²')
axes[1,1].set_ylabel('R²')
axes[1,1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Show best hyperparameters for regularized model
reg_results = results_df[results_df['model'] == 'Reg']
print("\n🔧 Best Hyperparameters for Regularized Model:")
print(reg_results[['tau1', 'tau2', 'rmse_val', 'rmse_test']].head(10))


In [None]:
# Performance comparison summary (TFT baseline)
print("=== Performance Comparison Summary (TFT as baseline) ===")

# Means by model
tft_mean_rmse   = results_df[results_df['model'] == 'TFT']['rmse_test'].mean()
unreg_mean_rmse = results_df[results_df['model'] == 'Unreg']['rmse_test'].mean()
reg_mean_rmse   = results_df[results_df['model'] == 'Reg']['rmse_test'].mean()

print(f"TFT (baseline)  - Mean Test RMSE: {tft_mean_rmse:.4f}")
print(f"Unregularized   - Mean Test RMSE: {unreg_mean_rmse:.4f} "
      f"({(1 - unreg_mean_rmse/tft_mean_rmse)*100:.1f}% improvement vs TFT)")
print(f"Regularized     - Mean Test RMSE: {reg_mean_rmse:.4f} "
      f"({(1 - reg_mean_rmse/tft_mean_rmse)*100:.1f}% improvement vs TFT)")

# Statistical significance test (paired by seed): TFT vs Regularized
from scipy import stats

pivot = results_df.pivot_table(index='seed', columns='model', values='rmse_test', aggfunc='mean')
paired = pivot.dropna(subset=['TFT', 'Reg'])  # keep only seeds that have both
t_stat, p_value = stats.ttest_rel(paired['TFT'].values, paired['Reg'].values)

print(f"\nStatistical Test (TFT vs Regularized):")
print(f"  t-statistic: {t_stat:.4f}")
print(f"  p-value: {p_value:.8f}")
print(f"  Significant improvement: {'Yes' if p_value < 0.05 else 'No'}")
