# OLS vs. Spatial Adapter Comparison with Tuning Parameter Selection

This notebook implements a comprehensive comparison between:
1. **OLS (Ordinary Least Squares)** - Linear baseline
2. **Unregularized Spatial Adapter** - Neural spatial model without regularization
3. **Regularized Spatial Adapter** - Neural spatial model with optimized tau1, tau2 parameters

The experiment uses Optuna for hyperparameter optimization and evaluates performance across multiple random seeds.

In [None]:
# Import required libraries
import csv
import math
from pathlib import Path

import numpy as np
import optuna
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from optuna.pruners import MedianPruner
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from typing import Tuple, Dict, Any
import pandas as pd

# Darts / TFT
from darts import TimeSeries
from darts.models import TFTModel
import pytorch_lightning as pl  # 供 Darts 背後使用

# Local imports
from geospatial_neural_adapter.cpp_extensions import estimate_covariance
from geospatial_neural_adapter.utils.experiment import log_covariance_and_basis
from geospatial_neural_adapter.utils import (
    ModelCache,
    clear_gpu_memory,
    create_experiment_config,
    print_experiment_summary,
    get_device_info,
)
# ⚠️ OLS 相關匯入已移除：compute_ols_coefficients, predict_ols, TrendModel

from geospatial_neural_adapter.models.spatial_basis_learner import SpatialBasisLearner
from geospatial_neural_adapter.models.spatial_neural_adapter import SpatialNeuralAdapter
from geospatial_neural_adapter.models.pretrained_trend_model import PretrainedTrendModel
from geospatial_neural_adapter.models.wrapper_examples.tft_wrapper import TFTWrapper

from geospatial_neural_adapter.data.generators import generate_time_synthetic_data
from geospatial_neural_adapter.data.preprocessing import prepare_all_with_scaling, denormalize_predictions
from geospatial_neural_adapter.metrics import compute_metrics

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
plt.style.use('default')
sns.set_palette("husl")

print("✅ All imports successful (TFT version)!")


## 1. Parameter Configuration and Setup

In [None]:
# Experiment Configuration
EXPERIMENT_CONFIG = {
    'seed': 42,
    'n_time_steps': 1024,
    'n_locations': 512,
    'noise_std': 4.0,
    'eigenvalue': 16.0,
    'latent_dim': 1,
    'ckpt_dir': "admm_bcd_ckpts",
}

# Spatial Neural Adapter Configuration using dataclasses
from geospatial_neural_adapter.models.spatial_neural_adapter import (
    SpatialNeuralAdapterConfig, ADMMConfig, TrainingConfig, BasisConfig
)

# ADMM Configuration
admm_config = ADMMConfig(
    rho=1.0,  # Base ADMM penalty parameter
    dual_momentum=0.2,  # Dual variable momentum
    max_iters=3000,  # Maximum ADMM iterations
    min_outer=20,  # Minimum outer iterations before convergence check
    tol=1e-4,  # Convergence tolerance
)

# Training Configuration
training_config = TrainingConfig(
    lr_mu=1e-2,  # Learning rate for trend parameters
    batch_size=64,  # Batch size for theta step
    pretrain_epochs=5,  # Default pretraining epochs
    use_mixed_precision=False,  # Whether to use mixed precision
)

# Basis Configuration
basis_config = BasisConfig(
    phi_every=5,  # Update basis every N iterations
    phi_freeze=200,  # Stop updating basis after N iterations
    matrix_reg=1e-6,  # Matrix regularization for basis update
    irl1_max_iters=10,  # IRL₁ maximum iterations
    irl1_eps=1e-6,  # IRL₁ epsilon
    irl1_tol=5e-4,  # IRL₁ inner tolerance
)

# Complete Spatial Neural Adapter Configuration
SPATIAL_CONFIG = SpatialNeuralAdapterConfig(
    admm=admm_config,
    training=training_config,
    basis=basis_config
)

# Legacy config dict for backward compatibility (if needed)
CFG = SPATIAL_CONFIG.to_dict()
CFG.update(EXPERIMENT_CONFIG)

# Set random seed
torch.manual_seed(EXPERIMENT_CONFIG["seed"])
Path(EXPERIMENT_CONFIG["ckpt_dir"]).mkdir(exist_ok=True)

