In [2]:
import os
import gc
import random
from pathlib import Path

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dense
from tensorflow.keras import regularizers


os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


# Global settings & dirs
BASE_SEED = 2024
os.environ["PYTHONHASHSEED"] = str(BASE_SEED)

DATA_PATH = Path("/home/wangxc1117/Weather2K/weather2k.npy")

TARGET_IDX = 4
OBS_RATIO = 0.10
EPOCHS = 300
BATCH_SIZE = 512
LR = 1e-3

GRID_SIZES = (5, 9, 12)
H_LIST = (10, 15, 45)

STATIONS_PER_CHUNK = 100
T_KEEP = 100

VAL_STATION_RATIO = 0.10
PATIENCE = 30

TAUS = [0.05, 0.25, 0.5, 0.75, 0.95]

N_REP = 10
SEED_OFFSET = 1000


tf.get_logger().setLevel("ERROR")
gpus = tf.config.list_physical_devices("GPU")
if gpus:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
print("GPUs:", gpus)


def wendland_c2(d):
    out = np.zeros_like(d, dtype=np.float32)
    m = (d >= 0.0) & (d <= 1.0)
    dm = d[m]
    out[m] = ((1.0 - dm) ** 6) * (35.0 * dm**2 + 18.0 * dm + 3.0) / 3.0
    return out


def build_space_basis(s_xy, grid_sizes=GRID_SIZES, theta_scale=2.5):
    n = s_xy.shape[0]
    cols = []
    for g in grid_sizes:
        knots_1d = np.linspace(0.0, 1.0, g, dtype=np.float32)
        kx, ky = np.meshgrid(knots_1d, knots_1d)
        knots = np.column_stack([kx.ravel(), ky.ravel()]).astype(np.float32)

        spacing = 1.0 / (g - 1)
        theta = theta_scale * spacing

        phi = np.zeros((n, knots.shape[0]), dtype=np.float32)
        for j in range(knots.shape[0]):
            d = np.linalg.norm(s_xy - knots[j], axis=1) / (theta + 1e-12)
            phi[:, j] = wendland_c2(d)

        cols.append(phi)
    return np.concatenate(cols, axis=1)


def build_time_basis(t_norm, H_list=H_LIST):
    cols = []
    for H in H_list:
        knots = np.linspace(0.0, 1.0, H, dtype=np.float32)
        kappa = abs(knots[1] - knots[0]) if H >= 2 else 1.0
        diff = (t_norm[:, None] - knots[None, :]) / (kappa + 1e-12)
        cols.append(np.exp(-0.5 * diff**2).astype(np.float32))
    return np.concatenate(cols, axis=1)


def tilted_loss(tau):
    tau = float(tau)

    def loss(y_true, y_pred):
        e = y_true - y_pred
        return tf.reduce_mean(tf.maximum(tau * e, (tau - 1.0) * e))

    return loss


def build_model_multi_quantile(input_dim):
    reg = regularizers.L1L2(l1=1e-5, l2=1e-4)

    inp = keras.Input(shape=(input_dim,), name="X")
    x = Dense(100, activation="relu", kernel_initializer="random_normal", kernel_regularizer=reg)(inp)
    x = Dense(100, activation="relu", kernel_initializer="random_normal", kernel_regularizer=reg)(x)

    for _ in range(6):
        x = Dense(100, activation="relu", kernel_initializer="random_normal")(x)

    x = Dense(50, activation="relu", kernel_initializer="random_normal")(x)
    x = Dense(50, activation="relu", kernel_initializer="random_normal")(x)

    outputs = {}
    losses = {}
    for tau in TAUS:
        name = f"q{str(tau).replace('.','_')}"
        outputs[name] = Dense(1, kernel_initializer="random_normal", name=name)(x)
        losses[name] = tilted_loss(tau)

    model = keras.Model(inputs=inp, outputs=outputs, name="STDK_MultiQuantile")
    model.compile(optimizer=keras.optimizers.Adam(learning_rate=LR), loss=losses)
    return model


def crps_from_quantiles_weighted(y_true, preds_dict, taus=TAUS):
    y = np.asarray(y_true).reshape(-1).astype(np.float64)

    taus = np.asarray(taus, dtype=np.float64)
    taus = np.sort(taus)

    w = np.zeros_like(taus)
    w[0] = 0.5 * (taus[1] - taus[0])
    w[-1] = 0.5 * (taus[-1] - taus[-2])
    w[1:-1] = 0.5 * (taus[2:] - taus[:-2])

    check_sum = np.zeros_like(y, dtype=np.float64)
    for tau, wk in zip(taus, w):
        key = f"q{str(tau).replace('.','_')}"
        q = np.asarray(preds_dict[key]).reshape(-1).astype(np.float64)
        e = y - q
        check = np.maximum(tau * e, (tau - 1.0) * e)
        check_sum += wk * check

    return 2.0 * check_sum


