# OLS vs. Spatial Adapter Comparison with Tuning Parameter Selection

This notebook implements a comprehensive comparison between:
1. **OLS (Ordinary Least Squares)** - Linear baseline
2. **Unregularized Spatial Adapter** - Neural spatial model without regularization
3. **Regularized Spatial Adapter** - Neural spatial model with optimized tau1, tau2 parameters

The experiment uses Optuna for hyperparameter optimization and evaluates performance across multiple random seeds.

my work: ols 換成 TFT 然後執行模擬

In [1]:
# Import required libraries
import csv
import math
from pathlib import Path

import numpy as np
import optuna
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from optuna.pruners import MedianPruner
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from typing import Tuple, Dict, Any
import pandas as pd

# NEW: Darts TFT backbone
from darts import TimeSeries
from darts.models import TFTModel
import pytorch_lightning as pl

from geospatial_neural_adapter.cpp_extensions import estimate_covariance
from geospatial_neural_adapter.utils.experiment import log_covariance_and_basis
from geospatial_neural_adapter.utils import (
    ModelCache,
    clear_gpu_memory,
    create_experiment_config,
    print_experiment_summary,
    get_device_info,
)
from geospatial_neural_adapter.models.spatial_basis_learner import SpatialBasisLearner
from geospatial_neural_adapter.models.spatial_neural_adapter import SpatialNeuralAdapter

# NEW: 用 TFT 當 backbone 的包裝
from geospatial_neural_adapter.models.pretrained_trend_model import create_pretrained_trend_model
from geospatial_neural_adapter.models.wrapper_examples.tft_wrapper import TFTWrapper

from geospatial_neural_adapter.data.generators import generate_time_synthetic_data
from geospatial_neural_adapter.data.preprocessing import prepare_all_with_scaling, denormalize_predictions
from geospatial_neural_adapter.metrics import compute_metrics

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
plt.style.use('default')
sns.set_palette("husl")

print("✅ All imports successful!")

  from .autonotebook import tqdm as notebook_tqdm
  __import__("pkg_resources").declare_namespace(__name__)  # type: ignore


✅ Loaded spatial_utils from: /home/wangxc17/work/TFTModel-use/geospatial-neural-adapter-dev/geospatial_neural_adapter/cpp_extensions/spatial_utils.so
✅ All imports successful!


## 1. Parameter Configuration and Setup

In [2]:
# Experiment Configuration
EXPERIMENT_CONFIG = {
    'seed': 42,
    'n_time_steps': 1024,
    'n_locations': 512,
    'noise_std': 4.0,
    'eigenvalue': 16.0,
    'latent_dim': 1,
    'ckpt_dir': "admm_bcd_ckpts",
}

TFT_CONFIG = {
    # 長度
    'tft_input_chunk_length': None,  # None → 依 T_train 自動算
    'tft_output_chunk_length': 1,
    'tft_output_chunk_shift': 0,

    # 模型規模
    'tft_hidden_size': 16, #32         
    'tft_lstm_layers': 1,            
    'tft_num_attention_heads': 2,    
    'tft_full_attention': False,
    'tft_feed_forward': 'GatedResidualNetwork',
    'tft_hidden_continuous_size': 16,             

    # 訓練/優化
    'tft_batch_size': 16, # 64
    'tft_n_epochs': 5, #20
    'tft_optimizer_kwargs': {'lr': 1e-3, 'weight_decay': 1e-4},
}

# Spatial Neural Adapter Configuration using dataclasses
from geospatial_neural_adapter.models.spatial_neural_adapter import (
    SpatialNeuralAdapterConfig, ADMMConfig, TrainingConfig, BasisConfig
)

# ADMM Configuration
admm_config = ADMMConfig(
    rho=1.0,
    dual_momentum=0.2,
    max_iters=3000,
    min_outer=20,
    tol=1e-4,
)

# Training Configuration
training_config = TrainingConfig(
    lr_mu=1e-2,
    batch_size=64,
    pretrain_epochs=5,
    use_mixed_precision=False,
)

# Basis Configuration
basis_config = BasisConfig(
    phi_every=5,
    phi_freeze=200,
    matrix_reg=1e-6,
    irl1_max_iters=10,
    irl1_eps=1e-6,
    irl1_tol=5e-4,
)

