In [1]:
import os
import random
from pathlib import Path

import numpy as np
import pandas as pd

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dense
from tensorflow.keras import regularizers

SEED = 2024
os.environ["PYTHONHASHSEED"] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

DATA_DIR = Path("/home/wangxc1117/study-DeepKriging/Space-Time.DeepKriging/simulation_2b-8/data")
FULL_FILE = "2b_8.csv"

EPOCHS = 350
BATCH_SIZE = 512
LR = 1e-3

OBS_RATIO = 0.1
N_REP = 10

GRID_SIZES = (5, 9, 12)
H_LIST = (10, 15, 45)

TAUS = [0.05, 0.25, 0.5, 0.75, 0.95]


def wendland_c2(d):
    out = np.zeros_like(d, dtype=np.float32)
    m = (d >= 0.0) & (d <= 1.0)
    dm = d[m]
    out[m] = ((1.0 - dm) ** 6) * (35.0 * dm**2 + 18.0 * dm + 3.0) / 3.0
    return out


def build_space_basis(s_xy, grid_sizes=GRID_SIZES, theta_scale=2.5):
    n = s_xy.shape[0]
    cols = []
    for g in grid_sizes:
        knots_1d = np.linspace(0.0, 1.0, g, dtype=np.float32)
        kx, ky = np.meshgrid(knots_1d, knots_1d)
        knots = np.column_stack([kx.ravel(), ky.ravel()]).astype(np.float32)

        spacing = 1.0 / (g - 1)
        theta = theta_scale * spacing

        phi = np.zeros((n, knots.shape[0]), dtype=np.float32)
        for j in range(knots.shape[0]):
            d = np.linalg.norm(s_xy - knots[j], axis=1) / (theta + 1e-12)
            phi[:, j] = wendland_c2(d)

        cols.append(phi)
    return np.concatenate(cols, axis=1)


def build_time_basis(t_norm, H_list=H_LIST):
    cols = []
    for H in H_list:
        knots = np.linspace(0.0, 1.0, H, dtype=np.float32)
        kappa = abs(knots[1] - knots[0])
        diff = (t_norm[:, None] - knots[None, :]) / (kappa + 1e-12)
        cols.append(np.exp(-0.5 * diff**2).astype(np.float32))
    return np.concatenate(cols, axis=1)


def remove_all_zero_columns(X):
    return X[:, (X != 0).any(axis=0)]


def tilted_loss(tau):
    def loss(y_true, y_pred):
        e = y_true - y_pred
        return tf.reduce_mean(tf.maximum(tau * e, (tau - 1.0) * e))
    return loss


def build_model_multi_quantile(input_dim, lr=LR):
    reg = regularizers.L1L2(l1=1e-5, l2=1e-4)

    inp = keras.Input(shape=(input_dim,), name="X")

    x = Dense(100, activation="relu", kernel_initializer="random_normal",
              kernel_regularizer=reg)(inp)
    x = Dense(100, activation="relu", kernel_initializer="random_normal",
              kernel_regularizer=reg)(x)

    for _ in range(6):
        x = Dense(100, activation="relu", kernel_initializer="random_normal")(x)

    for _ in range(4):
        x = Dense(50, activation="relu", kernel_initializer="random_normal")(x)

    outputs = {}
    losses = {}
    for tau in TAUS:
        name = f"q{str(tau).replace('.','_')}"
        outputs[name] = Dense(1, kernel_initializer="random_normal", name=name)(x)
        losses[name] = tilted_loss(float(tau))

    model = keras.Model(inputs=inp, outputs=outputs, name="STDK_MultiQuantile")

    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=lr),
        loss=losses
    )
    return model


def sample_observations_fixed_uniform(n_sites, obs_ratio, seed):
    rs = np.random.RandomState(seed)
    n_obs_sites = int(np.round(obs_ratio * n_sites))
    obs_sites = rs.choice(n_sites, size=n_obs_sites, replace=False)
    return np.sort(obs_sites)


def build_fixed_uniform_masks(df, obs_ratio=OBS_RATIO, seed=SEED + 12345):
    df = df.copy()
    df["site_id"] = pd.factorize(list(zip(df["x"].values, df["y"].values)))[0]
    n_sites = int(df["site_id"].nunique())

    obs_sites = sample_observations_fixed_uniform(n_sites, obs_ratio, seed)

    df["is_obs"] = df["site_id"].isin(obs_sites).to_numpy()
    train_mask = df["is_obs"].to_numpy()
    test_mask = (~df["is_obs"]).to_numpy()

    return train_mask, test_mask, obs_sites, n_sites


def crps_from_quantiles(y_true, preds_dict):
    y = y_true.reshape(-1).astype(np.float64)
    K = len(TAUS)
    w = 1.0 / K

    check_sum = np.zeros_like(y, dtype=np.float64)
    for tau in TAUS:
        key = f"q{str(tau).replace('.','_')}"
        q = preds_dict[key].reshape(-1).astype(np.float64)
        e = y - q
        check = np.maximum(tau * e, (tau - 1.0) * e)
        check_sum += check

    crps = 2.0 * w * check_sum
    return crps


