In [1]:
# ============================================================
# Paper-aligned Space-Time.DeepKriging (Table 1 replication)
# - Point prediction with MSE loss
# - 10-fold cross validation (45,000 train / 5,000 test per fold)
# - Proper validation split inside training set (early stopping)
# Output:
#   - Per-fold MSPE_nonstat and MSPE_z
#   - Mean/Std/SEmean across 10 folds
# ============================================================

import os
import random
from pathlib import Path

import numpy as np
import pandas as pd
from sklearn.metrics import mean_squared_error

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras import regularizers
from tensorflow.keras.callbacks import EarlyStopping


# --------------------
# Global settings
# --------------------
SEED = 2024
os.environ["PYTHONHASHSEED"] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

DATA_DIR = Path("../synthetic_ds")
LOC_FILE = "LOC_50000_univariate_spacetime_matern_stationary_1"
Z_FILE = "Z1_50000_univariate_spacetime_matern_stationary_1"

EPOCHS = 350
BATCH_SIZE = 512
PATIENCE = 30
LR = 1e-3

N_TOTAL = 50000
TRAIN_SIZE = 45000
TEST_SIZE = 5000
VAL_FRAC = 0.1
K_FOLDS = 10


# --------------------
# Basis functions
# --------------------
def wendland_c2(d):
    out = np.zeros_like(d, dtype=np.float32)
    m = (d >= 0.0) & (d <= 1.0)
    dm = d[m]
    out[m] = ((1.0 - dm) ** 6) * (35.0 * dm**2 + 18.0 * dm + 3.0) / 3.0
    return out


def build_space_basis(s_xy, grid_sizes=(5, 9, 12), theta_scale=2.5):
    n = s_xy.shape[0]
    cols = []

    for g in grid_sizes:
        knots_1d = np.linspace(0.0, 1.0, g, dtype=np.float32)
        kx, ky = np.meshgrid(knots_1d, knots_1d)
        knots = np.column_stack([kx.ravel(), ky.ravel()]).astype(np.float32)

        spacing = 1.0 / float(g - 1)
        theta = theta_scale * spacing

        nb = knots.shape[0]
        phi = np.zeros((n, nb), dtype=np.float32)

        for j in range(nb):
            d = np.linalg.norm(s_xy - knots[j], axis=1) / (theta + 1e-12)
            phi[:, j] = wendland_c2(d)

        cols.append(phi)

    return np.concatenate(cols, axis=1)


def build_time_basis(t_norm, H_list=(10, 15, 45)):
    cols = []

    for H in H_list:
        knots = np.linspace(0.0, 1.0, H, dtype=np.float32)
        kappa = float(abs(knots[1] - knots[0])) if H >= 2 else 1.0
        diff = (t_norm[:, None] - knots[None, :]) / (kappa + 1e-12)
        cols.append(np.exp(-0.5 * (diff**2)).astype(np.float32))

    return np.concatenate(cols, axis=1)


def remove_all_zero_columns(X):
    return X[:, (X != 0).any(axis=0)]


# --------------------
# Model (MSE)
# --------------------
def build_model(input_dim):
    reg = regularizers.L1L2(l1=1e-5, l2=1e-4)

    model = Sequential([
        Dense(100, activation="relu", kernel_initializer="random_normal",
              kernel_regularizer=reg, input_dim=input_dim),
        Dense(100, activation="relu", kernel_initializer="random_normal",
              kernel_regularizer=reg),

        Dense(100, activation="relu", kernel_initializer="random_normal"),
        Dense(100, activation="relu", kernel_initializer="random_normal"),
        Dense(100, activation="relu", kernel_initializer="random_normal"),
        Dense(100, activation="relu", kernel_initializer="random_normal"),
        Dense(100, activation="relu", kernel_initializer="random_normal"),
        Dense(100, activation="relu", kernel_initializer="random_normal"),

        Dense(50, activation="relu", kernel_initializer="random_normal"),
        Dense(50, activation="relu", kernel_initializer="random_normal"),
        Dense(50, activation="relu", kernel_initializer="random_normal"),
        Dense(50, activation="relu", kernel_initializer="random_normal"),

        Dense(1, kernel_initializer="random_normal"),
    ])

    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=LR),
        loss="mse"
    )
    return model