# Complete Spatial Neural Adapter Configuration
SPATIAL_CONFIG = SpatialNeuralAdapterConfig(
    admm=admm_config,
    training=training_config,
    basis=basis_config
)

# Legacy dict for convenience
CFG: Dict[str, Any] = SPATIAL_CONFIG.to_dict()
CFG.update(EXPERIMENT_CONFIG)
CFG.update(TFT_CONFIG) 

# Set random seed & paths
torch.manual_seed(EXPERIMENT_CONFIG["seed"])
Path(EXPERIMENT_CONFIG["ckpt_dir"]).mkdir(exist_ok=True)

# Device setup
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device_info = get_device_info()
print(f"Using {device_info['device'].upper()}: {device_info['device_name']}")
if device_info['device'] == 'cuda':
    print(f"   Memory: {device_info['memory_gb']} GB")

# Print configuration summary
print("\n=== Experiment Configuration ===")
for key, value in EXPERIMENT_CONFIG.items():
    print(f"{key}: {value}")

print("\n=== Spatial Neural Adapter Configuration ===")
SPATIAL_CONFIG.log_config()

2025-08-11 18:01:01,829 - spatial_neural_adapter - INFO - SpatialNeuralAdapterConfig:
2025-08-11 18:01:01,829 - spatial_neural_adapter - INFO -   ADMM Config:
2025-08-11 18:01:01,830 - spatial_neural_adapter - INFO -     rho: 1.0
2025-08-11 18:01:01,830 - spatial_neural_adapter - INFO -     dual_momentum: 0.2
2025-08-11 18:01:01,830 - spatial_neural_adapter - INFO -     max_iters: 3000
2025-08-11 18:01:01,831 - spatial_neural_adapter - INFO -     min_outer: 20
2025-08-11 18:01:01,831 - spatial_neural_adapter - INFO -     tol: 0.0001
2025-08-11 18:01:01,832 - spatial_neural_adapter - INFO -   Training Config:
2025-08-11 18:01:01,832 - spatial_neural_adapter - INFO -     lr_mu: 0.01
2025-08-11 18:01:01,833 - spatial_neural_adapter - INFO -     batch_size: 64
2025-08-11 18:01:01,833 - spatial_neural_adapter - INFO -     pretrain_epochs: 5
2025-08-11 18:01:01,834 - spatial_neural_adapter - INFO -     use_mixed_precision: False
2025-08-11 18:01:01,835 - spatial_neural_adapter - INFO -   Bas

Using CUDA: NVIDIA GeForce RTX 4060 Laptop GPU
   Memory: 8.6 GB

=== Experiment Configuration ===
seed: 42
n_time_steps: 1024
n_locations: 512
noise_std: 4.0
eigenvalue: 16.0
latent_dim: 1
ckpt_dir: admm_bcd_ckpts

=== Spatial Neural Adapter Configuration ===


## 2. Initialize Utilities
 目前減少跑的次數 確認完之後要跑完整的

In [3]:
# Initialize model cache for hyperparameter optimization
cache = ModelCache()

# Create experiment configuration
EXPERIMENT_TRIALS_CONFIG = create_experiment_config(
    n_trials_per_seed=5 if torch.cuda.is_available() else 5,
    n_dataset_seeds=2,
    seed_range_start=1,
    seed_range_end=3,
)

print_experiment_summary(EXPERIMENT_TRIALS_CONFIG)
print("Utilities initialized successfully!")

Experiment Configuration:
  Trials per seed: 5
  Dataset seeds: 1 to 2
  Total experiments: 10
  Device: GPU
Utilities initialized successfully!


## 3. Data Generation and Preprocessing

In [4]:
# Generate synthetic data with meaningful correlations
print("Generating correlated synthetic data...")

locs = np.linspace(-3, 3, CFG["n_locations"]).astype(np.float32)

cat_features, cont_features, targets = generate_time_synthetic_data(
    locs=locs,
    n_time_steps=CFG["n_time_steps"],
    noise_std=CFG["noise_std"],
    eigenvalue=CFG["eigenvalue"],
    eta_rho=0.8,
    f_rho=0.6,
    global_mean=50.0,
    feature_noise_std=0.1,
    non_linear_strength=0.2,
    seed=CFG["seed"],
)