df = pd.read_csv(DATA_DIR / FULL_FILE)
df = df[["x", "y", "t", "z"]].copy()

t_min = float(df["t"].min())
t_max = float(df["t"].max())
df["t_norm"] = ((df["t"].astype(np.float32) - t_min) / (t_max - t_min + 1e-12)).astype(np.float32)

train_mask, test_mask, obs_sites, n_sites = build_fixed_uniform_masks(
    df,
    obs_ratio=OBS_RATIO,
    seed=SEED + 54321
)

print(f"Fixed + Uniform | obs_ratio={OBS_RATIO}")
print(f"Observed sites: {len(obs_sites)} / {n_sites}")
print(f"Observed samples: {int(train_mask.sum())} / {df.shape[0]} ({train_mask.mean()*100:.1f}%)")
print(f"Test samples: {int(test_mask.sum())} / {df.shape[0]} ({test_mask.mean()*100:.1f}%)")

s_xy_all = df[["x", "y"]].to_numpy(dtype=np.float32)
t_norm_all = df["t_norm"].to_numpy(dtype=np.float32)

phi_space_all = build_space_basis(s_xy_all)
phi_time_all = build_time_basis(t_norm_all)

X_all = remove_all_zero_columns(np.concatenate([phi_space_all, phi_time_all], axis=1)).astype(np.float32)
y_all = df["z"].to_numpy(dtype=np.float32).reshape(-1, 1)

X_train = X_all[train_mask]
y_train = y_all[train_mask]

X_test = X_all[test_mask]
y_test = y_all[test_mask]

y_train_dict = {f"q{str(t).replace('.','_')}": y_train for t in TAUS}

crps_runs = []

for rep in range(N_REP):
    tf.random.set_seed(SEED + 1000 + rep)
    np.random.seed(SEED + 1000 + rep)
    random.seed(SEED + 1000 + rep)

    model = build_model_multi_quantile(X_all.shape[1], lr=LR)

    model.fit(
        X_train,
        y_train_dict,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        verbose=1
    )

    preds = model.predict(X_test, verbose=0)
    crps_vec = crps_from_quantiles(y_test, preds)
    crps_mean = float(np.mean(crps_vec))

    crps_runs.append(crps_mean)
    print(f"[rep {rep+1:02d}/{N_REP}] CRPS = {crps_mean:.6f}")

crps_runs = np.array(crps_runs, dtype=np.float64)

print("\n=== Table 4.4 style reporting (STDK only) ===")
print(f"CRPS mean = {crps_runs.mean():.6f}")
print(f"CRPS SD   = {crps_runs.std(ddof=1):.6f}")


2026-01-29 23:14:16.961880: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-29 23:14:17.264464: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-01-29 23:14:17.264544: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-01-29 23:14:17.312597: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-01-29 23:14:17.411631: I tensorflow/core/platform/cpu_feature_guar

Fixed + Uniform | obs_ratio=0.1
Observed sites: 1000 / 10000
Observed samples: 100000 / 1000000 (10.0%)
Test samples: 900000 / 1000000 (90.0%)


2026-01-29 23:14:27.632023: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:887] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2026-01-29 23:14:27.965888: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2256] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


Epoch 1/350
Epoch 2/350
Epoch 3/350
Epoch 4/350
Epoch 5/350
Epoch 6/350
Epoch 7/350
Epoch 8/350
Epoch 9/350
Epoch 10/350
Epoch 11/350
Epoch 12/350
Epoch 13/350
Epoch 14/350
Epoch 15/350
Epoch 16/350
Epoch 17/350
Epoch 18/350
Epoch 19/350
Epoch 20/350
Epoch 21/350
Epoch 22/350
Epoch 23/350
Epoch 24/350
Epoch 25/350
Epoch 26/350
Epoch 27/350
Epoch 28/350
Epoch 29/350
Epoch 30/350
Epoch 31/350
Epoch 32/350
Epoch 33/350
Epoch 34/350
Epoch 35/350
Epoch 36/350
Epoch 37/350
Epoch 38/350
Epoch 39/350
Epoch 40/350
Epoch 41/350
Epoch 42/350
Epoch 43/350
Epoch 44/350
Epoch 45/350
Epoch 46/350
Epoch 47/350
Epoch 48/350
Epoch 49/350
Epoch 50/350
Epoch 51/350
Epoch 52/350
Epoch 53/350
Epoch 54/350
Epoch 55/350
Epoch 56/350
Epoch 57/350
Epoch 58/350
Epoch 59/350
Epoch 60/350
Epoch 61/350
Epoch 62/350
Epoch 63/350
Epoch 64/350
Epoch 65/350
Epoch 66/350
Epoch 67/350
Epoch 68/350
Epoch 69/350
Epoch 70/350
Epoch 71/350
Epoch 72/350
Epoch 73/350
Epoch 74/350
Epoch 75/350
Epoch 76/350
Epoch 77/350
Epoch 78