# --------------------
# Load data & construct targets
# --------------------
loc = pd.read_csv(DATA_DIR / LOC_FILE, header=None, names=["x", "y", "t"])
z = pd.read_csv(DATA_DIR / Z_FILE, header=None, names=["z"])
df = loc.join(z)

s = df["t"].astype(np.float32) / 1000.0
mu = (
    2.0 * np.sin(15.0 * (s - 0.9))
    * np.cos(-37.0 * (s - 0.9) ** 4)
    + (s - 0.9) / 2.0
).to_numpy(dtype=np.float32)

z_obs = df["z"].to_numpy(dtype=np.float32)
y_nonstat = z_obs + mu


# --------------------
# Build embedding X = [phi_space, phi_time] (stacked)
# --------------------
s_xy = df[["x", "y"]].to_numpy(dtype=np.float32)

t_raw = df["t"].to_numpy(dtype=np.float32)
t_norm = (t_raw - t_raw.min()) / (t_raw.max() - t_raw.min() + 1e-12)
t_norm = t_norm.astype(np.float32)

phi_space = build_space_basis(s_xy, grid_sizes=(5, 9, 12), theta_scale=2.5)
phi_time = build_time_basis(t_norm, H_list=(10, 15, 45))

X_all = remove_all_zero_columns(
    np.concatenate([phi_space, phi_time], axis=1)
).astype(np.float32)

if X_all.shape[0] != N_TOTAL:
    raise ValueError(f"Expected N={N_TOTAL}, got {X_all.shape[0]}")


# --------------------
# Splits: proper validation inside each fold
# --------------------
def train_val_split(train_idx, val_frac, seed):
    rs = np.random.RandomState(seed)
    perm = rs.permutation(train_idx)
    n_val = int(len(perm) * val_frac)
    val_idx = perm[:n_val]
    fit_idx = perm[n_val:]
    return fit_idx, val_idx


def make_kfold_indices(n, k, seed):
    rs = np.random.RandomState(seed)
    perm = rs.permutation(n)
    return np.array_split(perm, k)