train_dataset, val_dataset, test_dataset, preprocessor = prepare_all_with_scaling(
    cat_features=cat_features,
    cont_features=cont_features,
    targets=targets,
    train_ratio=0.7,
    val_ratio=0.15,
    feature_scaler_type="standard",
    target_scaler_type="standard",
    fit_on_train_only=True,
)

# 建 DataLoader：用 SPATIAL_CONFIG.training.batch_size 做保險
bs = int(CFG.get("batch_size", SPATIAL_CONFIG.training.batch_size))
train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True)

# Extract tensors
_, train_X, train_y = train_dataset.tensors
_, val_X,   val_y   = val_dataset.tensors
_, test_X,  test_y  = test_dataset.tensors

p_dim  = train_X.shape[-1]
T_full = cont_features.shape[0]
N      = cont_features.shape[1]
F      = p_dim

print(f"Data shapes: cont={cont_features.shape}, targets={targets.shape}")
print(f"Original targets - Mean: {targets.mean():.2f}, Std: {targets.std():.2f}")
print(f"Original targets - Range: {targets.min():.2f} to {targets.max():.2f}")
print(f"N = {N} | F = {F} | T = {T_full}")

Generating correlated synthetic data...
Data shapes: cont=(1024, 512, 3), targets=(1024, 512)
Original targets - Mean: 50.95, Std: 4.22
Original targets - Range: 31.31 to 72.22
N = 512 | F = 3 | T = 1024


In [5]:
# # Visualize data characteristics
# fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# # Plot 1: Target distribution
# axes[0, 0].hist(targets.flatten(), bins=30, alpha=0.7, edgecolor='black')
# axes[0, 0].set_title('Target Distribution')
# axes[0, 0].set_xlabel('Target Value')
# axes[0, 0].set_ylabel('Frequency')
# axes[0, 0].grid(True, alpha=0.3)

# # Plot 2: Spatial pattern at first time step
# axes[0, 1].plot(locs, targets[0, :], 'o-', linewidth=2, markersize=4)
# axes[0, 1].set_title('Spatial Pattern at t=0')
# axes[0, 1].set_xlabel('Location')
# axes[0, 1].set_ylabel('Target Value')
# axes[0, 1].grid(True, alpha=0.3)

# # Plot 3: Temporal pattern at middle location
# time_steps = np.arange(len(targets))
# axes[1, 0].plot(time_steps, targets[:, 25], linewidth=2)
# axes[1, 0].set_title('Temporal Pattern at Location 25')
# axes[1, 0].set_xlabel('Time Step')
# axes[1, 0].set_ylabel('Target Value')
# axes[1, 0].grid(True, alpha=0.3)

# # Plot 4: Feature correlations
# feature_corrs = []
# for i in range(cont_features.shape[-1]):
#     corr = np.corrcoef(targets.flatten(), cont_features[:, :, i].flatten())[0, 1]
#     feature_corrs.append(corr)

# axes[1, 1].bar(range(len(feature_corrs)), feature_corrs, alpha=0.7, edgecolor='black')
# axes[1, 1].set_title('Feature-Target Correlations')
# axes[1, 1].set_xlabel('Feature Index')
# axes[1, 1].set_ylabel('Correlation')
# axes[1, 1].grid(True, alpha=0.3)

# plt.tight_layout()
# plt.show()


## 4. Baseline Implementation
 OLS 改成 TFT 

In [6]:
# Compute true spatial basis for comparison
phi_true = np.exp(-(locs**2))[:, None]
phi_true /= np.linalg.norm(phi_true)
sigma_true_spatial = CFG["eigenvalue"] * (phi_true @ phi_true.T)

print("Training TFT baseline...")

# === 用「經過 scaling 的張量」拼回全長，確保與後續 tft_trend(val_X) 尺度一致 ===
# train/val/test 的切分是沿時間切，所以直接按時間串接即可
y_all = torch.cat([train_y, val_y, test_y], dim=0).squeeze(-1).cpu().numpy()  # (T, N)
x_all = torch.cat([train_X, val_X, test_X], dim=0).cpu().numpy()              # (T, N, F)

T_full = x_all.shape[0]
N      = x_all.shape[1]
F      = x_all.shape[2]

