In [None]:
from pathlib import Path
from typing import List, Tuple

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

from darts import TimeSeries
from darts.models import TFTModel
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import CSVLogger

from geospatial_neural_adapter.data.preprocessing import prepare_all_with_scaling
from geospatial_neural_adapter.data.generators import generate_time_synthetic_data


# Global settings & experiment dirs
GLOBAL_SEED = 42
np.random.seed(GLOBAL_SEED)
torch.manual_seed(GLOBAL_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

try:
    EXP_ROOT = Path(__file__).resolve().parent
except NameError:
    EXP_ROOT = Path.cwd()

CKPT_DIR = (EXP_ROOT / "darts_ckpt_synth_tft_Bprime").resolve()
RUNS_DIR = (EXP_ROOT / "TFT_runs_synth_tft_Bprime").resolve()
PLOTS_DIR = (EXP_ROOT / "TFT_plots_synth_tft_Bprime_2figs").resolve()

CKPT_DIR.mkdir(parents=True, exist_ok=True)
RUNS_DIR.mkdir(parents=True, exist_ok=True)
PLOTS_DIR.mkdir(parents=True, exist_ok=True)


# Generate synthetic data
N_POINTS = 36
T_TOTAL = 1500
EIGENVALUE = 3.0

locs = np.linspace(-5.0, 5.0, N_POINTS).astype(np.float32)

cat_synth, cont_synth, y_synth = generate_time_synthetic_data(
    locs=locs,
    n_time_steps=T_TOTAL,
    noise_std=0.5,
    eigenvalue=EIGENVALUE,
    eta_rho=0.6,
    f_rho=0.6,
    global_mean=50.0,
    feature_noise_std=0.2,
    non_linear_strength=1.0,
    seed=GLOBAL_SEED,
)

print("Synthetic y shape :", y_synth.shape)
print("Synthetic cont shape:", cont_synth.shape)

Y = y_synth.astype(np.float32)
X = cont_synth.astype(np.float32)

T_total, N = Y.shape
_, _, P_time = X.shape

time_start = pd.Timestamp("2000-01-01 00:00:00")
time_index = pd.date_range(start=time_start, periods=T_total, freq="1H")
freq = "1H"

print("Total time steps:", T_total)
print("Number of points:", N)
print("Freq:", freq)

PAST_COV_COLS = ["f1", "f2", "f3"]

cat_dummy = np.zeros((T_total, N, 1), dtype=np.int64)


# Train/Val/Test split & scaling
train_ratio = 0.70
val_ratio = 0.15

cut_train = int(T_total * train_ratio)
cut_val = int(T_total * (train_ratio + val_ratio))

train_ds, val_ds, test_ds, preprocessor = prepare_all_with_scaling(
    cat_features=cat_dummy,
    cont_features=X,
    targets=Y,
    train_ratio=train_ratio,
    val_ratio=val_ratio,
    feature_scaler_type="standard",
    target_scaler_type="standard",
    fit_on_train_only=True,
)

def stitch_targets(dsets):
    return np.concatenate(
        [ds.tensors[2].detach().cpu().numpy().astype(np.float32) for ds in dsets],
        axis=0,
    )

def stitch_cont_features(dsets):
    return np.concatenate(
        [ds.tensors[1].detach().cpu().numpy().astype(np.float32) for ds in dsets],
        axis=0,
    )

Y_scaled = stitch_targets([train_ds, val_ds, test_ds])
X_scaled = stitch_cont_features([train_ds, val_ds, test_ds])


# Build TimeSeries lists
from typing import List as _ListTs

series_all: _ListTs[TimeSeries] = []
pcov_all: _ListTs[TimeSeries] = []

K_LAT = 6
K_LON = 6
if K_LAT * K_LON != N:
    lat_pts = np.linspace(-5.0, 5.0, N).astype(np.float32)
    lon_pts = np.zeros_like(lat_pts, dtype=np.float32)
else:
    lat_lin = np.linspace(18.0, 53.6, K_LAT, dtype=np.float32)
    lon_lin = np.linspace(73.5, 134.8, K_LON, dtype=np.float32)
    lat_grid, lon_grid = np.meshgrid(lat_lin, lon_lin, indexing="ij")
    lat_pts = lat_grid.reshape(-1).astype(np.float32)
    lon_pts = lon_grid.reshape(-1).astype(np.float32)

alt_pts = np.zeros_like(lat_pts, dtype=np.float32)

for j in range(N):
    name = f"grid_{j:03d}"

    ts = TimeSeries.from_times_and_values(
        times=time_index,
        values=Y_scaled[:, j:j + 1].astype(np.float32),
        columns=[name],
        freq=freq,
    )

    sc = pd.DataFrame(
        {"lat": [float(lat_pts[j])], "lon": [float(lon_pts[j])], "alt": [float(alt_pts[j])]},
        index=[name],
        dtype=np.float32,
    )

    ts = ts.with_static_covariates(sc)
    series_all.append(ts)

    pc = TimeSeries.from_times_and_values(
        times=time_index,
        values=X_scaled[:, j, :].astype(np.float32),
        columns=[f"{name}_{c}" for c in PAST_COV_COLS],
        freq=freq,
    ).with_static_covariates(sc)

    pcov_all.append(pc)


# Internal validation
L, H = 56, 24
INTERNAL_VAL_STEPS = 240

def slice_list(xs, a, b):
    return [x[a:b] for x in xs]

train_series = slice_list(series_all, 0, cut_train)
train_pcov = slice_list(pcov_all, 0, cut_train)

iv_start = max(0, cut_train - INTERNAL_VAL_STEPS)
iv_end = cut_train

if (iv_end - iv_start) < (L + H):
    iv_start = max(0, iv_end - (L + H))

val_series = slice_list(series_all, iv_start, iv_end)
val_pcov = slice_list(pcov_all, iv_start, iv_end)


# Train TFT
MODEL_NAME = f"tft_synth_Bprime_L{L}_H{H}_seed{GLOBAL_SEED}"
LOG_ROOT = (RUNS_DIR / f"seed_{GLOBAL_SEED}").resolve()
LOG_ROOT.mkdir(parents=True, exist_ok=True)

tft = TFTModel(
    input_chunk_length=L,
    output_chunk_length=H,
    n_epochs=30,
    hidden_size=64,
    num_attention_heads=4,
    dropout=0.1,
    batch_size=32,
    optimizer_kwargs={"lr": 3e-4},
    add_relative_index=True,
    random_state=GLOBAL_SEED,
    force_reset=True,
    model_name=MODEL_NAME,
    work_dir=str(CKPT_DIR),
    save_checkpoints=True,
    pl_trainer_kwargs={
        "accelerator": "gpu" if torch.cuda.is_available() else "cpu",
        "devices": 1,
        "enable_progress_bar": True,
        "enable_model_summary": False,
        "enable_checkpointing": True,
        "callbacks": [EarlyStopping(monitor="val_loss", mode="min", patience=6)],
        "logger": CSVLogger(save_dir=str(LOG_ROOT), name=MODEL_NAME),
        "gradient_clip_val": 1.0,
    },
)

tft.fit(
    series=train_series,
    past_covariates=train_pcov,
    val_series=val_series,
    val_past_covariates=val_pcov,
    verbose=True,
)

tft = TFTModel.load_from_checkpoint(
    model_name=MODEL_NAME,
    work_dir=str(CKPT_DIR),
    best=True,
)


# Rolling non-overlap
def rolling_nonoverlap(model, series, pcov, start, end, L, H):
    yh, yt, t0_list = [], [], []
    for t0 in range(start, end - H + 1, H):
        if t0 < L:
            continue
        ctx_s = [s[:t0] for s in series]
        ctx_p = [p[:t0] for p in pcov]
        preds = model.predict(n=H, series=ctx_s, past_covariates=ctx_p, verbose=False)
        preds = preds if isinstance(preds, list) else [preds]
        yh.append(np.stack([p.values(copy=False)[:, 0] for p in preds], axis=1))
        yt.append(np.stack([s[t0:t0 + H].values(copy=False)[:, 0] for s in series], axis=1))
        t0_list.append(t0)
    return np.stack(yh), np.stack(yt), t0_list

test_start = cut_val
start_ctx = max(test_start, L)

yhat_roll, ytrue_roll, t0_list = rolling_nonoverlap(
    tft, series_all, pcov_all, start_ctx, T_TOTAL, L, H
)

W = yhat_roll.shape[0]
diff = yhat_roll - ytrue_roll
rmse = float(np.sqrt(np.mean(diff ** 2)))
mae = float(np.mean(np.abs(diff)))

print("RMSE:", rmse)
print("MAE :", mae)


# q-risk
def qrisk(y_true, y_pred, q=0.5, eps=1e-8):
    y_true = np.asarray(y_true, dtype=np.float64)
    y_pred = np.asarray(y_pred, dtype=np.float64)
    e = y_true - y_pred
    return float(2.0 * np.sum(np.maximum(q * e, (q - 1) * e)) / (np.sum(np.abs(y_true)) + eps))

yhat_f = yhat_roll.reshape(-1, N)
ytrue_f = ytrue_roll.reshape(-1, N)

print("scaled  :", qrisk(ytrue_f, yhat_f))
print("unscaled:", qrisk(
    preprocessor.inverse_transform_targets(ytrue_f),
    preprocessor.inverse_transform_targets(yhat_f),
))


# Plots
test_end = T_TOTAL
test_len = test_end - test_start

y_true_test = Y_scaled[test_start:test_end, :]
y_pred_test = np.full((test_len, N), np.nan, dtype=np.float32)

for w, t0 in enumerate(t0_list):
    a = t0 - test_start
    b = a + H
    if 0 <= a and b <= test_len:
        y_pred_test[a:b, :] = yhat_roll[w]

w0 = 0
t0_first = t0_list[w0]
dates_first = time_index[t0_first:t0_first + H]
dates_test = time_index[test_start:test_end]

def plot_two_figs_all_points(out_dir: Path):
    out_dir.mkdir(parents=True, exist_ok=True)
    for j in range(N):
        plt.figure(figsize=(10, 4), dpi=140)
        plt.plot(dates_first, ytrue_roll[w0][:, j], "-o", label="True")
        plt.plot(dates_first, yhat_roll[w0][:, j], "-o", label="Pred")
        plt.tight_layout()
        plt.savefig(out_dir / f"FIG1_grid{j:03d}.png")
        plt.close()

        plt.figure(figsize=(12, 4), dpi=140)
        plt.plot(dates_test, y_true_test[:, j], label="True")
        plt.plot(dates_test, y_pred_test[:, j], label="Pred")
        plt.tight_layout()
        plt.savefig(out_dir / f"FIG2_grid{j:03d}.png")
        plt.close()

plot_two_figs_all_points(PLOTS_DIR)

print("All done.")


  __import__("pkg_resources").declare_namespace(__name__)  # type: ignore


✅ Loaded spatial_utils from: /home/wangxc1117/geospatial-neural-adapter/geospatial_neural_adapter/cpp_extensions/spatial_utils.so
Device: cuda
Synthetic (Scenario B′) y shape : (1500, 36)
Synthetic (Scenario B′) cont shape: (1500, 36, 3)
Total time steps: 1500
Number of points: 36
Freq: 1H

=== Split (Scenario B′, by time index) ===
Train: 0 -> 1049 | len = 1050
Val  : 1050 -> 1274 | len = 225
Test : 1275 -> 1499 | len = 225

=== After pooled scaling (Scenario B′, train-only) ===
Y_scaled: (1500, 36) | finite: True
X_scaled: (1500, 36, 3) | finite: True

Built TimeSeries (Scenario B′):
Targets: 36
Past covs: 36

=== INTERNAL validation (Scenario B′, for early stopping only) ===
Train len: 1050
IntVal idx: 810 -> 1049 | len = 240
IntVal time: 2000-02-03 18:00:00 -> 2000-02-13 17:00:00

=== Training TFT (Scenario B′: nonlinear world, eta_rho=0.6) ===


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                              | Type                             | Params | Mode 
------------------------------------------------------------------------------------------------
0  | train_metrics                     | MetricCollection                 | 0      | train
1  | val_metrics                       | MetricCollection                 | 0      | train
2  | input_embeddings                  | _MultiEmbedding                  | 0      | train
3  | 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.



=== Loading best checkpoint (Scenario B′) ===


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Loaded best checkpoint.


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available


=== Rolling non-overlap (Scenario B′, POOLED, TEST) ===
windows=9 | step=24 | each window predicts H=24 | points=36
RMSE: 2.201981544494629
MAE : 1.5484315156936646

=== P50 q-risk (Scenario B′, ROLLING, TEST) ===
scaled  : 1.0103965128501733
unscaled: 0.08466446155645471

=== Plotting 2 figs per grid point (Scenario B′) ===
Saved 2 figs/point for 36 points under: /home/wangxc1117/TFTModel-use/geospatial-neural-adapter-dev/examples/try/simulation/TFT_simulation_B1/TFT_plots_synth_tft_Bprime_2figs

All done (Scenario B′: nonlinear world, eta_rho=0.6).