def crossing_rate_from_preds(preds_dict, taus=TAUS):
    taus = np.asarray(sorted(taus), dtype=np.float64)
    keys = [f"q{str(t).replace('.','_')}" for t in taus]

    Q = np.concatenate(
        [np.asarray(preds_dict[k]).reshape(-1, 1).astype(np.float64) for k in keys],
        axis=1
    )

    dQ = np.diff(Q, axis=1)
    mono = np.all(dQ >= 0.0, axis=1)
    crossing_rate = float(1.0 - np.mean(mono))
    return crossing_rate


def build_Xy_for_stations(station_ids, Y_used, phi_space_st, phi_time_t, Ds, D, T):
    station_ids = np.asarray(station_ids, dtype=np.int64)
    n_st = station_ids.shape[0]

    X = np.empty((n_st * T, D), dtype=np.float32)
    y = Y_used[station_ids].reshape(-1).astype(np.float32)

    time_part = phi_time_t
    for i, sid in enumerate(station_ids):
        r0 = i * T
        r1 = (i + 1) * T
        X[r0:r1, :Ds] = phi_space_st[sid][None, :]
        X[r0:r1, Ds:] = time_part

    return X, y


arr = np.load(DATA_PATH).astype(np.float32)
S, V, T_full = arr.shape
print("Weather2K shape:", arr.shape)

lat = arr[:, 0, 0]
lon = arr[:, 1, 0]
Y_full = arr[:, TARGET_IDX, :]

if T_full < T_KEEP:
    raise ValueError(f"T_full={T_full} < T_KEEP={T_KEEP}")

Y = Y_full[:, -T_KEEP:]
T = Y.shape[1]
print(f"Using last {T_KEEP} timesteps | Y shape: {Y.shape}")

lat_n = (lat - lat.min()) / (lat.max() - lat.min() + 1e-12)
lon_n = (lon - lon.min()) / (lon.max() - lon.min() + 1e-12)
s_xy = np.column_stack([lat_n, lon_n]).astype(np.float32)

t_idx = np.arange(T, dtype=np.float32)
t_norm = (t_idx - t_idx.min()) / (t_idx.max() - t_idx.min() + 1e-12)

print("\n[Stage 3] Precompute bases ...")
phi_space_st_full = build_space_basis(s_xy)
phi_time_t_full = build_time_basis(t_norm)

space_keep = (phi_space_st_full != 0).any(axis=0)
time_keep = (phi_time_t_full != 0).any(axis=0)

phi_space_st = phi_space_st_full[:, space_keep].astype(np.float32)
phi_time_t = phi_time_t_full[:, time_keep].astype(np.float32)

Ds = phi_space_st.shape[1]
Dt = phi_time_t.shape[1]
D = Ds + Dt

print(f"Space basis dim: {Ds}")
print(f"Time  basis dim: {Dt}")
print(f"Total embedding dim: {D}")


cross_list = []
crps_z_list = []
crps_raw_list = []