# Device setup
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device_info = get_device_info()
print(f"Using {device_info['device'].upper()}: {device_info['device_name']}")
if device_info['device'] == 'cuda':
    print(f"   Memory: {device_info['memory_gb']} GB")

# Print configuration summary
print("\n=== Experiment Configuration ===")
for key, value in EXPERIMENT_CONFIG.items():
    print(f"{key}: {value}")

print("\n=== Spatial Neural Adapter Configuration ===")
SPATIAL_CONFIG.log_config()

## 2. Initialize Utilities

In [None]:
# Initialize model cache for hyperparameter optimization
cache = ModelCache()

# Create experiment configuration
EXPERIMENT_TRIALS_CONFIG = create_experiment_config(
    n_trials_per_seed=20 if torch.cuda.is_available() else 50,
    n_dataset_seeds=10,
    seed_range_start=1,
    seed_range_end=11,
)

print_experiment_summary(EXPERIMENT_TRIALS_CONFIG)
print("Utilities initialized successfully!")

## 3. Data Generation and Preprocessing

In [None]:
# Generate synthetic data with meaningful correlations
print("Generating correlated synthetic data...")

locs = np.linspace(-3, 3, CFG["n_locations"])
cat_features, cont_features, targets = generate_time_synthetic_data(
    locs=locs,
    n_time_steps=CFG["n_time_steps"],
    noise_std=CFG["noise_std"],
    eigenvalue=CFG["eigenvalue"],
    eta_rho=0.8,
    f_rho=0.6,
    global_mean=50.0,
    feature_noise_std=0.1,
    non_linear_strength=0.2,
    seed=CFG["seed"]
)

# Prepare datasets with scaling
train_dataset, val_dataset, test_dataset, preprocessor = prepare_all_with_scaling(
    cat_features=cat_features,
    cont_features=cont_features,
    targets=targets,
    train_ratio=0.7,
    val_ratio=0.15,
    feature_scaler_type="standard",
    target_scaler_type="standard",
    fit_on_train_only=True
)
train_loader = DataLoader(train_dataset, batch_size=CFG["batch_size"], shuffle=True)

# Extract tensors
_, train_X, train_y = train_dataset.tensors
_, val_X, val_y = val_dataset.tensors
_, test_X, test_y = test_dataset.tensors

p_dim = train_X.shape[-1]

print(f"Data shapes: {cont_features.shape}, {targets.shape}")
print(f"Original targets - Mean: {targets.mean():.2f}, Std: {targets.std():.2f}")
print(f"Original targets - Range: {targets.min():.2f} to {targets.max():.2f}")
print(f"Feature dimension: {p_dim}")

In [None]:
# # Visualize data characteristics
# fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# # Plot 1: Target distribution
# axes[0, 0].hist(targets.flatten(), bins=30, alpha=0.7, edgecolor='black')
# axes[0, 0].set_title('Target Distribution')
# axes[0, 0].set_xlabel('Target Value')
# axes[0, 0].set_ylabel('Frequency')
# axes[0, 0].grid(True, alpha=0.3)

# # Plot 2: Spatial pattern at first time step
# axes[0, 1].plot(locs, targets[0, :], 'o-', linewidth=2, markersize=4)
# axes[0, 1].set_title('Spatial Pattern at t=0')
# axes[0, 1].set_xlabel('Location')
# axes[0, 1].set_ylabel('Target Value')
# axes[0, 1].grid(True, alpha=0.3)

# # Plot 3: Temporal pattern at middle location
# time_steps = np.arange(len(targets))
# axes[1, 0].plot(time_steps, targets[:, 25], linewidth=2)
# axes[1, 0].set_title('Temporal Pattern at Location 25')
# axes[1, 0].set_xlabel('Time Step')
# axes[1, 0].set_ylabel('Target Value')
# axes[1, 0].grid(True, alpha=0.3)

# # Plot 4: Feature correlations
# feature_corrs = []
# for i in range(cont_features.shape[-1]):
#     corr = np.corrcoef(targets.flatten(), cont_features[:, :, i].flatten())[0, 1]
#     feature_corrs.append(corr)

