In [1]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import gc
import random
from pathlib import Path

import numpy as np
import pandas as pd
import xarray as xr

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dense
from tensorflow.keras import regularizers


BASE_SEED = 2024
os.environ["PYTHONHASHSEED"] = str(BASE_SEED)

DATA_DIR = Path("/home/wangxc1117/surface_air_temperature/data2024")

OBS_RATIO = 0.10
VAL_STATION_RATIO = 0.10

T_KEEP = 100
N_FILES = 5

GRID_SIZES = (5, 9, 12)
H_LIST = (10, 15, 45)

EPOCHS = 300
BATCH_SIZE = 512
LR = 1e-3
PATIENCE = 30

TAUS = [0.05, 0.25, 0.5, 0.75, 0.95]

N_REP = 10
SEED_OFFSET = 1000

STATIONS_PER_CHUNK = 128


tf.get_logger().setLevel("ERROR")
gpus = tf.config.list_physical_devices("GPU")
if gpus:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
print("GPUs:", gpus)


ENGINES = ["netcdf4", "h5netcdf", "scipy"]
def open_dataset_robust(path: Path):
    last_err = None
    for eng in ENGINES:
        try:
            return xr.open_dataset(str(path), engine=eng, decode_cf=True, mask_and_scale=True)
        except Exception as e:
            last_err = e
    raise RuntimeError(f"Failed: {path.name} -> {last_err}")


def pick_last_n_files(data_dir: Path, n_files: int):
    files = sorted([p for p in data_dir.iterdir() if p.is_file()])
    if len(files) == 0:
        raise RuntimeError(f"No files in {data_dir}")
    if len(files) < n_files:
        return files
    return files[-n_files:]


def wendland_c2_vec(d):
    out = np.zeros_like(d, dtype=np.float32)
    m = (d >= 0.0) & (d <= 1.0)
    dm = d[m]
    out[m] = ((1.0 - dm) ** 6) * (35.0 * dm**2 + 18.0 * dm + 3.0) / 3.0
    return out


def build_space_basis_fast(s_xy, grid_sizes=GRID_SIZES, theta_scale=2.5, chunk=8192):
    n = s_xy.shape[0]
    out_cols = []

    for g in grid_sizes:
        knots_1d = np.linspace(0.0, 1.0, g, dtype=np.float32)
        kx, ky = np.meshgrid(knots_1d, knots_1d)
        knots = np.column_stack([kx.ravel(), ky.ravel()]).astype(np.float32)
        K = knots.shape[0]

        spacing = 1.0 / (g - 1)
        theta = theta_scale * spacing

        phi = np.empty((n, K), dtype=np.float32)

        for a in range(0, n, chunk):
            b = min(a + chunk, n)
            X = s_xy[a:b]
            dx = X[:, 0:1] - knots[None, :, 0]
            dy = X[:, 1:2] - knots[None, :, 1]
            dist = np.sqrt(dx * dx + dy * dy) / (theta + 1e-12)
            phi[a:b] = wendland_c2_vec(dist)

        out_cols.append(phi)

    return np.concatenate(out_cols, axis=1)


def build_time_basis(t_norm, H_list=H_LIST):
    cols = []
    for H in H_list:
        knots = np.linspace(0.0, 1.0, H, dtype=np.float32)
        kappa = abs(knots[1] - knots[0]) if H >= 2 else 1.0
        diff = (t_norm[:, None] - knots[None, :]) / (kappa + 1e-12)
        cols.append(np.exp(-0.5 * diff**2).astype(np.float32))
    return np.concatenate(cols, axis=1)


def tilted_loss(tau):
    tau = float(tau)
    def loss(y_true, y_pred):
        e = y_true - y_pred
        return tf.reduce_mean(tf.maximum(tau * e, (tau - 1.0) * e))
    return loss


def build_model_multi_quantile(input_dim):
    reg = regularizers.L1L2(l1=1e-5, l2=1e-4)

    inp = keras.Input(shape=(input_dim,), name="X")
    x = Dense(100, activation="relu", kernel_initializer="random_normal", kernel_regularizer=reg)(inp)
    x = Dense(100, activation="relu", kernel_initializer="random_normal", kernel_regularizer=reg)(x)

    for _ in range(6):
        x = Dense(100, activation="relu", kernel_initializer="random_normal")(x)

    x = Dense(50, activation="relu", kernel_initializer="random_normal")(x)
    x = Dense(50, activation="relu", kernel_initializer="random_normal")(x)

    outputs = {}
    losses = {}
    for tau in TAUS:
        name = f"q{str(tau).replace('.','_')}"
        outputs[name] = Dense(1, kernel_initializer="random_normal", name=name)(x)
        losses[name] = tilted_loss(tau)

    model = keras.Model(inputs=inp, outputs=outputs, name="STDK_MultiQuantile")
    model.compile(optimizer=keras.optimizers.Adam(learning_rate=LR), loss=losses)
    return model


