In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

from darts import TimeSeries
from darts.models import TFTModel
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import CSVLogger

from geospatial_neural_adapter.data.preprocessing import prepare_all_with_scaling
from geospatial_neural_adapter.data.generators import generate_time_synthetic_data


# Global settings & dirs
GLOBAL_SEED = 42
np.random.seed(GLOBAL_SEED)
torch.manual_seed(GLOBAL_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

try:
    EXP_ROOT = Path(__file__).resolve().parent
except NameError:
    EXP_ROOT = Path.cwd()

CKPT_DIR = (EXP_ROOT / "darts_ckpt_synth_tft_D").resolve()
RUNS_DIR = (EXP_ROOT / "TFT_runs_synth_tft_D").resolve()
PLOTS_DIR = (EXP_ROOT / "TFT_plots_synth_tft_D_2figs").resolve()

CKPT_DIR.mkdir(parents=True, exist_ok=True)
RUNS_DIR.mkdir(parents=True, exist_ok=True)
PLOTS_DIR.mkdir(parents=True, exist_ok=True)


# Data generation
N_POINTS = 36
T_TOTAL = 1500
EIGENVALUE = 2.0

locs = np.linspace(-5.0, 5.0, N_POINTS).astype(np.float32)

cat_synth, cont_synth, y_synth = generate_time_synthetic_data(
    locs=locs,
    n_time_steps=T_TOTAL,
    noise_std=0.5,
    eigenvalue=EIGENVALUE,
    eta_rho=0.4,
    f_rho=0.6,
    global_mean=50.0,
    feature_noise_std=0.2,
    non_linear_strength=2.5,
    seed=GLOBAL_SEED,
)

print("Synthetic (Scenario D) y shape :", y_synth.shape)
print("Synthetic (Scenario D) cont shape:", cont_synth.shape)

Y = y_synth.astype(np.float32)
X = cont_synth.astype(np.float32)

T_total, N = Y.shape
P_time = X.shape[2]

time_start = pd.Timestamp("2000-01-01 00:00:00")
time_index = pd.date_range(start=time_start, periods=T_total, freq="1H")
freq = "1H"
if len(time_index) != T_total:
    raise RuntimeError("Global time_index length mismatch.")

print("Total time steps:", T_total)
print("Number of points:", N)
print("Freq:", freq)

PAST_COV_COLS = ["f1", "f2", "f3"]
cat_dummy = np.zeros((T_total, N, 1), dtype=np.int64)


# Split & scaling
train_ratio = 0.70
val_ratio = 0.15

cut_train = int(T_total * train_ratio)
cut_val = int(T_total * (train_ratio + val_ratio))
if not (0 < cut_train < cut_val < T_total):
    raise ValueError("Bad split indices computed from ratios.")

print("Train:", 0, "->", cut_train - 1, "| len =", cut_train)
print("Val  :", cut_train, "->", cut_val - 1, "| len =", cut_val - cut_train)
print("Test :", cut_val, "->", T_total - 1, "| len =", T_total - cut_val)

train_ds, val_ds, test_ds, preprocessor = prepare_all_with_scaling(
    cat_features=cat_dummy,
    cont_features=X,
    targets=Y,
    train_ratio=train_ratio,
    val_ratio=val_ratio,
    feature_scaler_type="standard",
    target_scaler_type="standard",
    fit_on_train_only=True,
)

def stitch_targets(dsets) -> np.ndarray:
    return np.concatenate(
        [ds.tensors[2].detach().cpu().numpy().astype(np.float32) for ds in dsets],
        axis=0,
    )

def stitch_cont_features(dsets) -> np.ndarray:
    return np.concatenate(
        [ds.tensors[1].detach().cpu().numpy().astype(np.float32) for ds in dsets],
        axis=0,
    )

Y_scaled = stitch_targets([train_ds, val_ds, test_ds])
X_scaled = stitch_cont_features([train_ds, val_ds, test_ds])

if Y_scaled.shape != Y.shape:
    raise ValueError(f"Y_scaled shape {Y_scaled.shape} != Y {Y.shape}")
if X_scaled.shape != X.shape:
    raise ValueError(f"X_scaled shape {X_scaled.shape} != X {X.shape}")

print("Y_scaled:", Y_scaled.shape, "| finite:", bool(np.isfinite(Y_scaled).all()))
print("X_scaled:", X_scaled.shape, "| finite:", bool(np.isfinite(X_scaled).all()))


# Darts TimeSeries
series_all = []
pcov_all = []

K_LAT = 6
K_LON = 6
if K_LAT * K_LON != N:
    lat_pts = np.linspace(-5.0, 5.0, N).astype(np.float32)
    lon_pts = np.zeros_like(lat_pts, dtype=np.float32)
else:
    lat_lin = np.linspace(18.0, 53.6, K_LAT, dtype=np.float32)
    lon_lin = np.linspace(73.5, 134.8, K_LON, dtype=np.float32)
    lat_grid, lon_grid = np.meshgrid(lat_lin, lon_lin, indexing="ij")
    lat_pts = lat_grid.reshape(-1).astype(np.float32)
    lon_pts = lon_grid.reshape(-1).astype(np.float32)

alt_pts = np.zeros_like(lat_pts, dtype=np.float32)

for j in range(N):
    name = f"grid_{j:03d}"

    ts = TimeSeries.from_times_and_values(
        times=time_index,
        values=Y_scaled[:, j:j + 1].astype(np.float32),
        columns=[name],
        freq=freq,
    )

    sc = pd.DataFrame(
        {"lat": [float(lat_pts[j])], "lon": [float(lon_pts[j])], "alt": [float(alt_pts[j])]},
        index=[name],
        dtype=np.float32,
    )

    ts = ts.with_static_covariates(sc)
    series_all.append(ts)

    pc = TimeSeries.from_times_and_values(
        times=time_index,
        values=X_scaled[:, j, :].astype(np.float32),
        columns=[f"{name}_{c}" for c in PAST_COV_COLS],
        freq=freq,
    ).with_static_covariates(sc)

    pcov_all.append(pc)

print("Targets:", len(series_all))
print("Past covs:", len(pcov_all))


# Internal validation
L = 56
H = 24
INTERNAL_VAL_STEPS = 240

def slice_list(xs, a, b):
    return [x[a:b] for x in xs]

train_series = slice_list(series_all, 0, cut_train)
train_pcov = slice_list(pcov_all, 0, cut_train)

iv_start = max(0, cut_train - INTERNAL_VAL_STEPS)
iv_end = cut_train

min_needed = L + H
if (iv_end - iv_start) < min_needed:
    iv_start = max(0, iv_end - min_needed)

val_series = slice_list(series_all, iv_start, iv_end)
val_pcov = slice_list(pcov_all, iv_start, iv_end)

print("Train len:", len(train_series[0]))
print("IntVal idx:", iv_start, "->", iv_end - 1, "| len =", len(val_series[0]))
print("IntVal time:", time_index[iv_start], "->", time_index[iv_end - 1])


# Train TFT
MODEL_NAME = f"tft_synth_D_L{L}_H{H}_seed{GLOBAL_SEED}"
LOG_ROOT = (RUNS_DIR / f"seed_{GLOBAL_SEED}").resolve()
LOG_ROOT.mkdir(parents=True, exist_ok=True)

tft = TFTModel(
    input_chunk_length=L,
    output_chunk_length=H,
    n_epochs=30,
    hidden_size=64,
    num_attention_heads=4,
    dropout=0.1,
    batch_size=32,
    optimizer_kwargs={"lr": 3e-4},
    add_relative_index=True,
    random_state=GLOBAL_SEED,
    force_reset=True,
    model_name=MODEL_NAME,
    work_dir=str(CKPT_DIR),
    save_checkpoints=True,
    pl_trainer_kwargs={
        "accelerator": "gpu" if torch.cuda.is_available() else "cpu",
        "devices": 1,
        "enable_progress_bar": True,
        "enable_model_summary": False,
        "enable_checkpointing": True,
        "callbacks": [EarlyStopping(monitor="val_loss", mode="min", patience=6)],
        "logger": CSVLogger(save_dir=str(LOG_ROOT), name=MODEL_NAME),
        "gradient_clip_val": 1.0,
    },
)

print("Training TFT (Scenario D)")
tft.fit(
    series=train_series,
    past_covariates=train_pcov,
    val_series=val_series,
    val_past_covariates=val_pcov,
    verbose=True,
)

tft = TFTModel.load_from_checkpoint(model_name=MODEL_NAME, work_dir=str(CKPT_DIR), best=True)
print("Loaded best checkpoint.")


# Rolling evaluation
def rolling_nonoverlap(model, series, pcov, start, end, L, H):
    yh, yt = [], []
    t0_list = []

    for t0 in range(start, end - H + 1, H):
        if t0 < L:
            continue

        ctx_s = [s[:t0] for s in series]
        ctx_p = [p[:t0] for p in pcov]
        preds = model.predict(n=H, series=ctx_s, past_covariates=ctx_p, verbose=False)
        preds = preds if isinstance(preds, list) else [preds]

        yh.append(np.stack([p.values(copy=False)[:, 0] for p in preds], axis=1).astype(np.float32))
        yt.append(np.stack([s[t0:t0 + H].values(copy=False)[:, 0] for s in series], axis=1).astype(np.float32))
        t0_list.append(int(t0))

    if len(yh) == 0:
        raise RuntimeError("No rolling windows produced. Check start/end/L/H.")
    return np.stack(yh, axis=0), np.stack(yt, axis=0), t0_list

test_start = cut_val
start_ctx = max(test_start, L)
yhat_roll, ytrue_roll, t0_list = rolling_nonoverlap(tft, series_all, pcov_all, start_ctx, T_total, L, H)

W = yhat_roll.shape[0]
diff = yhat_roll - ytrue_roll
rmse = float(np.sqrt(np.mean(diff ** 2)))
mae = float(np.mean(np.abs(diff)))

print("windows=", W, "| step=", H, "| H=", H, "| points=", N)
print("RMSE:", rmse)
print("MAE :", mae)


# q-risk
def qrisk(y_true, y_pred, q=0.5, eps=1e-8):
    y_true = np.asarray(y_true, dtype=np.float64)
    y_pred = np.asarray(y_pred, dtype=np.float64)
    e = y_true - y_pred
    return float(2.0 * np.sum(np.maximum(q * e, (q - 1) * e)) / (np.sum(np.abs(y_true)) + eps))

yhat_f = yhat_roll.reshape(-1, N)
ytrue_f = ytrue_roll.reshape(-1, N)

print("scaled  :", qrisk(ytrue_f, yhat_f))
ytrue_un = preprocessor.inverse_transform_targets(ytrue_f)
yhat_un = preprocessor.inverse_transform_targets(yhat_f)
print("unscaled:", qrisk(ytrue_un, yhat_un))


# Plots
test_end = T_total
test_len = test_end - test_start

y_true_test = Y_scaled[test_start:test_end, :].astype(np.float32)

y_pred_test = np.full((test_len, N), np.nan, dtype=np.float32)
for w, t0 in enumerate(t0_list):
    a = t0 - test_start
    b = a + H
    if a < 0 or b > test_len:
        continue
    y_pred_test[a:b, :] = yhat_roll[w]

w0 = 0
t0_first = t0_list[w0]
dates_first = time_index[t0_first:t0_first + H]
dates_test = time_index[test_start:test_end]

def plot_two_figs_all_points(out_dir: Path):
    out_dir.mkdir(parents=True, exist_ok=True)

    for j in range(N):
        y_true_H = ytrue_roll[w0][:, j]
        y_pred_H = yhat_roll[w0][:, j]

        plt.figure(figsize=(10, 4), dpi=140)
        plt.plot(dates_first, y_true_H, "-o", linewidth=2, markersize=3, label="True")
        plt.plot(dates_first, y_pred_H, "-o", linewidth=2, markersize=3, label="Pred")
        plt.title(f"FIG1 Synthetic D | grid_{j:03d} | TEST first {H} steps | y (scaled)")
        plt.xlabel("Time")
        plt.ylabel("Scaled target")
        plt.grid(alpha=0.3)
        plt.legend()
        plt.tight_layout()
        plt.savefig(out_dir / f"FIG1_grid{j:03d}_test_firstH{H}_t0{t0_first}.png")
        plt.close()

        y_true_all = y_true_test[:, j]
        y_pred_all = y_pred_test[:, j]

        plt.figure(figsize=(12, 4), dpi=140)
        plt.plot(dates_test, y_true_all, "-", linewidth=1.8, label="True")
        plt.plot(dates_test, y_pred_all, "-", linewidth=1.8, label="Pred (stitched rolling)")
        plt.axvline(time_index[test_start], linestyle="--", linewidth=1, label="TEST start")
        plt.title(f"FIG2 Synthetic D | grid_{j:03d} | ALL TEST | y (scaled)")
        plt.xlabel("Time")
        plt.ylabel("Scaled y")
        plt.grid(alpha=0.3)
        plt.legend()
        plt.tight_layout()
        plt.savefig(out_dir / f"FIG2_grid{j:03d}_test_all_stitched_step{H}.png")
        plt.close()

    print(f"Saved 2 figs/point for {N} points under: {out_dir}")

plot_two_figs_all_points(PLOTS_DIR)

print("All done (Scenario D).")


  __import__("pkg_resources").declare_namespace(__name__)  # type: ignore


âœ… Loaded spatial_utils from: /home/wangxc1117/geospatial-neural-adapter/geospatial_neural_adapter/cpp_extensions/spatial_utils.so
Device: cuda
Synthetic (Scenario D) y shape : (1500, 36)
Synthetic (Scenario D) cont shape: (1500, 36, 3)
Total time steps: 1500
Number of points: 36
Freq: 1H

=== Split (by time index) ===
Train: 0 -> 1049 | len = 1050
Val  : 1050 -> 1274 | len = 225
Test : 1275 -> 1499 | len = 225

=== After pooled scaling (train-only) ===
Y_scaled: (1500, 36) | finite: True
X_scaled: (1500, 36, 3) | finite: True

Built TimeSeries (Scenario D):
Targets: 36
Past covs: 36

=== INTERNAL validation (Scenario D, for early stopping only) ===
Train len: 1050
IntVal idx: 810 -> 1049 | len = 240
IntVal time: 2000-02-03 18:00:00 -> 2000-02-13 17:00:00


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



=== Training TFT (Scenario D: strong nonlinear, weak AR, weaker spatial) ===



   | Name                              | Type                             | Params | Mode 
------------------------------------------------------------------------------------------------
0  | train_metrics                     | MetricCollection                 | 0      | train
1  | val_metrics                       | MetricCollection                 | 0      | train
2  | input_embeddings                  | _MultiEmbedding                  | 0      | train
3  | static_covariates_vsn             | _VariableSelectionNetwork        | 5.0 K  | train
4  | encoder_vsn                       | _VariableSelectionNetwork        | 8.8 K  | train
5  | decoder_vsn                       | _VariableSelectionNetwork        | 1.6 K  | train
6  | static_context_grn                | _GatedResidualNetwork            | 16.8 K | train
7  | static_context_hidden_encoder_grn | _GatedResidualNetwork            | 16.8 K | train
8  | static_context_cell_encoder_grn   | _GatedResidualNetwork            | 16.8 K 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.



=== Loading best checkpoint ===


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Loaded best checkpoint.


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available


=== Rolling non-overlap (Scenario D, POOLED, TEST) ===
windows=9 | step=24 | each window predicts H=24 | points=36
RMSE: 2.2825212478637695
MAE : 1.5991889238357544

=== P50 q-risk (Scenario D, ROLLING, TEST) ===
scaled  : 1.021848168268284
unscaled: 0.17983688117443752

=== Plotting 2 figs per grid point (Scenario D) ===
Saved 2 figs/point for 36 points under: /home/wangxc1117/TFTModel-use/geospatial-neural-adapter-dev/examples/try/simulation/TFT_simulation_D/TFT_plots_synth_tft_D_2figs

All done (Scenario D).