# Build pandas DataFrames for Darts
time_index = pd.RangeIndex(start=0, stop=T_full, step=1)
target_df  = pd.DataFrame(
    y_all,  # 已經是 (T, N)
    index=time_index,
    columns=[f"y_loc_{i}" for i in range(N)]
)

cont_np       = x_all.reshape(T_full, N * F)
locs_expanded = np.tile(locs.astype(np.float32), (T_full, 1))
cov_df = pd.DataFrame(
    np.concatenate([cont_np, locs_expanded], axis=1),
    index=time_index,
    columns=[f"x{j}_loc_{i}" for i in range(N) for j in range(F)] + [f"loc_{i}" for i in range(N)]
)

# Create TimeSeries
target_ts    = TimeSeries.from_dataframe(target_df,  fill_missing_dates=False)
covariate_ts = TimeSeries.from_dataframe(cov_df,     fill_missing_dates=False)

# Train TFT（T_train 用 train_X 長度）
T_train = len(train_X)
input_chunk = CFG.get('tft_input_chunk_length') or min(32, max(8, T_train // 4))
tft_model = TFTModel(
    input_chunk_length=input_chunk,
    output_chunk_length=int(CFG['tft_output_chunk_length']),
    hidden_size=int(CFG['tft_hidden_size']),
    lstm_layers=int(CFG['tft_lstm_layers']),
    num_attention_heads=int(CFG['tft_num_attention_heads']),
    full_attention=bool(CFG['tft_full_attention']),
    hidden_continuous_size=int(CFG['tft_hidden_continuous_size']),
    batch_size=int(CFG['tft_batch_size']),
    n_epochs=int(CFG['tft_n_epochs']),
    optimizer_kwargs=CFG['tft_optimizer_kwargs'],
    add_relative_index=True,
    force_reset=True,
    save_checkpoints=False,
    random_state=int(CFG['seed']),
    dropout = 0.1
)
tft_model.fit(
    series=target_ts[:T_train],
    past_covariates=covariate_ts[:T_train],
    verbose=True
)

# Wrap TFT as trend model for Spatial Adapter
tft_wrapper = TFTWrapper(tft_model=tft_model, num_locations=N, num_features=F)
tft_trend = create_pretrained_trend_model(
    pretrained_model=tft_wrapper,
    input_shape=(None, N, F),   # ← 輸入：T x N x F
    output_shape=(None, N, 1),  # ← 輸出：T x N x 1
)
tft_trend=tft_trend.to(DEVICE)
# Residual eigen-basis from validation residuals
with torch.no_grad():
    y_tft_val  = tft_trend(val_X.to(DEVICE))
    y_tft_test = tft_trend(test_X.to(DEVICE))

# ✅ 壓成 2D 再做共變異
residuals_val = (val_y.to(DEVICE) - y_tft_val).squeeze(-1)  # (T_val, N)
covariance_matrix = residuals_val.T @ residuals_val         # (N, N)
K = CFG["latent_dim"]
eigenvectors = torch.linalg.eigh(covariance_matrix).eigenvectors[:, -K:]  # top-K

tft_basis = SpatialBasisLearner(CFG["n_locations"], K).to(DEVICE)
tft_basis.basis.data.copy_(eigenvectors)

# TFT metrics
rmse_tft_val,  mae_tft_val,  r2_tft_val  = compute_metrics(val_y.to(DEVICE),  y_tft_val)
rmse_tft_test, mae_tft_test, r2_tft_test = compute_metrics(test_y.to(DEVICE), y_tft_test)

print(f"TFT Validation - RMSE: {rmse_tft_val:.4f}, R²: {r2_tft_val:.4f}")
print(f"TFT Test       - RMSE: {rmse_tft_test:.4f}, R²: {r2_tft_test:.4f}")


Training TFT baseline...


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                              | Type                             | Params | Mode 
------------------------------------------------------------------------------------------------
0  | train_metrics                     | MetricCollection                 | 0      | train
1  | val_metrics                       | MetricCollection                 | 0      | train
2  | input_embeddings                  | _MultiEmbedding                  | 0      | train
3  | 

Epoch 4: 100%|██████████| 43/43 [01:56<00:00,  0.37it/s, train_loss=4.130]

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 43/43 [01:56<00:00,  0.37it/s, train_loss=4.130]
TFT Validation - RMSE: 1.1373, R²: -0.2951
TFT Test       - RMSE: 1.1822, R²: -0.3789


## 5. Main Experiment Function
存檔從seed改為TFT_seed, 
 呼叫 OLS 的函式改成呼叫 TFT 函式

In [7]:
def run_one_experiment(dataset_seed: int, n_trials: int = 30):
    """Run a complete experiment for one dataset seed (TFT baseline + Spatial Adapter)."""
    import copy  # 為了 deepcopy trend/basis

    log_root = Path("runs") / f"TFT_seed_{dataset_seed}"
    log_root.mkdir(parents=True, exist_ok=True)

    # ---- Generate data for this seed ----
    cat_features, cont_features, targets = generate_time_synthetic_data(
        locs=locs,
        n_time_steps=CFG["n_time_steps"],
        noise_std=CFG["noise_std"],
        eigenvalue=CFG["eigenvalue"],
        eta_rho=0.8,
        f_rho=0.6,
        global_mean=50.0,
        feature_noise_std=0.1,
        non_linear_strength=0.2,
        seed=dataset_seed
    )

    # ---- Prepare datasets with scaling (沿用你原本 0.7/0.15 切分) ----
    train_dataset, val_dataset, test_dataset, preprocessor = prepare_all_with_scaling(
        cat_features=cat_features,
        cont_features=cont_features,
        targets=targets,
        train_ratio=0.7,
        val_ratio=0.15,
        feature_scaler_type="standard",
        target_scaler_type="standard",
        fit_on_train_only=True
    )
    train_loader = DataLoader(train_dataset, batch_size=SPATIAL_CONFIG.training.batch_size, shuffle=True)

    _, train_X, train_y = train_dataset.tensors
    _, val_X,   val_y   = val_dataset.tensors
    _, test_X,  test_y  = test_dataset.tensors

    # ---- Shapes ----
    T_full = cont_features.shape[0]
    N      = cont_features.shape[1]
    F      = train_X.shape[-1]

    # ---- TFT Baseline（取代 OLS）----
    print("Training TFT baseline...")
    time_index = pd.RangeIndex(start=0, stop=T_full, step=1)

    target_df = pd.DataFrame(
        targets.reshape(T_full, N),
        index=time_index,
        columns=[f"y_loc_{i}" for i in range(N)]
    )
    cont_np        = cont_features.reshape(T_full, N * F)
    locs_expanded  = np.tile(locs.astype(np.float32), (T_full, 1))
    cov_df = pd.DataFrame(
        np.concatenate([cont_np, locs_expanded], axis=1),
        index=time_index,
        columns=[f"x{j}_loc_{i}" for i in range(N) for j in range(F)] + [f"loc_{i}" for i in range(N)]
    )

    target_ts    = TimeSeries.from_dataframe(target_df, fill_missing_dates=False)
    covariate_ts = TimeSeries.from_dataframe(cov_df,   fill_missing_dates=False)

    T_train = len(train_X)
    input_chunk = CFG.get('tft_input_chunk_length') or min(64, max(8, T_train // 4))
    tft_model = TFTModel(
        input_chunk_length=input_chunk,
        output_chunk_length=int(CFG.get('tft_output_chunk_length', 1)),
        hidden_size=int(CFG.get('tft_hidden_size', 64)),
        n_epochs=int(CFG.get('tft_n_epochs', 30)),
        dropout=float(CFG.get('tft_dropout', 0.1)),
        batch_size=int(CFG.get('tft_batch_size', CFG.get('batch_size', 64))),
        random_state=int(CFG.get('seed', 42)),
        add_relative_index=True,
        force_reset=True,
        save_checkpoints=False,
    )
    tft_model.fit(
        series=target_ts[:T_train],
        past_covariates=covariate_ts[:T_train],
        verbose=True
    )

    # Wrap TFT as trend model for Spatial Adapter
    tft_wrapper = TFTWrapper(tft_model=tft_model, num_locations=N, num_features=F)

    trend_base = create_pretrained_trend_model(
        pretrained_model=tft_wrapper,
        input_shape=(None, N, F),   # 時間×地點×特徵
        output_shape=(None, N, 1),  # 時間×地點×1
    )
    trend_base = trend_base.to(DEVICE)  # ← 這裡再搬到裝置

    # Baseline metrics
    with torch.no_grad():
        y_tft_val  = trend_base(val_X.to(DEVICE))
        y_tft_test = trend_base(test_X.to(DEVICE))
    rmse_tft,  mae_tft,  r2_tft    = compute_metrics(val_y.to(DEVICE),  y_tft_val)
    rmse_tft_t, mae_tft_t, r2_tft_t = compute_metrics(test_y.to(DEVICE), y_tft_test)

    # ---- 用 TFT 驗證殘差初始化空間基底（top-K eigenvectors）----
    with torch.no_grad():
        resid_val = val_y.to(DEVICE) - y_tft_val
    covM = resid_val.T @ resid_val
    K = CFG["latent_dim"]
    eigvecs = torch.linalg.eigh(covM).eigenvectors[:, -K:]  # top-K

    basis_init = SpatialBasisLearner(N, K).to(DEVICE)
    basis_init.basis.data.copy_(eigvecs)

    # ---- 清空 cache，準備 bootstrap ----
    cache.clear()
    clear_gpu_memory()

    # ---- Bootstrap: tau1=tau2=0（Unregularized）----
    boot_trend = copy.deepcopy(trend_base).to(DEVICE)
    boot_basis = SpatialBasisLearner(N, K).to(DEVICE)
    boot_basis.basis.data.copy_(basis_init.basis.data)

    boot_writer = SummaryWriter(log_dir=log_root / "bootstrap")
    boot = SpatialNeuralAdapter(
        boot_trend,
        boot_basis,
        train_loader,
        val_cont=val_X.to(DEVICE),
        val_y=val_y.to(DEVICE),
        locs=locs,
        config=SPATIAL_CONFIG,
        device=DEVICE,
        writer=boot_writer,
        tau1=0.0,
        tau2=0.0,
    )
    boot.pretrain_trend(epochs=5)
    boot.init_basis_dense()
    boot.run()
    cache.store(0.0, 0.0, boot_trend.state_dict(), boot_basis.state_dict())
    boot_writer.close()

    # Unreg predictions
    y_boot_val  = boot.predict(val_X.to(DEVICE),  val_y.to(DEVICE))
    rmse_boot,  mae_boot,  r2_boot  = compute_metrics(val_y.to(DEVICE),  y_boot_val)
    y_boot_test = boot.predict(test_X.to(DEVICE), test_y.to(DEVICE))
    rmse_boot_t, mae_boot_t, r2_boot_t = compute_metrics(test_y.to(DEVICE), y_boot_test)

    # Clean up bootstrap models
    del boot_trend, boot_basis, boot
    clear_gpu_memory()

    # ---- Optuna objective（搜尋 tau1, tau2）----
    def objective(trial):
        dev  = DEVICE
        tau1 = trial.suggest_float("tau1", 1e-4, 1e8, log=True)
        tau2 = trial.suggest_float("tau2", 1e-4, 1e8, log=True)

        clear_gpu_memory()

        trend = copy.deepcopy(trend_base).to(dev)
        basis = SpatialBasisLearner(N, K).to(dev)
        basis.basis.data.copy_(basis_init.basis.data)

        # 若 cache 內有鄰近的起點，載入
        cache.load_nearest(trend, basis, tau1, tau2)

        writer = SummaryWriter(log_dir=log_root / f"trial_{trial.number:03d}")
        trainer = SpatialNeuralAdapter(
            trend,
            basis,
            train_loader,
            val_cont=val_X.to(dev),
            val_y=val_y.to(dev),
            locs=locs,
            config=SPATIAL_CONFIG,
            device=dev,
            writer=writer,
            tau1=tau1,
            tau2=tau2,
        )
        trainer.pretrain_trend(epochs=3)
        trainer.init_basis_dense()
        trainer.run()

        with torch.no_grad():
            y_pred = trainer.predict(val_X.to(dev), val_y.to(dev))
        rmse, mae, r2 = compute_metrics(val_y.to(dev), y_pred)

        trial.set_user_attr("rmse", rmse)
        trial.set_user_attr("mae", mae)
        trial.set_user_attr("r2", r2)

        writer.close()
        cache.store(tau1, tau2, trend.state_dict(), basis.state_dict())

        del trend, basis, trainer, y_pred
        clear_gpu_memory()

        return rmse

    study = optuna.create_study(
        study_name=f"spatial_adapter_ds{dataset_seed}",
        direction="minimize",
        sampler=optuna.samplers.TPESampler(),
        pruner=MedianPruner(n_warmup_steps=5),
        load_if_exists=False,
    )
    study.optimize(objective, n_trials=n_trials, n_jobs=1)

    # ---- Best results ----
    best     = study.best_trial
    rmse_opt = best.user_attrs["rmse"]
    mae_opt  = best.user_attrs["mae"]
    r2_opt   = best.user_attrs["r2"]
    tau1_opt = best.params["tau1"]
    tau2_opt = best.params["tau2"]
    best_no  = best.number

    # ---- Test best model（用 cache 中的最佳權重還原 trend/basis，做簡單投影評估）----
    dev_best = DEVICE
    trend_best = copy.deepcopy(trend_base).to(dev_best)
    basis_best = SpatialBasisLearner(N, K).to(dev_best)
    # 從 cache 還原（若不存在會 KeyError；正常來說已 store）
    sd_t, sd_b = cache.cache[(tau1_opt, tau2_opt)]
    trend_best.load_state_dict(sd_t)
    basis_best.load_state_dict(sd_b)

    trend_best.eval()
    basis_best.eval()
    with torch.no_grad():
        X_test = test_X.to(dev_best)
        y_test = test_y.to(dev_best)
        y_trend = trend_best(X_test)
        # 將殘差投影到基底子空間，再加回趨勢（與你原本做法一致）
        residual = y_test - y_trend
        y_basis  = (residual @ basis_best.basis) @ basis_best.basis.T
        y_reg_test = y_trend + y_basis
    rmse_test, mae_test, r2_test = compute_metrics(y_test, y_reg_test)

    # ---- Write results to CSV ----
    csv_path = Path("metrics_summary.csv")
    write_header = not csv_path.exists()
    with csv_path.open("a", newline="") as f:
        w = csv.writer(f)
        if write_header:
            w.writerow([
                "seed", "model", "trial", "tau1", "tau2",
                "rmse_val", "mae_val", "r2_val",
                "rmse_test", "mae_test", "r2_test"
            ])

        # Baseline（TFT）
        w.writerow([
            dataset_seed, "TFT", "", "", "",
            f"{rmse_tft:.6f}", f"{mae_tft:.6f}", f"{r2_tft:.6f}",
            f"{rmse_tft_t:.6f}", f"{mae_tft_t:.6f}", f"{r2_tft_t:.6f}"
        ])
        # Unregularized
        w.writerow([
            dataset_seed, "Unreg", "", "0", "0",
            f"{rmse_boot:.6f}", f"{mae_boot:.6f}", f"{r2_boot:.6f}",
            f"{rmse_boot_t:.6f}", f"{mae_boot_t:.6f}", f"{r2_boot_t:.6f}"
        ])
        # Regularized
        w.writerow([
            dataset_seed, "Reg", best_no, f"{tau1_opt:.6g}", f"{tau2_opt:.6g}",
            f"{rmse_opt:.6f}", f"{mae_opt:.6f}", f"{r2_opt:.6f}",
            f"{rmse_test:.6f}", f"{mae_test:.6f}", f"{r2_test:.6f}"
        ])

    print(
        f"Dataset {dataset_seed}:  "
        f"TFT RMSE={rmse_tft:.3f} | "
        f"Unreg RMSE={rmse_boot:.3f} | "
        f"Reg RMSE={rmse_opt:.3f} (test {rmse_test:.3f})"
    )

    # 為了相容外部程式若仍取 'ols' key，這裡沿用 'ols'，內容是 TFT 指標
    return {
        'ols':   {'rmse_val': rmse_tft,  'rmse_test': rmse_tft_t,  'r2_val': r2_tft,  'r2_test': r2_tft_t},
        'unreg': {'rmse_val': rmse_boot, 'rmse_test': rmse_boot_t, 'r2_val': r2_boot, 'r2_test': r2_boot_t},
        'reg':   {'rmse_val': rmse_opt,  'rmse_test': rmse_test,   'r2_val': r2_opt,  'r2_test': r2_test,
                  'tau1': tau1_opt, 'tau2': tau2_opt}
    }


## 6. Run Full Experiment Suite

In [8]:
all_results = []
for seed in range(EXPERIMENT_TRIALS_CONFIG['seed_range_start'], EXPERIMENT_TRIALS_CONFIG['seed_range_end']):
    print(f"\nStarting experiment for seed {seed}")
    results = run_one_experiment(seed, n_trials=EXPERIMENT_TRIALS_CONFIG['n_trials_per_seed'])
    all_results.append(results)
    cache.clear()
    clear_gpu_memory()
    print(f"✅ Completed seed {seed}")

print("\n🎉 All experiments completed!")


Starting experiment for seed 1
Training TFT baseline...


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                              | Type                             | Params | Mode 
------------------------------------------------------------------------------------------------
0  | train_metrics                     | MetricCollection                 | 0      | train
1  | val_metrics                       | MetricCollection                 | 0      | train
2  | input_embeddings                  | _MultiEmbedding                  | 0      | train
3  | static_covariates_vsn             | _VariableSelectionNetwork        | 0      | train
4  | encoder_vsn                       | _VariableSelectionNetwork        | 1.8 M  | train
5  | decoder_vsn                       | _VariableSelectionNetwork        | 528    | train
6  | static_context_grn                | _GatedResidualNetwork            | 1.1 K  | train
7  | static_cont

Epoch 2:  95%|█████████▌| 39/41 [02:56<00:09,  0.22it/s, train_loss=407.0]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

## 7. Results Analysis and Visualization

還未完全改完

In [None]:
# Load results
results_df = pd.read_csv("metrics_summary.csv")
print("📊 Results Summary (mean across seeds):")
print(results_df.groupby('model')[['rmse_val', 'rmse_test', 'r2_val', 'r2_test']].mean())

# Plot comparison
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# RMSE comparison
sns.boxplot(data=results_df, x='model', y='rmse_val', ax=axes[0,0])
axes[0,0].set_title('Validation RMSE'); axes[0,0].set_ylabel('RMSE'); axes[0,0].grid(True, alpha=0.3)

sns.boxplot(data=results_df, x='model', y='rmse_test', ax=axes[0,1])
axes[0,1].set_title('Test RMSE'); axes[0,1].set_ylabel('RMSE'); axes[0,1].grid(True, alpha=0.3)

# R² comparison
sns.boxplot(data=results_df, x='model', y='r2_val', ax=axes[1,0])
axes[1,0].set_title('Validation R²'); axes[1,0].set_ylabel('R²'); axes[1,0].grid(True, alpha=0.3)

sns.boxplot(data=results_df, x='model', y='r2_test', ax=axes[1,1])
axes[1,1].set_title('Test R²'); axes[1,1].set_ylabel('R²'); axes[1,1].grid(True, alpha=0.3)

plt.tight_layout(); plt.show()

# Show best hyperparameters for regularized model
reg_results = results_df[results_df['model'] == 'Reg']
print("\n🔧 Best Hyperparameters for Regularized Model (first 10 rows):")
print(reg_results[['tau1', 'tau2', 'rmse_val', 'rmse_test']].head(10))


In [None]:
# Performance comparison summary
print("=== Performance Comparison Summary ===")
tft_mean_rmse   = results_df[results_df['model'] == 'TFT']['rmse_test'].mean()
unreg_mean_rmse = results_df[results_df['model'] == 'Unreg']['rmse_test'].mean()
reg_mean_rmse   = results_df[results_df['model'] == 'Reg']['rmse_test'].mean()

print(f"TFT (baseline) - Mean Test RMSE: {tft_mean_rmse:.4f}")
print(f"Unregularized  - Mean Test RMSE: {unreg_mean_rmse:.4f} ({(1 - unreg_mean_rmse/tft_mean_rmse)*100:.1f}% improvement)")
print(f"Regularized    - Mean Test RMSE: {reg_mean_rmse:.4f} ({(1 - reg_mean_rmse/tft_mean_rmse)*100:.1f}% improvement)")

from scipy import stats
wide = results_df.pivot_table(index='seed', columns='model', values='rmse_test', aggfunc='mean')
paired = wide[['TFT', 'Reg']].dropna()
t_stat, p_value = stats.ttest_rel(paired['TFT'], paired['Reg'])

print(f"\nStatistical Test (TFT vs Regularized): t={t_stat:.4f}, p={p_value:.4f}, "
      f"significant: {'Yes' if p_value < 0.05 else 'No'}")