def crps_from_quantiles_weighted(y_true, preds_dict, taus=TAUS):
    y = np.asarray(y_true).reshape(-1).astype(np.float64)

    taus = np.asarray(taus, dtype=np.float64)
    taus = np.sort(taus)

    w = np.zeros_like(taus)
    w[0] = 0.5 * (taus[1] - taus[0])
    w[-1] = 0.5 * (taus[-1] - taus[-2])
    w[1:-1] = 0.5 * (taus[2:] - taus[:-2])

    check_sum = np.zeros_like(y, dtype=np.float64)
    for tau, wk in zip(taus, w):
        key = f"q{str(tau).replace('.','_')}"
        q = np.asarray(preds_dict[key]).reshape(-1).astype(np.float64)
        e = y - q
        check = np.maximum(tau * e, (tau - 1.0) * e)
        check_sum += wk * check

    return 2.0 * check_sum


def crossing_rate_from_preds(preds_dict, taus=TAUS):
    taus = np.asarray(sorted(taus), dtype=np.float64)
    keys = [f"q{str(t).replace('.','_')}" for t in taus]
    Q = np.concatenate([np.asarray(preds_dict[k]).reshape(-1, 1).astype(np.float64) for k in keys], axis=1)
    dQ = np.diff(Q, axis=1)
    mono = np.all(dQ >= 0.0, axis=1)
    return float(1.0 - np.mean(mono))


def build_Xy_stationchunk_vectorized(station_ids, Yz, phi_space, phi_time, Ds, D, T):
    station_ids = np.asarray(station_ids, dtype=np.int64)
    n_st = station_ids.shape[0]

    space_part = phi_space[station_ids].astype(np.float32)
    y_chunk = Yz[station_ids].astype(np.float32)

    X_space = np.repeat(space_part[:, None, :], T, axis=1)
    X_time = np.repeat(phi_time[None, :, :].astype(np.float32), n_st, axis=0)

    X = np.concatenate([X_space, X_time], axis=2).reshape(n_st * T, D).astype(np.float32)
    y = y_chunk.reshape(-1).astype(np.float32)

    return X, y


def make_stream_dataset(station_ids, Yz, phi_space, phi_time, Ds, D, T, batch_size, stations_per_chunk, taus, shuffle, seed):
    station_ids = np.asarray(station_ids, dtype=np.int64)
    n_st = len(station_ids)

    def gen():
        rng = np.random.RandomState(seed)
        while True:
            if shuffle:
                order = rng.permutation(n_st)
            else:
                order = np.arange(n_st, dtype=np.int64)

            for a in range(0, n_st, stations_per_chunk):
                b = min(a + stations_per_chunk, n_st)
                chunk_ids = station_ids[order[a:b]]

                Xc, yc = build_Xy_stationchunk_vectorized(chunk_ids, Yz, phi_space, phi_time, Ds, D, T)

                n_rows = Xc.shape[0]
                for r0 in range(0, n_rows, batch_size):
                    r1 = min(r0 + batch_size, n_rows)
                    Xb = Xc[r0:r1]
                    yb = yc[r0:r1]
                    y_dict = {f"q{str(t).replace('.','_')}": yb for t in taus}
                    yield Xb, y_dict

    output_signature = (
        tf.TensorSpec(shape=(None, D), dtype=tf.float32),
        {f"q{str(t).replace('.','_')}": tf.TensorSpec(shape=(None,), dtype=tf.float32) for t in taus}
    )

    ds = tf.data.Dataset.from_generator(gen, output_signature=output_signature)
    ds = ds.prefetch(tf.data.AUTOTUNE)
    return ds


files = pick_last_n_files(DATA_DIR, N_FILES)
print("[FILES] using last:", len(files))
print("first:", files[0].name)
print("last :", files[-1].name)

ds0 = open_dataset_robust(files[0])
lat = ds0["lat"].values.astype(np.float32)
lon = ds0["lon"].values.astype(np.float32)
n_lat, n_lon = len(lat), len(lon)
S = n_lat * n_lon
n_time_per_file = int(ds0.sizes["time"])
ds0.close()

