In [None]:
from pathlib import Path
from typing import List

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

from darts import TimeSeries
from darts.models import TFTModel
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import CSVLogger

from geospatial_neural_adapter.data.preprocessing import prepare_all_with_scaling


# Paths & settings
W2K_PATH = Path("/home/wangxc1117/Weather2K/weather2k.npy")
if not W2K_PATH.exists():
    raise FileNotFoundError(f"Weather2K npy not found: {W2K_PATH}")

GLOBAL_SEED = 42
np.random.seed(GLOBAL_SEED)
torch.manual_seed(GLOBAL_SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

try:
    EXP_ROOT = Path(__file__).resolve().parent
except NameError:
    EXP_ROOT = Path.cwd()

CKPT_DIR = (EXP_ROOT / "darts_ckpt_weather2k_beijing_tft_pooled_wd_sincos").resolve()
RUNS_DIR = (EXP_ROOT / "TFT_runs_weather2k_beijing_tft_pooled_wd_sincos").resolve()
PLOTS_DIR = (EXP_ROOT / "TFT_plots_weather2k_beijing_tft_pooled_wd_sincos_2figs").resolve()

CKPT_DIR.mkdir(parents=True, exist_ok=True)
RUNS_DIR.mkdir(parents=True, exist_ok=True)
PLOTS_DIR.mkdir(parents=True, exist_ok=True)


# Load data
arr = np.load(W2K_PATH, allow_pickle=False).astype(np.float32)
S, V, T = arr.shape
if V != 13:
    raise ValueError(f"Expected 13 variables, got V={V}")
print("Weather2K shape:", arr.shape)
print("T steps:", T, "| 3-hour freq | days =", T / 8.0)


# Beijing subset (bounding box)
LAT_MIN, LAT_MAX = 39.4, 41.1
LON_MIN, LON_MAX = 115.4, 117.5

lat = arr[:, 0, 0]
lon = arr[:, 1, 0]

use_idx = np.where(
    (lat >= LAT_MIN) & (lat <= LAT_MAX) &
    (lon >= LON_MIN) & (lon <= LON_MAX)
)[0]

N = len(use_idx)
print("Total stations:", S)
print("Beijing stations:", N)
if N == 0:
    raise RuntimeError("No stations found in Beijing bounding box. Consider expanding bounds.")
print("First 10 use_idx:", use_idx[:10])
print("lat range:", float(lat[use_idx].min()), "to", float(lat[use_idx].max()))
print("lon range:", float(lon[use_idx].min()), "to", float(lon[use_idx].max()))


time_index = pd.date_range("2000-01-01", periods=T, freq="3H")


TARGET_IDX = 4
STATIC_IDXS = [0, 1, 2]

AP_IDX = 3
MXT_IDX = 5
MNT_IDX = 6
RH_IDX = 7
P3_IDX = 8
WD_IDX = 9
WS_IDX = 10
MWD_IDX = 11
MWS_IDX = 12

PAST_COV_COLS = [
    "ap", "mxt", "mnt", "rh", "p3",
    "wd_sin", "wd_cos",
    "mwd_sin", "mwd_cos",
    "ws", "mws",
]


def deg2rad_f32(x: np.ndarray) -> np.ndarray:
    return (x.astype(np.float32) * (np.pi / 180.0)).astype(np.float32)

Y = np.stack([arr[int(s), TARGET_IDX, :] for s in use_idx], axis=1).astype(np.float32)

X_list = []
for s in use_idx:
    s = int(s)

    ap = arr[s, AP_IDX, :].astype(np.float32)
    mxt = arr[s, MXT_IDX, :].astype(np.float32)
    mnt = arr[s, MNT_IDX, :].astype(np.float32)
    rh = arr[s, RH_IDX, :].astype(np.float32)
    p3 = arr[s, P3_IDX, :].astype(np.float32)

    wd = arr[s, WD_IDX, :].astype(np.float32)
    mwd = arr[s, MWD_IDX, :].astype(np.float32)

    ws = arr[s, WS_IDX, :].astype(np.float32)
    mws = arr[s, MWS_IDX, :].astype(np.float32)

    wd_rad = deg2rad_f32(wd)
    mwd_rad = deg2rad_f32(mwd)

    wd_sin = np.sin(wd_rad).astype(np.float32)
    wd_cos = np.cos(wd_rad).astype(np.float32)
    mwd_sin = np.sin(mwd_rad).astype(np.float32)
    mwd_cos = np.cos(mwd_rad).astype(np.float32)

    Xs = np.stack(
        [ap, mxt, mnt, rh, p3, wd_sin, wd_cos, mwd_sin, mwd_cos, ws, mws],
        axis=1
    ).astype(np.float32)

    X_list.append(Xs)

X = np.stack(X_list, axis=1).astype(np.float32)

P = X.shape[2]
if P != len(PAST_COV_COLS):
    raise ValueError(f"Past cov dim mismatch: P={P} vs len(PAST_COV_COLS)={len(PAST_COV_COLS)}")

cat_dummy = np.zeros((T, N, 1), dtype=np.int64)


train_ratio = 0.70
val_ratio = 0.15

cut_train = int(T * train_ratio)
cut_val = int(T * (train_ratio + val_ratio))
if not (0 < cut_train < cut_val < T):
    raise ValueError("Bad split indices computed from ratios.")

print("\n=== Split (by time index) ===")
print("Train:", 0, "->", cut_train - 1, "| len =", cut_train)
print("Val  :", cut_train, "->", cut_val - 1, "| len =", cut_val - cut_train)
print("Test :", cut_val, "->", T - 1, "| len =", T - cut_val)

train_ds, val_ds, test_ds, preprocessor = prepare_all_with_scaling(
    cat_features=cat_dummy,
    cont_features=X,
    targets=Y,
    train_ratio=train_ratio,
    val_ratio=val_ratio,
    feature_scaler_type="standard",
    target_scaler_type="standard",
    fit_on_train_only=True,
)

def stitch_targets(dsets) -> np.ndarray:
    return np.concatenate([ds.tensors[2].detach().cpu().numpy().astype(np.float32) for ds in dsets], axis=0)

def stitch_cont_features(dsets) -> np.ndarray:
    return np.concatenate([ds.tensors[1].detach().cpu().numpy().astype(np.float32) for ds in dsets], axis=0)

Y_scaled = stitch_targets([train_ds, val_ds, test_ds])
X_scaled = stitch_cont_features([train_ds, val_ds, test_ds])

if Y_scaled.shape != Y.shape:
    raise ValueError(f"Y_scaled shape {Y_scaled.shape} != Y {Y.shape}")
if X_scaled.shape != X.shape:
    raise ValueError(f"X_scaled shape {X_scaled.shape} != X {X.shape}")

print("\n=== After pooled scaling (train-only) ===")
print("Y_scaled:", Y_scaled.shape, "| finite:", bool(np.isfinite(Y_scaled).all()))
print("X_scaled:", X_scaled.shape, "| finite:", bool(np.isfinite(X_scaled).all()))


series_all: List[TimeSeries] = []
pcov_all: List[TimeSeries] = []

for j, s in enumerate(use_idx):
    name = f"st_{int(s)}"

    ts = TimeSeries.from_times_and_values(
        times=time_index,
        values=Y_scaled[:, j:j+1].astype(np.float32),
        columns=[name],
        freq="3H",
    )

    lat_s, lon_s, alt_s = arr[int(s), STATIC_IDXS, 0].astype(np.float32).tolist()
    sc = pd.DataFrame(
        {"lat": [lat_s], "lon": [lon_s], "alt": [alt_s]},
        index=[name],
        dtype=np.float32,
    )

    ts = ts.with_static_covariates(sc)
    series_all.append(ts)

    pc = TimeSeries.from_times_and_values(
        times=time_index,
        values=X_scaled[:, j, :].astype(np.float32),
        columns=[f"{name}_{c}" for c in PAST_COV_COLS],
        freq="3H",
    ).with_static_covariates(sc)

    pcov_all.append(pc)

print("\nBuilt TimeSeries:")
print("Targets:", len(series_all))
print("Past covs:", len(pcov_all))


L, H = 56, 24
INTERNAL_VAL_STEPS = 240

def slice_list(xs, a, b):
    return [x[a:b] for x in xs]

train_series = slice_list(series_all, 0, cut_train)
train_pcov = slice_list(pcov_all, 0, cut_train)

iv_start = max(0, cut_train - INTERNAL_VAL_STEPS)
iv_end = cut_train

min_needed = L + H
if (iv_end - iv_start) < min_needed:
    iv_start = max(0, iv_end - min_needed)

val_series = slice_list(series_all, iv_start, iv_end)
val_pcov = slice_list(pcov_all, iv_start, iv_end)

print("\n=== INTERNAL validation (for early stopping only) ===")
print("Train len:", len(train_series[0]))
print("IntVal idx:", iv_start, "->", iv_end - 1, "| len =", len(val_series[0]))
print("IntVal time:", time_index[iv_start], "->", time_index[iv_end - 1])


MODEL_NAME = f"tft_weather2k_beijing_pooled_wd_sincos_L{L}_H{H}_seed{GLOBAL_SEED}"
LOG_ROOT = (RUNS_DIR / f"seed_{GLOBAL_SEED}").resolve()
LOG_ROOT.mkdir(parents=True, exist_ok=True)

tft = TFTModel(
    input_chunk_length=L,
    output_chunk_length=H,
    n_epochs=30,
    hidden_size=64,
    num_attention_heads=4,
    dropout=0.1,
    batch_size=32,
    optimizer_kwargs={"lr": 3e-4},
    add_relative_index=True,
    random_state=GLOBAL_SEED,
    force_reset=True,
    model_name=MODEL_NAME,
    work_dir=str(CKPT_DIR),
    save_checkpoints=True,
    pl_trainer_kwargs={
        "accelerator": "gpu" if torch.cuda.is_available() else "cpu",
        "devices": 1,
        "enable_progress_bar": True,
        "enable_model_summary": False,
        "enable_checkpointing": True,
        "callbacks": [EarlyStopping(monitor="val_loss", mode="min", patience=6)],
        "logger": CSVLogger(save_dir=str(LOG_ROOT), name=MODEL_NAME),
        "gradient_clip_val": 1.0,
    },
)

print("\n=== Training TFT ===")
tft.fit(
    series=train_series,
    past_covariates=train_pcov,
    val_series=val_series,
    val_past_covariates=val_pcov,
    verbose=True,
)

print("\n=== Loading best checkpoint ===")
tft = TFTModel.load_from_checkpoint(model_name=MODEL_NAME, work_dir=str(CKPT_DIR), best=True)
print("Loaded best checkpoint.")


def rolling_nonoverlap(model, series, pcov, start, end, L, H):
    yh, yt = [], []
    t0_list = []
    for t0 in range(start, end - H + 1, H):
        if t0 < L:
            continue
        ctx_s = [s[:t0] for s in series]
        ctx_p = [p[:t0] for p in pcov]
        preds = model.predict(n=H, series=ctx_s, past_covariates=ctx_p, verbose=False)
        preds = preds if isinstance(preds, list) else [preds]

        yh.append(np.stack([p.values(copy=False)[:, 0] for p in preds], axis=1).astype(np.float32))
        yt.append(np.stack([s[t0:t0+H].values(copy=False)[:, 0] for s in series], axis=1).astype(np.float32))
        t0_list.append(int(t0))

    if len(yh) == 0:
        raise RuntimeError("No rolling windows produced. Check start/end/L/H.")
    return np.stack(yh, axis=0), np.stack(yt, axis=0), t0_list

test_start = cut_val
start_ctx = max(test_start, L)
yhat_roll, ytrue_roll, t0_list = rolling_nonoverlap(tft, series_all, pcov_all, start_ctx, T, L, H)

W = yhat_roll.shape[0]
print("\n=== Rolling non-overlap (POOLED, TEST) ===")
print(f"windows={W} | step={H} | each window predicts H={H} | stations={N}")

diff = yhat_roll - ytrue_roll
rmse = float(np.sqrt(np.mean(diff ** 2)))
mae = float(np.mean(np.abs(diff)))
print("RMSE:", rmse)
print("MAE :", mae)


def qrisk(y_true, y_pred, q=0.5, eps=1e-8):
    y_true = np.asarray(y_true, dtype=np.float64)
    y_pred = np.asarray(y_pred, dtype=np.float64)
    e = y_true - y_pred
    return float(2.0 * np.sum(np.maximum(q * e, (q - 1) * e)) / (np.sum(np.abs(y_true)) + eps))

yhat_f = yhat_roll.reshape(-1, N)
ytrue_f = ytrue_roll.reshape(-1, N)

print("\n=== P50 q-risk (ROLLING, TEST) ===")
print("scaled  :", qrisk(ytrue_f, yhat_f))

ytrue_un = preprocessor.inverse_transform_targets(ytrue_f)
yhat_un = preprocessor.inverse_transform_targets(yhat_f)
print("unscaled:", qrisk(ytrue_un, yhat_un))


test_end = T
test_len = test_end - test_start

y_true_test = Y_scaled[test_start:test_end, :].astype(np.float32)

y_pred_test = np.full((test_len, N), np.nan, dtype=np.float32)
for w, t0 in enumerate(t0_list):
    a = t0 - test_start
    b = a + H
    if a < 0 or b > test_len:
        continue
    y_pred_test[a:b, :] = yhat_roll[w]

w0 = 0
t0_first = t0_list[w0]
dates_first = time_index[t0_first:t0_first + H]
dates_test = time_index[test_start:test_end]

def plot_two_figs_all_stations(out_dir: Path):
    out_dir.mkdir(parents=True, exist_ok=True)

    for j in range(N):
        station_id = int(use_idx[j])

        y_true_24 = ytrue_roll[w0][:, j]
        y_pred_24 = yhat_roll[w0][:, j]

        plt.figure(figsize=(10, 4), dpi=140)
        plt.plot(dates_first, y_true_24, "-o", linewidth=2, markersize=3, label="True")
        plt.plot(dates_first, y_pred_24, "-o", linewidth=2, markersize=3, label="Pred")
        plt.title(f"FIG1 Weather2K Beijing | st_{station_id} | TEST first {H} steps | t (scaled)")
        plt.xlabel("Time")
        plt.ylabel("Scaled t")
        plt.grid(alpha=0.3)
        plt.legend()
        plt.tight_layout()
        p1 = out_dir / f"FIG1_st{station_id}_test_firstH{H}_t0{t0_first}.png"
        plt.savefig(p1)
        plt.close()

        y_true_all = y_true_test[:, j]
        y_pred_all = y_pred_test[:, j]

        plt.figure(figsize=(12, 4), dpi=140)
        plt.plot(dates_test, y_true_all, "-", linewidth=1.8, label="True")
        plt.plot(dates_test, y_pred_all, "-", linewidth=1.8, label="Pred (stitched rolling)")
        plt.axvline(time_index[test_start], linestyle="--", linewidth=1, label="TEST start")
        plt.title(f"FIG2 Weather2K Beijing | st_{station_id} | ALL TEST | t (scaled)")
        plt.xlabel("Time")
        plt.ylabel("Scaled t")
        plt.grid(alpha=0.3)
        plt.legend()
        plt.tight_layout()
        p2 = out_dir / f"FIG2_st{station_id}_test_all_stitched_step{H}.png"
        plt.savefig(p2)
        plt.close()

    print(f"Saved 2 figs/station for {N} stations under: {out_dir}")

print("\n=== Plotting 2 figs per station ===")
plot_two_figs_all_stations(PLOTS_DIR)

print("\nAll done.")


  __import__("pkg_resources").declare_namespace(__name__)  # type: ignore


âœ… Loaded spatial_utils from: /home/wangxc1117/geospatial-neural-adapter/geospatial_neural_adapter/cpp_extensions/spatial_utils.so
Device: cuda
Weather2K shape: (1866, 13, 13632)
T steps: 13632 | 3-hour freq | days = 1704.0
Total stations: 1866
Beijing stations: 31
First 10 use_idx: [545 546 548 549 550 552 553 554 555 557]
lat range: 39.41999816894531 to 40.93000030517578
lon range: 115.5 to 117.47000122070312

=== Split (by time index) ===
Train: 0 -> 9541 | len = 9542
Val  : 9542 -> 11586 | len = 2045
Test : 11587 -> 13631 | len = 2045

=== After pooled scaling (train-only) ===
Y_scaled: (13632, 31) | finite: True
X_scaled: (13632, 31, 11) | finite: True

Built TimeSeries:
Targets: 31
Past covs: 31

=== INTERNAL validation (for early stopping only) ===
Train len: 9542
IntVal idx: 9302 -> 9541 | len = 240
IntVal time: 2003-03-08 18:00:00 -> 2003-04-07 15:00:00


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



=== Training TFT ===



   | Name                              | Type                             | Params | Mode 
------------------------------------------------------------------------------------------------
0  | train_metrics                     | MetricCollection                 | 0      | train
1  | val_metrics                       | MetricCollection                 | 0      | train
2  | input_embeddings                  | _MultiEmbedding                  | 0      | train
3  | static_covariates_vsn             | _VariableSelectionNetwork        | 5.0 K  | train
4  | encoder_vsn                       | _VariableSelectionNetwork        | 24.0 K | train
5  | decoder_vsn                       | _VariableSelectionNetwork        | 1.6 K  | train
6  | static_context_grn                | _GatedResidualNetwork            | 16.8 K | train
7  | static_context_hidden_encoder_grn | _GatedResidualNetwork            | 16.8 K | train
8  | static_context_cell_encoder_grn   | _GatedResidualNetwork            | 16.8 K 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.



=== Loading best checkpoint ===


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Loaded best checkpoint.


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU availa


=== Rolling non-overlap (POOLED, TEST) ===
windows=85 | step=24 | each window predicts H=24 | stations=31
RMSE: 0.3423219621181488
MAE : 0.26078614592552185

=== P50 q-risk (ROLLING, TEST) ===
scaled  : 0.2987890600211251
unscaled: 0.20997248703558258

=== Plotting 2 figs per station ===
Saved 2 figs/station for 31 stations under: /home/wangxc1117/TFTModel-use/geospatial-neural-adapter-dev/examples/try/weather2k/test/TFT_plots_weather2k_beijing_tft_pooled_wd_sincos_2figs

All done.