# axes[1, 1].bar(range(len(feature_corrs)), feature_corrs, alpha=0.7, edgecolor='black')
# axes[1, 1].set_title('Feature-Target Correlations')
# axes[1, 1].set_xlabel('Feature Index')
# axes[1, 1].set_ylabel('Correlation')
# axes[1, 1].grid(True, alpha=0.3)

# plt.tight_layout()
# plt.show()

## 4. OLS Baseline Implementation

In [None]:
# ===========================================
# === Replace OLS trend with TFT trend ======
# ===========================================

print("⚙️ Training TFT trend model... (TFT replaces the original OLS baseline)")

# 取用前面前處理後的資料張量
_, train_X, train_y = train_dataset.tensors
_, val_X,   val_y   = val_dataset.tensors
_, test_X,  test_y  = test_dataset.tensors

# 轉成 numpy（使用標準化後的資料）
train_X_np = train_X.numpy().astype(np.float32)                # (Ttr, N, F)
val_X_np   = val_X.numpy().astype(np.float32)                  # (Tv,  N, F)
test_X_np  = test_X.numpy().astype(np.float32)                 # (Tte, N, F)

train_y_np = train_y.numpy().squeeze(-1).astype(np.float32)    # (Ttr, N)
val_y_np   = val_y.numpy().squeeze(-1).astype(np.float32)      # (Tv,  N)
test_y_np  = test_y.numpy().squeeze(-1).astype(np.float32)     # (Tte, N)

# 基本維度
Ttr, N, F = train_X_np.shape[0], train_X_np.shape[1], train_X_np.shape[2]
p_dim = F

# === 建立 Darts 的 multivariate TimeSeries（target 與 past covariates）===
# 將 train/val/test 串成完整序列供 Darts 建 index；訓練只切到 training_cutoff
cont_np_full    = np.concatenate([train_X_np, val_X_np, test_X_np], axis=0)   # (T_full, N, F)
targets_np_full = np.concatenate([train_y_np, val_y_np, test_y_np], axis=0)   # (T_full, N)
T_full = cont_np_full.shape[0]

time_index = pd.RangeIndex(start=0, stop=T_full, step=1)

# 目標（多變量：每個 location 一欄）
target_df = pd.DataFrame(targets_np_full, index=time_index, columns=[f"loc_{i}" for i in range(N)])
target_ts = TimeSeries.from_dataframe(target_df, fill_missing_dates=True)

# past covariates（展平成 (T_full, N*F)）
cov_cols = [f"cov_{j}_loc_{i}" for i in range(N) for j in range(F)]
cov_df   = pd.DataFrame(cont_np_full.reshape(T_full, N*F), index=time_index, columns=cov_cols)
cov_ts   = TimeSeries.from_dataframe(cov_df, fill_missing_dates=True)

# 訓練只用 training_cutoff
training_cutoff = len(train_dataset)  # = Ttr

# === 輕量 TFT：以 trend 抽取為目的 ===
from darts.models import TFTModel
import torch