# --------------------
# 10-fold CV evaluation (Table 1 protocol)
# --------------------
def run_kfold_cv_mspe(X_all, y_nonstat, z_obs, mu,
                      k_folds=10, train_size=45000, test_size=5000,
                      val_frac=0.1, base_seed=2024):

    folds = make_kfold_indices(len(X_all), k_folds, base_seed + 10000)

    mspe_nonstat_list = []
    mspe_z_list = []

    for k in range(k_folds):
        te = folds[k]
        tr_full = np.concatenate([folds[j] for j in range(k_folds) if j != k], axis=0)

        if tr_full.shape[0] != train_size or te.shape[0] != test_size:
            raise ValueError(f"Fold size mismatch: train={tr_full.shape[0]}, test={te.shape[0]}")

        fit_idx, val_idx = train_val_split(tr_full, val_frac, base_seed + 20000 + k)

        tf.random.set_seed(base_seed + 30000 + k)
        np.random.seed(base_seed + 30000 + k)
        random.seed(base_seed + 30000 + k)

        model = build_model(X_all.shape[1])
        model.fit(
            X_all[fit_idx], y_nonstat[fit_idx],
            validation_data=(X_all[val_idx], y_nonstat[val_idx]),
            epochs=EPOCHS,
            batch_size=BATCH_SIZE,
            verbose=1,
            callbacks=[EarlyStopping(
                monitor="val_loss",
                patience=PATIENCE,
                restore_best_weights=True
            )]
        )

        y_hat = model.predict(X_all[te], verbose=0).ravel()
        z_hat = y_hat - mu[te]

        mspe_nonstat = mean_squared_error(y_nonstat[te], y_hat)
        mspe_z = mean_squared_error(z_obs[te], z_hat)

        mspe_nonstat_list.append(mspe_nonstat)
        mspe_z_list.append(mspe_z)

        print(
            f"[Fold {k + 1:02d}/{k_folds}] "
            f"MSPE_nonstat={mspe_nonstat:.6f} | "
            f"MSPE_z={mspe_z:.6f}"
        )

    mspe_nonstat_arr = np.array(mspe_nonstat_list, dtype=np.float64)
    mspe_z_arr = np.array(mspe_z_list, dtype=np.float64)

    se_nonstat = mspe_nonstat_arr.std(ddof=1) / np.sqrt(k_folds)
    se_z = mspe_z_arr.std(ddof=1) / np.sqrt(k_folds)

    print("\n=== Summary (10-fold CV) ===")
    print(
        f"Mean MSPE_nonstat = {mspe_nonstat_arr.mean():.6f} "
        f"(Std={mspe_nonstat_arr.std(ddof=1):.6f}, SEmean={se_nonstat:.6f})"
    )
    print(
        f"Mean MSPE_z       = {mspe_z_arr.mean():.6f} "
        f"(Std={mspe_z_arr.std(ddof=1):.6f}, SEmean={se_z:.6f})"
    )

    return mspe_nonstat_arr, mspe_z_arr


mspe_nonstat_arr, mspe_z_arr = run_kfold_cv_mspe(
    X_all, y_nonstat, z_obs, mu,
    k_folds=K_FOLDS,
    train_size=TRAIN_SIZE,
    test_size=TEST_SIZE,
    val_frac=VAL_FRAC,
    base_seed=SEED
)


2026-01-28 12:18:08.389164: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-28 12:18:08.419575: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-01-28 12:18:08.419610: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-01-28 12:18:08.420350: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-01-28 12:18:08.425181: I tensorflow/core/platform/cpu_feature_guar

Epoch 1/350
Epoch 2/350
Epoch 3/350
Epoch 4/350
Epoch 5/350
Epoch 6/350
Epoch 7/350
Epoch 8/350
Epoch 9/350
Epoch 10/350
Epoch 11/350
Epoch 12/350
Epoch 13/350
Epoch 14/350
Epoch 15/350
Epoch 16/350
Epoch 17/350
Epoch 18/350
Epoch 19/350
Epoch 20/350
Epoch 21/350
Epoch 22/350
Epoch 23/350
Epoch 24/350
Epoch 25/350
Epoch 26/350
Epoch 27/350
Epoch 28/350
Epoch 29/350
Epoch 30/350
Epoch 31/350
Epoch 32/350
Epoch 33/350
Epoch 34/350
Epoch 35/350
Epoch 36/350
Epoch 37/350
Epoch 38/350
Epoch 39/350
Epoch 40/350
Epoch 41/350
Epoch 42/350
Epoch 43/350
Epoch 44/350
Epoch 45/350
Epoch 46/350
Epoch 47/350
Epoch 48/350
Epoch 49/350
Epoch 50/350
Epoch 51/350
Epoch 52/350
Epoch 53/350
Epoch 54/350
Epoch 55/350
Epoch 56/350
Epoch 57/350
Epoch 58/350
Epoch 59/350
Epoch 60/350
Epoch 61/350
Epoch 62/350
Epoch 63/350
Epoch 64/350
Epoch 65/350
Epoch 66/350
Epoch 67/350
Epoch 68/350
Epoch 69/350
Epoch 70/350
Epoch 71/350
Epoch 72/350
Epoch 73/350
Epoch 74/350
Epoch 75/350
Epoch 76/350
Epoch 77/350
Epoch 78