for rep in range(1, N_REP + 1):
    rep_seed = BASE_SEED + SEED_OFFSET + rep

    random.seed(rep_seed)
    np.random.seed(rep_seed)
    tf.random.set_seed(rep_seed)

    split_rs = np.random.RandomState(rep_seed)
    n_obs = int(np.round(OBS_RATIO * S))
    obs_sites = np.sort(split_rs.choice(S, size=n_obs, replace=False))

    is_obs_station = np.zeros(S, dtype=bool)
    is_obs_station[obs_sites] = True
    miss_sites = np.where(~is_obs_station)[0]

    n_val = max(1, int(np.round(VAL_STATION_RATIO * n_obs)))
    perm = split_rs.permutation(n_obs)
    val_sites = np.sort(obs_sites[perm[:n_val]])
    train_sites = np.sort(obs_sites[perm[n_val:]])

    y_train_raw = Y[train_sites].reshape(-1).astype(np.float32)
    y_mu = float(np.mean(y_train_raw))
    y_sd = float(np.std(y_train_raw) + 1e-12)
    Yz = (Y - y_mu) / y_sd

    X_train, y_train_z = build_Xy_for_stations(train_sites, Yz, phi_space_st, phi_time_t, Ds, D, T)
    X_val, y_val_z = build_Xy_for_stations(val_sites, Yz, phi_space_st, phi_time_t, Ds, D, T)

    y_train_dict = {f"q{str(t).replace('.','_')}": y_train_z for t in TAUS}
    y_val_dict = {f"q{str(t).replace('.','_')}": y_val_z for t in TAUS}

    n_chunks = int(np.ceil(len(miss_sites) / STATIONS_PER_CHUNK))

    keras.backend.clear_session()
    model = build_model_multi_quantile(D)

    model.fit(
        X_train,
        y_train_dict,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        validation_data=(X_val, y_val_dict),
        verbose=0,
        callbacks=[
            keras.callbacks.EarlyStopping(
                monitor="val_loss",
                patience=PATIENCE,
                restore_best_weights=True
            )
        ]
    )

    sum_crps_z = 0.0
    sum_crps_raw = 0.0
    sum_cross = 0.0
    count = 0

    for ci in range(n_chunks):
        a = ci * STATIONS_PER_CHUNK
        b = min((ci + 1) * STATIONS_PER_CHUNK, len(miss_sites))
        chunk_sites = miss_sites[a:b]

        X_chunk, y_true_z = build_Xy_for_stations(chunk_sites, Yz, phi_space_st, phi_time_t, Ds, D, T)
        preds_z = model.predict(X_chunk, batch_size=BATCH_SIZE, verbose=0)

        crps_vec_z = crps_from_quantiles_weighted(y_true_z, preds_z, taus=TAUS)
        sum_crps_z += float(np.sum(crps_vec_z))

        y_true_raw = (y_true_z.astype(np.float64) * y_sd + y_mu)
        preds_raw = {}
        for tau in TAUS:
            k = f"q{str(tau).replace('.','_')}"
            preds_raw[k] = (np.asarray(preds_z[k]).reshape(-1).astype(np.float64) * y_sd + y_mu)

        crps_vec_raw = crps_from_quantiles_weighted(y_true_raw, preds_raw, taus=TAUS)
        sum_crps_raw += float(np.sum(crps_vec_raw))

        cross_rate_chunk = crossing_rate_from_preds(preds_z, taus=TAUS)
        sum_cross += float(cross_rate_chunk) * float(y_true_z.size)

        count += int(y_true_z.size)

        del X_chunk

    crps_mean_z = sum_crps_z / count
    crps_mean_raw = sum_crps_raw / count
    crossing_rate = sum_cross / count

    cross_list.append(crossing_rate)
    crps_z_list.append(crps_mean_z)
    crps_raw_list.append(crps_mean_raw)

    print(
        f"[Rep {rep:02d}/{N_REP}] "
        f"obs={len(obs_sites)} train={len(train_sites)} val={len(val_sites)} miss={len(miss_sites)} | "
        f"Crossing rate = {crossing_rate:.6f}"
    )

    del model
    del X_train, X_val
    gc.collect()


cross_arr = np.asarray(cross_list, dtype=np.float64)
crps_z_arr = np.asarray(crps_z_list, dtype=np.float64)
crps_raw_arr = np.asarray(crps_raw_list, dtype=np.float64)

print("\n=== Summary over repetitions (resample split each rep) ===")
print(f"Crossing rate mean = {float(np.mean(cross_arr)):.6f}")
print(f"Crossing rate SD   = {float(np.std(cross_arr, ddof=1)):.6f}" if N_REP > 1 else "Crossing rate SD   = NA")
print(f"CRPS_z   mean = {float(np.mean(crps_z_arr)):.6f} | SD = {float(np.std(crps_z_arr, ddof=1)):.6f}" if N_REP > 1 else f"CRPS_z   mean = {float(np.mean(crps_z_arr)):.6f} | SD = NA")
print(f"CRPS_raw mean = {float(np.mean(crps_raw_arr)):.6f} | SD = {float(np.std(crps_raw_arr, ddof=1)):.6f}" if N_REP > 1 else f"CRPS_raw mean = {float(np.mean(crps_raw_arr)):.6f} | SD = NA")


GPUs: []
Weather2K shape: (1866, 13, 13632)
Using last 100 timesteps | Y shape: (1866, 100)

[Stage 3] Precompute bases ...
Space basis dim: 238
Time  basis dim: 70
Total embedding dim: 308
[Rep 01/10] obs=187 train=168 val=19 miss=1679 | Crossing rate = 0.000000
[Rep 02/10] obs=187 train=168 val=19 miss=1679 | Crossing rate = 0.000000
[Rep 03/10] obs=187 train=168 val=19 miss=1679 | Crossing rate = 0.000000
[Rep 04/10] obs=187 train=168 val=19 miss=1679 | Crossing rate = 0.000000
[Rep 05/10] obs=187 train=168 val=19 miss=1679 | Crossing rate = 0.000000
[Rep 06/10] obs=187 train=168 val=19 miss=1679 | Crossing rate = 0.000000
[Rep 07/10] obs=187 train=168 val=19 miss=1679 | Crossing rate = 0.000000
[Rep 08/10] obs=187 train=168 val=19 miss=1679 | Crossing rate = 0.000000
[Rep 09/10] obs=187 train=168 val=19 miss=1679 | Crossing rate = 0.000000
[Rep 10/10] obs=187 train=168 val=19 miss=1679 | Crossing rate = 0.000000

=== Summary over repetitions (resample split each rep) ===
Crossing r