input_chunk_length = max(8, min(64, Ttr // 4))
tft_model = TFTModel(
    input_chunk_length=input_chunk_length,
    output_chunk_length=1,
    n_epochs=10,
    hidden_size=64,
    num_attention_heads=2,
    dropout=0.1,
    random_state=42,
    force_reset=True,
    pl_trainer_kwargs={
        "accelerator": "gpu" if torch.cuda.is_available() else "cpu",
        "devices": -1,
        "enable_progress_bar": True,
        "enable_model_summary": False,
        "enable_checkpointing": False,
        "max_epochs": 10,
    },
)

print(f"Training TFT (input_chunk_length={input_chunk_length}, epochs=10)...")
tft_model.fit(
    target_ts[:training_cutoff],
    past_covariates=cov_ts[:training_cutoff],
    verbose=True
)
print("✅ TFT training done.")

# === 包裝成 PretrainedTrendModel（TFTWrapper + PretrainedTrendModel）===
from geospatial_neural_adapter.models.pretrained_trend_model import PretrainedTrendModel
from geospatial_neural_adapter.models.wrapper_examples.tft_wrapper import TFTWrapper

wrapper = TFTWrapper(
    tft_model=tft_model,
    num_locations=N,
    num_features=F
)
tft_trend_model = PretrainedTrendModel(
    pretrained_model=wrapper,
    input_shape=(None, N, F),
    output_shape=(None, N),
    freeze_backbone=True,      # TFT backbone 凍結
    add_residual_head=True,    # 加一層可訓練 head，讓 adapter 訓練更穩
    residual_hidden_dim=64,
    dropout_rate=0.1,
).to(DEVICE)

print("✅ TFT trend model (wrapper) ready!")

# === 取得 VAL / TEST 的 trend 預測（用 wrapper 逐步取出）===
import torch

def wrapper_predict_seq(x_np):
    """x_np: (T, N, F) -> return (T, N) using tft_trend_model"""
    preds = []
    tft_trend_model.eval()
    with torch.no_grad():
        for t in range(x_np.shape[0]):
            x_t = torch.tensor(x_np[t:t+1], dtype=torch.float32, device=DEVICE)  # (1, N, F)
            y_t = tft_trend_model(x_t)  # (1, N)
            preds.append(y_t.squeeze(0).cpu())
    return torch.stack(preds, dim=0)  # (T, N)

y_trend_val  = wrapper_predict_seq(val_X_np)   # (Tv,  N)
y_trend_test = wrapper_predict_seq(test_X_np)  # (Tte, N)

# === 以 TFT 殘差估共變異，初始化 top-K eigen-basis（取代 OLS 殘差）===
residuals_val = torch.tensor(val_y_np,  dtype=torch.float32, device=DEVICE) - y_trend_val.to(DEVICE)  # (Tv, N)
covariance_matrix = residuals_val.transpose(0, 1) @ residuals_val  # (N, N)

K = CFG["latent_dim"]
eig = torch.linalg.eigh(covariance_matrix)
eigenvectors = eig.eigenvectors[:, -K:]  # 取最大 K 個特徵向量

from geospatial_neural_adapter.models.spatial_basis_learner import SpatialBasisLearner
tft_basis = SpatialBasisLearner(CFG["n_locations"], K).to(DEVICE)
tft_basis.basis.data.copy_(eigenvectors)

print(f"Initialized spatial basis from TFT residuals: basis shape = {tft_basis.basis.shape}")

# === （Optional）保留真實空間基底做參考：與 OLS/TFT 無關 ===
phi_true = np.exp(-(locs**2))[:, None]
phi_true /= np.linalg.norm(phi_true)
sigma_true_spatial = CFG["eigenvalue"] * (phi_true @ phi_true.T)

# === TFT baseline metrics（取代 OLS metrics）===
from geospatial_neural_adapter.metrics import compute_metrics

rmse_tft_val,  mae_tft_val,  r2_tft_val  = compute_metrics(
    torch.tensor(val_y_np,  dtype=torch.float32, device=DEVICE),
    y_trend_val.to(DEVICE)
)
rmse_tft_test, mae_tft_test, r2_tft_test = compute_metrics(
    torch.tensor(test_y_np, dtype=torch.float32, device=DEVICE),
    y_trend_test.to(DEVICE)
)

print(f"TFT Validation - RMSE: {rmse_tft_val:.4f}, R²: {r2_tft_val:.4f}")
print(f"TFT Test       - RMSE: {rmse_tft_test:.4f}, R²: {r2_tft_test:.4f}")


## 5. Main Experiment Function

In [None]:
def run_one_experiment(dataset_seed: int, n_trials: int = 30):
    """
    Run a complete experiment for one dataset seed, using TFT as the trend model
    (replacing the original OLS trend). We compare:
      1) TFT (trend-only baseline)
      2) Unregularized Spatial Adapter (tau1=tau2=0)
      3) Regularized Spatial Adapter (Optuna tuned tau1, tau2)
    """
    from darts import TimeSeries
    from darts.models import TFTModel
    import pandas as pd
    import numpy as np
    from torch.utils.tensorboard import SummaryWriter
    import optuna
    from optuna.pruners import MedianPruner
    import torch

    log_root = Path("TFT_runs") / f"TFT_seed_{dataset_seed}"
    log_root.mkdir(parents=True, exist_ok=True)

    # -------------------------------
    # 1) Generate dataset for this seed
    # -------------------------------
    cat_features, cont_features, targets = generate_time_synthetic_data(
        locs=locs,
        n_time_steps=CFG["n_time_steps"],
        noise_std=CFG["noise_std"],
        eigenvalue=CFG["eigenvalue"],
        eta_rho=0.8,
        f_rho=0.6,
        global_mean=50.0,
        feature_noise_std=0.1,
        non_linear_strength=0.2,
        seed=dataset_seed
    )

    train_dataset, val_dataset, test_dataset, preprocessor = prepare_all_with_scaling(
        cat_features=cat_features,
        cont_features=cont_features,
        targets=targets,
        train_ratio=0.7,
        val_ratio=0.15,
        feature_scaler_type="standard",
        target_scaler_type="standard",
        fit_on_train_only=True
    )
    train_loader = DataLoader(train_dataset, batch_size=SPATIAL_CONFIG.training.batch_size, shuffle=True)

    # unpack tensors
    _, train_X, train_y = train_dataset.tensors
    _, val_X,   val_y   = val_dataset.tensors
    _, test_X,  test_y  = test_dataset.tensors

    p_dim = train_X.shape[-1]  # num features
    N     = train_X.shape[1]   # num locations

    # -------------------------------
    # 2) Train TFT trend (multivariate) on TRAIN only
    # -------------------------------
    # Build TimeSeries for Darts: use scaled data
    train_X_np = train_X.numpy()                  # (Ttr, N, F)
    train_y_np = train_y.numpy().squeeze(-1)      # (Ttr, N)
    val_X_np   = val_X.numpy()                    # (Tv,  N, F)
    val_y_np   = val_y.numpy().squeeze(-1)        # (Tv,  N)
    test_X_np  = test_X.numpy()                   # (Tte, N, F)
    test_y_np  = test_y.numpy().squeeze(-1)       # (Tte, N)

    Ttr = train_X_np.shape[0]
    Tv  = val_X_np.shape[0]
    Tte = test_X_np.shape[0]

    # For Darts multivariate target
    idx_tr = pd.RangeIndex(start=0, stop=T_full, step=1)
    target_tr_df = pd.DataFrame(train_y_np, index=idx_tr, columns=[f"loc_{i}" for i in range(N)])
    ts_target_tr = TimeSeries.from_dataframe(target_tr_df, fill_missing_dates=True)

    # Lightweight TFT config for trend
    input_chunk_length = max(8, min(64, Ttr // 4))
    tft_model = TFTModel(
        input_chunk_length=input_chunk_length,
        output_chunk_length=1,
        n_epochs=10,
        hidden_size=64,
        num_attention_heads=2,
        dropout=0.1,
        random_state=42,
        force_reset=True,
        pl_trainer_kwargs={
            "accelerator": "gpu" if torch.cuda.is_available() else "cpu",
            "devices": -1,
            "enable_progress_bar": True,
            "enable_model_summary": False,
            "enable_checkpointing": False,
            "max_epochs": 10,
        },
    )
    tft_model.fit(ts_target_tr, verbose=True)

    # Wrap TFT as PretrainedTrendModel
    wrapper = TFTWrapper(tft_model=tft_model, num_locations=N, num_features=p_dim)
    tft_trend_model = PretrainedTrendModel(
        pretrained_model=wrapper,
        input_shape=(None, N, p_dim),
        output_shape=(None, N),
        freeze_backbone=True,      # keep TFT frozen during adapter training
        add_residual_head=True,    # small residual head trainable if needed
        residual_hidden_dim=64,
        dropout_rate=0.1,
    ).to(DEVICE)

    # -------------------------------
    # 3) Evaluate TFT-only (trend baseline) on VAL/TEST
    # -------------------------------
    def predict_with_wrapper(x_np):
        # x_np: (T, N, F), return (T, N)
        preds = []
        with torch.no_grad():
            for t in range(x_np.shape[0]):
                x_t = torch.tensor(x_np[t:t+1], dtype=torch.float32, device=DEVICE)  # (1,N,F)
                y_t = tft_trend_model(x_t)  # (1,N)
                preds.append(y_t.squeeze(0).cpu().numpy())
        return np.stack(preds, axis=0)

    y_tft_val_np  = predict_with_wrapper(val_X_np)   # (Tv, N)
    y_tft_test_np = predict_with_wrapper(test_X_np)  # (Tte, N)

    # metrics (standardized space)
    rmse_tft, mae_tft, r2_tft = compute_metrics(
        torch.from_numpy(val_y_np).to(DEVICE),
        torch.from_numpy(y_tft_val_np).to(DEVICE)
    )
    rmse_tft_test, mae_tft_test, r2_tft_test = compute_metrics(
        torch.from_numpy(test_y_np).to(DEVICE),
        torch.from_numpy(y_tft_test_np).to(DEVICE)
    )

    # -------------------------------
    # 4) Bootstrap: Unregularized Spatial Adapter (tau1=tau2=0)
    # -------------------------------
    cache.clear()
    clear_gpu_memory()

    # Create fresh trend & basis for bootstrap
    boot_trend = PretrainedTrendModel(
        pretrained_model=wrapper,
        input_shape=(None, N, p_dim),
        output_shape=(None, N),
        freeze_backbone=True,
        add_residual_head=True,
        residual_hidden_dim=64,
        dropout_rate=0.1,
    ).to(DEVICE)
    boot_basis = SpatialBasisLearner(N, CFG["latent_dim"]).to(DEVICE)

    boot_writer = SummaryWriter(log_dir=log_root / "bootstrap")
    boot = SpatialNeuralAdapter(
        boot_trend,
        boot_basis,
        train_loader,
        val_cont=val_X.to(DEVICE),
        val_y=val_y.to(DEVICE),
        locs=locs,
        config=SPATIAL_CONFIG,
        device=DEVICE,
        writer=boot_writer,
        tau1=0.0,
        tau2=0.0,
    )
    # skip heavy pretrain; TFT backbone is frozen & pretrained
    boot.init_basis_dense()
    boot.run()
    cache.store(0.0, 0.0, boot_trend.state_dict(), boot_basis.state_dict())
    boot_writer.close()

    # Unregularized predictions
    y_boot_val = boot.predict(val_X.to(DEVICE), val_y.to(DEVICE))
    rmse_boot, mae_boot, r2_boot = compute_metrics(val_y.to(DEVICE), y_boot_val)
    y_boot_test = boot.predict(test_X.to(DEVICE), test_y.to(DEVICE))
    rmse_boot_test, mae_boot_test, r2_boot_test = compute_metrics(test_y.to(DEVICE), y_boot_test)

    # Clean up bootstrap models
    del boot_trend, boot_basis, boot
    clear_gpu_memory()

    # -------------------------------
    # 5) Optuna objective: tune (tau1, tau2) for Regularized Adapter
    # -------------------------------
    def objective(trial):
        dev = DEVICE
        tau1 = trial.suggest_float("tau1", 1e-4, 1e8, log=True)
        tau2 = trial.suggest_float("tau2", 1e-4, 1e8, log=True)

        clear_gpu_memory()

        # fresh trend & basis
        trend = PretrainedTrendModel(
            pretrained_model=wrapper,
            input_shape=(None, N, p_dim),
            output_shape=(None, N),
            freeze_backbone=True,
            add_residual_head=True,
            residual_hidden_dim=64,
            dropout_rate=0.1,
        ).to(dev)
        basis = SpatialBasisLearner(N, CFG["latent_dim"]).to(dev)

        # warm-start from nearest cached pair (if any)
        cache.load_nearest(trend, basis, tau1, tau2)

        writer = SummaryWriter(log_dir=log_root / f"trial_{trial.number:03d}")
        trainer = SpatialNeuralAdapter(
            trend,
            basis,
            train_loader,
            val_cont=val_X.to(dev),
            val_y=val_y.to(dev),
            locs=locs,
            config=SPATIAL_CONFIG,
            device=dev,
            writer=writer,
            tau1=tau1,
            tau2=tau2,
        )
        trainer.init_basis_dense()
        trainer.run()

        y_pred = trainer.predict(val_X.to(dev), val_y.to(dev))
        rmse, mae, r2 = compute_metrics(val_y.to(dev), y_pred)

        trial.set_user_attr("rmse", rmse)
        trial.set_user_attr("mae", mae)
        trial.set_user_attr("r2", r2)

        writer.close()
        cache.store(tau1, tau2, trend.state_dict(), basis.state_dict())

        # cleanup
        del trend, basis, trainer, y_pred
        clear_gpu_memory()

        return rmse

    study = optuna.create_study(
        study_name=f"spatial_adapter_tft_ds{dataset_seed}",
        direction="minimize",
        sampler=optuna.samplers.TPESampler(),
        pruner=MedianPruner(n_warmup_steps=5),
        load_if_exists=False,
    )
    study.optimize(objective, n_trials=n_trials, n_jobs=1)

    # Best regularized model
    best = study.best_trial
    rmse_opt = best.user_attrs["rmse"]
    mae_opt  = best.user_attrs["mae"]
    r2_opt   = best.user_attrs["r2"]
    tau1_opt = best.params["tau1"]
    tau2_opt = best.params["tau2"]
    best_no  = best.number

    dev_best = DEVICE
    trend_best = PretrainedTrendModel(
        pretrained_model=wrapper,
        input_shape=(None, N, p_dim),
        output_shape=(None, N),
        freeze_backbone=True,
        add_residual_head=True,
        residual_hidden_dim=64,
        dropout_rate=0.1,
    ).to(dev_best)
    basis_best = SpatialBasisLearner(N, CFG["latent_dim"]).to(dev_best)

    # load cached states for best taus
    sd_t, sd_b = cache.cache[(tau1_opt, tau2_opt)]
    trend_best.load_state_dict(sd_t)
    basis_best.load_state_dict(sd_b)

    trend_best.eval(); basis_best.eval()
    with torch.no_grad():
        X_test = test_X.to(dev_best)
        y_test = test_y.to(dev_best)
        y_reg_test = SpatialNeuralAdapter.predict_static(trend_best, basis_best, X_test, y_test)
    rmse_test, mae_test, r2_test = compute_metrics(y_test, y_reg_test)

    # -------------------------------
    # 6) Write CSV summary (TFT / Unreg / Reg)
    # -------------------------------
    csv_path = Path("metrics_summary_TFT.csv")
    write_header = not csv_path.exists()
    with csv_path.open("a", newline="") as f:
        w = csv.writer(f)
        if write_header:
            w.writerow([
                "seed", "model", "trial", "tau1", "tau2",
                "rmse_val", "mae_val", "r2_val",
                "rmse_test", "mae_test", "r2_test"
            ])

        # TFT trend-only baseline
        w.writerow([
            dataset_seed, "TFT", "", "", "",
            f"{rmse_tft:.6f}", f"{mae_tft:.6f}", f"{r2_tft:.6f}",
            f"{rmse_tft_test:.6f}", f"{mae_tft_test:.6f}", f"{r2_tft_test:.6f}"
        ])

        # Unregularized Adapter
        w.writerow([
            dataset_seed, "Unreg", "", "0", "0",
            f"{rmse_boot:.6f}", f"{mae_boot:.6f}", f"{r2_boot:.6f}",
            f"{rmse_boot_test:.6f}", f"{mae_boot_test:.6f}", f"{r2_boot_test:.6f}"
        ])

        # Regularized Adapter (best)
        w.writerow([
            dataset_seed, "Reg", best_no, f"{tau1_opt:.6g}", f"{tau2_opt:.6g}",
            f"{rmse_opt:.6f}", f"{mae_opt:.6f}", f"{r2_opt:.6f}",
            f"{rmse_test:.6f}", f"{mae_test:.6f}", f"{r2_test:.6f}"
        ])

    print(
        f"Dataset {dataset_seed}:  "
        f"TFT RMSE={rmse_tft:.3f} | "
        f"Unreg RMSE={rmse_boot:.3f} | "
        f"Reg RMSE={rmse_opt:.3f} (test {rmse_test:.3f})"
    )

    return {
        'tft':  {'rmse_val': rmse_tft,  'rmse_test': rmse_tft_test,  'r2_val': r2_tft,  'r2_test': r2_tft_test},
        'unreg':{'rmse_val': rmse_boot, 'rmse_test': rmse_boot_test, 'r2_val': r2_boot, 'r2_test': r2_boot},
        'reg':  {'rmse_val': rmse_opt,  'rmse_test': rmse_test,      'r2_val': r2_opt,  'r2_test': r2_test,
                 'tau1': tau1_opt, 'tau2': tau2_opt}
    }


## 6. Run Full Experiment Suite

In [None]:
all_results = []
for seed in range(EXPERIMENT_TRIALS_CONFIG['seed_range_start'], EXPERIMENT_TRIALS_CONFIG['seed_range_end']):
    print(f"\nStarting experiment for seed {seed}")
    results = run_one_experiment(seed, n_trials=EXPERIMENT_TRIALS_CONFIG['n_trials_per_seed'])
    all_results.append(results)
    # Clear cache between seeds to free memory
    cache.clear()
    clear_gpu_memory()
    print(f"✅ Completed seed {seed}")

print("\n🎉 All experiments completed!")

## 7. Results Analysis and Visualization

In [None]:
# Load results
results_df = pd.read_csv("metrics_summary_TFT.csv")
print("📊 Results Summary:")
print(results_df.groupby('model')[['rmse_val', 'rmse_test', 'r2_val', 'r2_test']].mean())

# Plot comparison
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# RMSE comparison
sns.boxplot(data=results_df, x='model', y='rmse_val', ax=axes[0,0])
axes[0,0].set_title('Validation RMSE')
axes[0,0].set_ylabel('RMSE')
axes[0,0].grid(True, alpha=0.3)

sns.boxplot(data=results_df, x='model', y='rmse_test', ax=axes[0,1])
axes[0,1].set_title('Test RMSE')
axes[0,1].set_ylabel('RMSE')
axes[0,1].grid(True, alpha=0.3)

# R² comparison
sns.boxplot(data=results_df, x='model', y='r2_val', ax=axes[1,0])
axes[1,0].set_title('Validation R²')
axes[1,0].set_ylabel('R²')
axes[1,0].grid(True, alpha=0.3)

sns.boxplot(data=results_df, x='model', y='r2_test', ax=axes[1,1])
axes[1,1].set_title('Test R²')
axes[1,1].set_ylabel('R²')
axes[1,1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Show best hyperparameters for regularized model
reg_results = results_df[results_df['model'] == 'Reg']
print("\n🔧 Best Hyperparameters for Regularized Model:")
print(reg_results[['tau1', 'tau2', 'rmse_val', 'rmse_test']].head(10))

In [None]:
# === Performance comparison summary (TFT is the baseline now) ===
print("=== Performance Comparison Summary ===")

tft_mean_rmse   = results_df[results_df['model'] == 'TFT']['rmse_test'].mean()
unreg_mean_rmse = results_df[results_df['model'] == 'Unreg']['rmse_test'].mean()
reg_mean_rmse   = results_df[results_df['model'] == 'Reg']['rmse_test'].mean()

print(f"TFT (baseline) - Mean Test RMSE: {tft_mean_rmse:.4f}")
print(f"Unregularized  - Mean Test RMSE: {unreg_mean_rmse:.4f} ({(1 - unreg_mean_rmse/tft_mean_rmse)*100:.1f}% improvement)")
print(f"Regularized    - Mean Test RMSE: {reg_mean_rmse:.4f} ({(1 - reg_mean_rmse/tft_mean_rmse)*100:.1f}% improvement)")

# Statistical significance test (pair by seed if你有對應關係；此處簡單比較分布)
from scipy import stats
tft_scores = results_df[results_df['model'] == 'TFT']['rmse_test'].values
reg_scores = results_df[results_df['model'] == 'Reg']['rmse_test'].values

# 注意：若不同 row 不完全配對同一 seed，這裡用獨立樣本 t 檢定較合理；若完全一一對應，可改 ttest_rel
t_stat, p_value = stats.ttest_ind(tft_scores, reg_scores, equal_var=False)
print(f"\nStatistical Test (TFT vs Regularized):")
print(f"  t-statistic: {t_stat:.4f}")
print(f"  p-value: {p_value:.4f}")
print(f"  Significant improvement: {'Yes' if p_value < 0.05 else 'No'}")