blocks = []
time_blocks = []
for p in files:
    ds = open_dataset_robust(p)
    blocks.append(ds["TLML"].values.astype(np.float32).reshape(n_time_per_file, S))
    time_blocks.append(pd.to_datetime(ds["time"].values))
    ds.close()

Y_5days = np.concatenate(blocks, axis=0)
t_5days = np.concatenate(time_blocks, axis=0)

if Y_5days.shape[0] < T_KEEP:
    raise ValueError(f"Total steps={Y_5days.shape[0]} < T_KEEP={T_KEEP}")

Y_last = Y_5days[-T_KEEP:, :]
t_last = t_5days[-T_KEEP:]
Y = Y_last.T.astype(np.float32)
T = Y.shape[1]

print("Y (S,T_KEEP):", Y.shape)
print("t_last range:", t_last[0], "->", t_last[-1])

lat_grid, lon_grid = np.meshgrid(lat, lon, indexing="ij")
lat_pts = lat_grid.reshape(-1).astype(np.float32)
lon_pts = lon_grid.reshape(-1).astype(np.float32)

lat_n = (lat_pts - float(lat_pts.min())) / (float(lat_pts.max() - lat_pts.min()) + 1e-12)
lon_n = (lon_pts - float(lon_pts.max() - lon_pts.min()) + 1e-12)
s_xy = np.column_stack([lat_n, lon_n]).astype(np.float32)

t_idx = np.arange(T, dtype=np.float32)
t_norm = (t_idx - t_idx.min()) / (t_idx.max() - t_idx.min() + 1e-12)

phi_space_full = build_space_basis_fast(s_xy)
phi_time_full = build_time_basis(t_norm)

space_keep = (phi_space_full != 0).any(axis=0)
time_keep = (phi_time_full != 0).any(axis=0)

phi_space = phi_space_full[:, space_keep].astype(np.float32)
phi_time = phi_time_full[:, time_keep].astype(np.float32)

Ds = phi_space.shape[1]
Dt = phi_time.shape[1]
D = Ds + Dt

print("Ds:", Ds, "Dt:", Dt, "D:", D)


cross_list = []
crps_z_list = []
crps_raw_list = []

for rep in range(1, N_REP + 1):
    rep_seed = BASE_SEED + SEED_OFFSET + rep

    random.seed(rep_seed)
    np.random.seed(rep_seed)
    tf.random.set_seed(rep_seed)

    split_rs = np.random.RandomState(rep_seed)
    n_obs = int(np.round(OBS_RATIO * S))
    obs_sites = np.sort(split_rs.choice(S, size=n_obs, replace=False))

    is_obs = np.zeros(S, dtype=bool)
    is_obs[obs_sites] = True
    miss_sites = np.where(~is_obs)[0]

    n_val = max(1, int(np.round(VAL_STATION_RATIO * n_obs)))
    perm = split_rs.permutation(n_obs)
    val_sites = np.sort(obs_sites[perm[:n_val]])
    train_sites = np.sort(obs_sites[perm[n_val:]])

    y_train_raw = Y[train_sites].reshape(-1).astype(np.float32)
    y_mu = float(np.mean(y_train_raw))
    y_sd = float(np.std(y_train_raw) + 1e-12)
    Yz = (Y - y_mu) / y_sd

    train_rows = int(len(train_sites) * T)
    val_rows = int(len(val_sites) * T)
    steps_per_epoch = int(np.ceil(train_rows / BATCH_SIZE))
    val_steps = int(np.ceil(val_rows / BATCH_SIZE))

    n_chunks_miss = int(np.ceil(len(miss_sites) / STATIONS_PER_CHUNK))

    print(
        f"[Rep {rep:02d}/{N_REP}] "
        f"train={len(train_sites)} val={len(val_sites)} miss={len(miss_sites)} | "
        f"steps/epoch={steps_per_epoch} val_steps={val_steps} miss_chunks={n_chunks_miss}"
    )

    keras.backend.clear_session()

    train_ds = make_stream_dataset(
        station_ids=train_sites,
        Yz=Yz,
        phi_space=phi_space,
        phi_time=phi_time,
        Ds=Ds,
        D=D,
        T=T,
        batch_size=BATCH_SIZE,
        stations_per_chunk=STATIONS_PER_CHUNK,
        taus=TAUS,
        shuffle=True,
        seed=rep_seed + 11,
    )

    val_ds = make_stream_dataset(
        station_ids=val_sites,
        Yz=Yz,
        phi_space=phi_space,
        phi_time=phi_time,
        Ds=Ds,
        D=D,
        T=T,
        batch_size=BATCH_SIZE,
        stations_per_chunk=STATIONS_PER_CHUNK,
        taus=TAUS,
        shuffle=False,
        seed=rep_seed + 17,
    )

    model = build_model_multi_quantile(D)

    model.fit(
        train_ds,
        epochs=EPOCHS,
        steps_per_epoch=steps_per_epoch,
        validation_data=val_ds,
        validation_steps=val_steps,
        verbose=1,
        callbacks=[
            keras.callbacks.EarlyStopping(
                monitor="val_loss",
                patience=PATIENCE,
                restore_best_weights=True
            )
        ]
    )

    sum_crps_z = 0.0
    sum_crps_raw = 0.0
    sum_cross = 0.0
    count = 0

    for ci in range(n_chunks_miss):
        a = ci * STATIONS_PER_CHUNK
        b = min((ci + 1) * STATIONS_PER_CHUNK, len(miss_sites))
        chunk_sites = miss_sites[a:b]

        X_chunk, y_true_z = build_Xy_stationchunk_vectorized(chunk_sites, Yz, phi_space, phi_time, Ds, D, T)
        preds_z = model.predict(X_chunk, batch_size=BATCH_SIZE, verbose=0)

        crps_vec_z = crps_from_quantiles_weighted(y_true_z, preds_z, taus=TAUS)
        sum_crps_z += float(np.sum(crps_vec_z))

        y_true_raw = (y_true_z.astype(np.float64) * y_sd + y_mu)
        preds_raw = {}
        for tau in TAUS:
            k = f"q{str(tau).replace('.','_')}"
            preds_raw[k] = (np.asarray(preds_z[k]).reshape(-1).astype(np.float64) * y_sd + y_mu)

        crps_vec_raw = crps_from_quantiles_weighted(y_true_raw, preds_raw, taus=TAUS)
        sum_crps_raw += float(np.sum(crps_vec_raw))

        sum_cross += crossing_rate_from_preds(preds_z, taus=TAUS) * float(y_true_z.size)
        count += int(y_true_z.size)

        del X_chunk
        gc.collect()

        if (ci + 1) % 50 == 0:
            print(f"[Rep {rep:02d}] eval chunk {ci+1}/{n_chunks_miss}")

    cross_list.append(sum_cross / count)
    crps_z_list.append(sum_crps_z / count)
    crps_raw_list.append(sum_crps_raw / count)

    print(f"[Rep {rep:02d}/{N_REP}] Crossing rate = {cross_list[-1]:.6f}")

    del model
    gc.collect()


cross_arr = np.asarray(cross_list, dtype=np.float64)
crps_z_arr = np.asarray(crps_z_list, dtype=np.float64)
crps_raw_arr = np.asarray(crps_raw_list, dtype=np.float64)

print("\n=== Summary over repetitions (resampled station splits) ===")
print(f"Crossing rate mean = {float(np.mean(cross_arr)):.6f}")
print(f"Crossing rate SD   = {float(np.std(cross_arr, ddof=1)):.6f}" if N_REP > 1 else "Crossing rate SD   = NA")
print(f"CRPS_z   mean = {float(np.mean(crps_z_arr)):.6f} | SD = {float(np.std(crps_z_arr, ddof=1)):.6f}" if N_REP > 1 else f"CRPS_z   mean = {float(np.mean(crps_z_arr)):.6f} | SD = NA")
print(f"CRPS_raw mean = {float(np.mean(crps_raw_arr)):.6f} | SD = {float(np.std(crps_raw_arr, ddof=1)):.6f}" if N_REP > 1 else f"CRPS_raw mean = {float(np.mean(crps_raw_arr)):.6f} | SD = NA")


GPUs: []
[FILES] using last: 5
first: M2T1NXFLX.5.12.4%3AMERRA2_400.tavg1_2d_flx_Nx.20241227.nc4.dap.nc4
last : M2T1NXFLX.5.12.4%3AMERRA2_400.tavg1_2d_flx_Nx.20241231.nc4.dap.nc4
Y (S,T_KEEP): (207936, 100)
t_last range: 2024-12-27T20:30:00.000000000 -> 2024-12-31T23:30:00.000000000
Ds: 0 Dt: 70 D: 70
[Rep 01/10] train=18715 val=2079 miss=187142 | steps/epoch=3656 val_steps=407 miss_chunks=1463
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 36/300
Epoch 37/300
Epoch 38/300
Epoch 39/300
Epoch 40/300
Epoch 41/300
Epoch 42/300
Epoch 43/300
Epoch 44/300
Epoch 45/300
Epoch 46/300
Epoch 47/300
