Quick demo of the full DKAJ pipeline on the Framingham dataset:

1. Load and preprocess data
2. Fit SurvivalBoost
3. TUNA warmup of a base neural net using SurvivalBoost leaves
4. Train DKAJ on top of that warm-started net
5. (Optional) Summary fine-tuning
6. Evaluate using c-index and integrated brier score

This is meant as an easy-to-read usage example, not a full experiment runner.

In [1]:

import os
import numpy as np
import pandas as pd
import torch

from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

from datasets import load_dataset, LabTransformCR
from models import (
    SurvivalBoostWrap,
    hist_gradient_boosting_classifier_apply,
    create_base_neural_net_with_hypersphere,
    tuna_loss,
    DKAJ,
    DKAJSummary,
    DKAJSummaryLoss,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Set this flag if you want to run summary fine-tuning at the end
do_summary_finetune = True  # from dkaj10.finetune_summaries = 1

Using device: cuda


## Load and preprocess Framingham

In [2]:
simple_data_splitting_val_ratio = 0.2
fix_test_shuffle_train = True  # DEFAULT.fix_test_shuffle_train = 1

(
    X_full_train_raw_np,
    Y_full_train_np,
    D_full_train_np,
    X_test_raw_np,
    Y_test_np,
    D_test_np,
    features_before_preprocessing,
    features_after_preprocessing,
    events,
    train_test_split_prespecified,
    build_preprocessor_and_preprocess,
    apply_preprocessor,
) = load_dataset(
    "framingham",
    random_seed_offset=0,
    fix_test_shuffle_train=fix_test_shuffle_train,
    competing=True,
)

# Split original training into train/validation using config ratio
X_train_raw_np, X_val_raw_np, Y_train_np, Y_val_np, D_train_np, D_val_np = train_test_split(
    X_full_train_raw_np,
    Y_full_train_np,
    D_full_train_np,
    test_size=simple_data_splitting_val_ratio,
    random_state=0,
    stratify=D_full_train_np,
)

# Build preprocessor on training data, apply to val/test
X_train_np, preprocessor = build_preprocessor_and_preprocess(X_train_raw_np)
X_val_np = apply_preprocessor(X_val_raw_np, preprocessor)
X_test_np = apply_preprocessor(X_test_raw_np, preprocessor)

print("Framingham shapes:")
print("  X_train:", X_train_np.shape)
print("  X_val:  ", X_val_np.shape)
print("  X_test: ", X_test_np.shape)
print("  #events:", len(events))
print()

Framingham shapes:
  X_train: (2482, 21)
  X_val:   (621, 21)
  X_test:  (1331, 21)
  #events: 2



## Fit SurvivalBoost

In [None]:
survboost_learning_rate = 0.05        # can choose from [0.01, 0.05, 0.1, 0.5]
survboost_n_iter = 200                # can choose from [20, 100, 200]
survboost_max_depth = 4               # can choose from [-1, 4, 8, 16]
survboost_n_time_grid_steps = 128     # can choose from [64, 128]
survboost_ipcw_strategy = "kaplan-meier"  # choose from ['alternating', 'kaplan-meier']
survboost_random_seed = 146561020

y_train_df = pd.DataFrame(
    np.vstack([D_train_np, Y_train_np]).T,
    columns=["event", "duration"],
)

survival_boost = SurvivalBoostWrap(
    n_iter=survboost_n_iter,
    learning_rate=survboost_learning_rate,
    max_depth=survboost_max_depth,
    n_time_grid_steps=survboost_n_time_grid_steps,
    ipcw_strategy=survboost_ipcw_strategy,
    random_state=survboost_random_seed,
)
survival_boost.fit(X_train_np, y_train_df)

print("Fitted SurvivalBoost model.")
print()

100%|██████████| 200/200 [00:03<00:00, 65.67it/s]

Fitted SurvivalBoost model.






## TUNA warmup (learn neural net to mimic SurvivalBoost kernel)

In [None]:
n_random_times_per_data_point = 10    # can choose from [5, 10]
tuna_batch_size = 1024                
tuna_n_layers = 2                     # can choose from [2, 4]
tuna_n_nodes = 128                    # can choose from [64, 128]
tuna_squared_radius = 0.1             # can choose from [0.1]
tuna_learning_rate = 1e-3             # can choose from [0.01, 0.001]
tuna_epochs = 10                      # small for quick demo

In [5]:
# Build (time, x) features with random times, then compute SurvivalBoost leaves
X_train_with_random_times_np = []
X_val_with_random_times_np = []

for _ in range(n_random_times_per_data_point):
    rand_t_train = np.random.uniform(
        Y_train_np.min(), Y_train_np.max(),
        size=X_train_np.shape[0]
    ).reshape(-1, 1)
    rand_t_val = np.random.uniform(
        Y_val_np.min(), Y_val_np.max(),
        size=X_val_np.shape[0]
    ).reshape(-1, 1)

    X_train_with_random_times_np.append(np.hstack([rand_t_train, X_train_np]))
    X_val_with_random_times_np.append(np.hstack([rand_t_val, X_val_np]))

X_train_with_random_times_np = np.vstack(X_train_with_random_times_np)
X_val_with_random_times_np = np.vstack(X_val_with_random_times_np)

# Predict leaves for these (time, x) vectors
survival_boost_leaves_train_np = hist_gradient_boosting_classifier_apply(
    survival_boost.estimator_,
    X_train_with_random_times_np,
    n_threads=os.cpu_count() or 1,
)
survival_boost_leaves_val_np = hist_gradient_boosting_classifier_apply(
    survival_boost.estimator_,
    X_val_with_random_times_np,
    n_threads=os.cpu_count() or 1,
)

# Aggregate leaves over random times: (n_samples, n_random_times * leaf_dim)
n_train = X_train_np.shape[0]
n_val = X_val_np.shape[0]
n_times = n_random_times_per_data_point

survival_boost_leaves_train_np = (
    survival_boost_leaves_train_np
    .reshape(n_times, n_train, -1)
    .transpose(1, 0, 2)
    .reshape(n_train, -1)
)
survival_boost_leaves_val_np = (
    survival_boost_leaves_val_np
    .reshape(n_times, n_val, -1)
    .transpose(1, 0, 2)
    .reshape(n_val, -1)
)

In [6]:
# Prepare TUNA dataloaders: map x -> aggregated leaf encoding
X_train_t = torch.tensor(X_train_np, dtype=torch.float32, device=device)
X_val_t = torch.tensor(X_val_np, dtype=torch.float32, device=device)
leaves_train_t = torch.tensor(survival_boost_leaves_train_np, dtype=torch.float32, device=device)
leaves_val_t = torch.tensor(survival_boost_leaves_val_np, dtype=torch.float32, device=device)

tuna_train_data = list(zip(X_train_t, leaves_train_t))
tuna_val_data = list(zip(X_val_t, leaves_val_t))

tuna_train_loader = DataLoader(tuna_train_data, batch_size=tuna_batch_size, shuffle=True)
tuna_val_loader = DataLoader(tuna_val_data, batch_size=tuna_batch_size, shuffle=False)

In [7]:
# Create base neural net on hypersphere
num_input_features = X_train_np.shape[1]
base_neural_net = create_base_neural_net_with_hypersphere(
    num_input_features,
    [tuna_n_nodes for _ in range(tuna_n_layers)],
    squared_radius=tuna_squared_radius,
).to(device)

tuna_optimizer = torch.optim.Adam(base_neural_net.parameters(), lr=tuna_learning_rate)

In [None]:
print("=== TUNA warmup ===")
for epoch in range(tuna_epochs):
    # Training
    base_neural_net.train()
    train_loss = 0.0
    n_train_seen = 0

    for X_batch, leaves_batch in tuna_train_loader:
        outputs = base_neural_net(X_batch)
        loss = tuna_loss(outputs, leaves_batch, device=device)
        tuna_optimizer.zero_grad()
        loss.backward()
        tuna_optimizer.step()
        train_loss += float(loss) * X_batch.size(0)
        n_train_seen += X_batch.size(0)

    train_loss /= max(1, n_train_seen)

    # Validation
    base_neural_net.eval()
    val_loss = 0.0
    n_val_seen = 0
    with torch.no_grad():
        for X_batch, leaves_batch in tuna_val_loader:
            outputs = base_neural_net(X_batch)
            loss = tuna_loss(outputs, leaves_batch, device=device)
            val_loss += float(loss) * X_batch.size(0)
            n_val_seen += X_batch.size(0)
    val_loss /= max(1, n_val_seen)

    print(f"[TUNA] epoch {epoch+1:03d} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f}")

print("Completed TUNA warmup.")
print()

=== TUNA warmup ===
[TUNA] epoch 001 | train_loss=0.1310 | val_loss=0.1201
[TUNA] epoch 002 | train_loss=0.1134 | val_loss=0.0936
[TUNA] epoch 003 | train_loss=0.0839 | val_loss=0.0589
[TUNA] epoch 004 | train_loss=0.0499 | val_loss=0.0290
[TUNA] epoch 005 | train_loss=0.0256 | val_loss=0.0222
[TUNA] epoch 006 | train_loss=0.0234 | val_loss=0.0263
[TUNA] epoch 007 | train_loss=0.0242 | val_loss=0.0185
[TUNA] epoch 008 | train_loss=0.0156 | val_loss=0.0122
[TUNA] epoch 009 | train_loss=0.0112 | val_loss=0.0132
[TUNA] epoch 010 | train_loss=0.0126 | val_loss=0.0142
Completed TUNA warmup.



## DKAJ training (with warm-started neural net)

In [None]:
dkaj_n_durations = 128     # can choose from [0, 64, 128]; 0 means "use all unique event times"
dkaj_batch_size = 1024
dkaj_learning_rate = 1e-3  # can choose from [0.01, 0.001]
dkaj_alpha = 0.001         # can choose from [0, 0.001, 0.01]
dkaj_sigma = 1.0           # can choose from [0.1, 1]
dkaj_beta = 0.25           # can choose from [0.25, 0.5]
dkaj_min_kernel_weight = 1e-2
dkaj_max_epochs = 20           # small for a quick demo

ANN_max_n_neighbors = 128  # ANN_max_n_neighbors = [128]


In [10]:
# Discretize times using LabTransformCR
if dkaj_n_durations == 0:
    # Use all unique non-censored event times, up to a cap (512) as in run_dkaj
    mask = (D_train_np >= 1)
    n_unique_times = np.unique(Y_train_np[mask]).shape[0]
    if n_unique_times > 512:
        print(
            f"Trying to use all training unique event times, but there are {n_unique_times} "
            "unique event times. Using upper limit of 512 instead."
        )
        label_transform = LabTransformCR(512, scheme="quantiles")
    else:
        label_transform = LabTransformCR(np.unique(Y_train_np[mask]))
else:
    label_transform = LabTransformCR(dkaj_n_durations, scheme="quantiles")

Y_train_discrete_np, D_train_discrete_np = label_transform.fit_transform(Y_train_np, D_train_np)
Y_val_discrete_np, D_val_discrete_np = label_transform.transform(Y_val_np, D_val_np)
time_grid_train_np = label_transform.cuts

# Build dataloaders after discretization
X_train_t = torch.tensor(X_train_np, dtype=torch.float32, device=device)
Y_train_t = torch.tensor(Y_train_discrete_np, dtype=torch.int64, device=device)
D_train_t = torch.tensor(D_train_discrete_np, dtype=torch.int32, device=device)
dkaj_train_data = list(zip(X_train_t, Y_train_t, D_train_t))

X_val_t = torch.tensor(X_val_np, dtype=torch.float32, device=device)
Y_val_t = torch.tensor(Y_val_discrete_np, dtype=torch.int64, device=device)
D_val_t = torch.tensor(D_val_discrete_np, dtype=torch.int32, device=device)
dkaj_val_data = list(zip(X_val_t, Y_val_t, D_val_t))

dkaj_train_loader = DataLoader(dkaj_train_data, batch_size=dkaj_batch_size, shuffle=True)
dkaj_val_loader = DataLoader(dkaj_val_data, batch_size=dkaj_batch_size, shuffle=False)



In [11]:
# Instantiate DKAJ model on top of base_neural_net
dkaj_tau = np.sqrt(-np.log(dkaj_min_kernel_weight))

dkaj_model = DKAJ(
    base_neural_net,
    device=device,
    alpha=dkaj_alpha,
    sigma=dkaj_sigma,
    beta=dkaj_beta,
    tau=dkaj_tau,
    max_n_neighbors=ANN_max_n_neighbors,
    dkn_max_n_neighbors=ANN_max_n_neighbors,
)
dkaj_loss = dkaj_model.loss
dkaj_optimizer = torch.optim.Adam(base_neural_net.parameters(), lr=dkaj_learning_rate)

In [12]:
print("=== DKAJ training ===")
for epoch in range(dkaj_max_epochs):
    # Training
    base_neural_net.train()
    train_loss = 0.0
    n_train_seen = 0

    for X_batch, Y_batch, D_batch in dkaj_train_loader:
        embeddings = base_neural_net(X_batch)
        loss = dkaj_loss(embeddings, Y_batch, D_batch)

        dkaj_optimizer.zero_grad()
        loss.backward()
        dkaj_optimizer.step()

        train_loss += float(loss) * X_batch.size(0)
        n_train_seen += X_batch.size(0)

    train_loss /= max(1, n_train_seen)

    # Build ANN index & compute simple validation loss
    dkaj_model.training_data = (
        X_train_np.astype("float32"),
        (Y_train_discrete_np.astype("int64"), D_train_discrete_np.astype("int32")),
    )
    dkaj_model.train_embeddings = dkaj_model.predict(
        X_train_np.astype("float32"),
        batch_size=dkaj_batch_size,
    )
    dkaj_model.duration_index = time_grid_train_np
    dkaj_model.build_ANN_index()

    base_neural_net.eval()
    val_loss = 0.0
    n_val_seen = 0
    with torch.no_grad():
        for X_batch, Y_batch, D_batch in dkaj_val_loader:
            embeddings = base_neural_net(X_batch)
            loss = dkaj_loss(embeddings, Y_batch, D_batch)
            val_loss += float(loss) * X_batch.size(0)
            n_val_seen += X_batch.size(0)
    val_loss /= max(1, n_val_seen)

    print(f"[DKAJ] epoch {epoch+1:03d} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f}")

print("Completed DKAJ training.")
print()

=== DKAJ training ===
[DKAJ] epoch 001 | train_loss=0.3418 | val_loss=0.3392
[DKAJ] epoch 002 | train_loss=0.3383 | val_loss=0.3361
[DKAJ] epoch 003 | train_loss=0.3358 | val_loss=0.3337
[DKAJ] epoch 004 | train_loss=0.3335 | val_loss=0.3319
[DKAJ] epoch 005 | train_loss=0.3312 | val_loss=0.3306
[DKAJ] epoch 006 | train_loss=0.3299 | val_loss=0.3295
[DKAJ] epoch 007 | train_loss=0.3286 | val_loss=0.3285
[DKAJ] epoch 008 | train_loss=0.3269 | val_loss=0.3275
[DKAJ] epoch 009 | train_loss=0.3255 | val_loss=0.3266
[DKAJ] epoch 010 | train_loss=0.3231 | val_loss=0.3256
[DKAJ] epoch 011 | train_loss=0.3220 | val_loss=0.3246
[DKAJ] epoch 012 | train_loss=0.3197 | val_loss=0.3236
[DKAJ] epoch 013 | train_loss=0.3175 | val_loss=0.3229
[DKAJ] epoch 014 | train_loss=0.3152 | val_loss=0.3221
[DKAJ] epoch 015 | train_loss=0.3132 | val_loss=0.3215
[DKAJ] epoch 016 | train_loss=0.3118 | val_loss=0.3209
[DKAJ] epoch 017 | train_loss=0.3105 | val_loss=0.3201
[DKAJ] epoch 018 | train_loss=0.3083 | val_

In [13]:
print("Computing CIFs on test set (demo)...")
cifs = dkaj_model.predict_cif(
    X_test_np.astype("float32"),
    batch_size=dkaj_batch_size,
    numpy=True,
    to_cpu=True,
)
cifs_np = np.array(cifs)  # shape: (n_events, n_durations, n_test)
print("CIFs shape:", cifs_np.shape)
print()

Computing CIFs on test set (demo)...
CIFs shape: (2, 113, 1331)



## Optional summary fine-tuning (DKAJSummary)

In [None]:
if do_summary_finetune:
    sumtune_learning_rate = 1e-3  # can choose from [0.01, 0.001, 0.0001]
    sumtune_max_epochs = 10       # smaller for demo

    # Use same alpha/sigma as main DKAJ
    sumtune_alpha = dkaj_alpha
    sumtune_sigma = dkaj_sigma

    # Reuse discrete labels
    Y_train_disc_np, D_train_disc_np = label_transform.transform(Y_train_np, D_train_np)

    X_train_t = torch.tensor(X_train_np, dtype=torch.float32, device=device)
    Y_train_t = torch.tensor(Y_train_disc_np, dtype=torch.int64, device=device)
    D_train_t = torch.tensor(D_train_disc_np, dtype=torch.int32, device=device)
    sumtune_train_data = list(zip(X_train_t, Y_train_t, D_train_t))
    sumtune_train_loader = DataLoader(sumtune_train_data, batch_size=dkaj_batch_size, shuffle=True)

    # Initialize summary network around current summary functions
    init_summary_functions = dkaj_model.get_summary_functions()
    summary_net = DKAJSummary(
        dkaj_model,
        init_summary_functions[0],
        init_summary_functions[1],
    ).to(device)
    summary_loss = DKAJSummaryLoss(sumtune_alpha, sumtune_sigma)
    summary_optimizer = torch.optim.Adam(summary_net.parameters(), lr=sumtune_learning_rate)

    print("=== Summary fine-tuning ===")
    for epoch in range(sumtune_max_epochs):
        summary_net.train()
        epoch_loss = 0.0
        n_train_seen = 0

        for X_batch, Y_batch, D_batch in sumtune_train_loader:
            with torch.no_grad():
                embeddings = base_neural_net(X_batch)
            event_hazards, overall_hazards = summary_net(embeddings)
            loss = summary_loss(
                event_hazards,
                overall_hazards,
                Y_batch,
                D_batch,
            )

            summary_optimizer.zero_grad()
            loss.backward()
            summary_optimizer.step()

            epoch_loss += float(loss) * X_batch.size(0)
            n_train_seen += X_batch.size(0)

        epoch_loss /= max(1, n_train_seen)

        # Update dkaj_model's summary functions from the summary net
        (
            exemplar_event_counts_new,
            exemplar_at_risk_counts_new,
            baseline_event_counts,
            baseline_at_risk_counts,
        ) = summary_net.get_exemplar_summary_functions_baseline_event_at_risk_counts()
        dkaj_model.load_summary_functions(
            exemplar_event_counts_new,
            exemplar_at_risk_counts_new,
            baseline_event_counts,
            baseline_at_risk_counts,
        )

        print(f"[Summary] epoch {epoch+1:03d} | loss={epoch_loss:.4f}")

    print("Completed summary fine-tuning.")
    print()

=== Summary fine-tuning ===
[Summary] epoch 001 | loss=0.2191
[Summary] epoch 002 | loss=0.2189
[Summary] epoch 003 | loss=0.2190
[Summary] epoch 004 | loss=0.2189
[Summary] epoch 005 | loss=0.2189
[Summary] epoch 006 | loss=0.2188
[Summary] epoch 007 | loss=0.2186
[Summary] epoch 008 | loss=0.2186
[Summary] epoch 009 | loss=0.2185
[Summary] epoch 010 | loss=0.2186
Completed summary fine-tuning.



## Evaluation

- Uses training data only to set the time grids
- Computes Brier scores at the 0.25/0.5/0.75 quantile times
- Computes integrated Brier score up to the 90th percentile of training event times
- Computes the time-dependent concordance index (Ctd) per event and averaged.

In [15]:
from lifelines import KaplanMeierFitter
from metrics import (
    neg_cindex_td,
    compute_brier_competing_multiple_times,
    compute_ibs_competing,
)

In [18]:
print("Computing CIFs on test set (demo)...")
events_pretty_names = ['non-CVD Death', 'CVD Death']
cifs = dkaj_model.predict_cif(
    X_test_np.astype("float32"),
    batch_size=dkaj_batch_size,
    numpy=True,
    to_cpu=True,
)
cifs_np = np.array(cifs)  # shape: (n_events, n_durations, n_test)
print("CIFs shape:", cifs_np.shape)
print()

# ---------------------------------------------------------------------------
# 5. Evaluation: Brier scores, Ctd, IBS (training-based horizons)
# ---------------------------------------------------------------------------

eval_horizon_quantiles = np.array([0.25, 0.5, 0.75])
ibs_n_horizon_points = 100
ibs_max_horizon_percentile = 0.9

# Use ONLY the training set to define time horizons
# Here we use the original full training split from load_dataset
Y_all_train = Y_full_train_np
D_all_train = D_full_train_np
event_times_train = Y_all_train[D_all_train > 0]

eval_horizons = np.quantile(event_times_train, eval_horizon_quantiles)
tau = np.percentile(event_times_train, ibs_max_horizon_percentile * 100.0)
ibs_integrate_horizons = np.linspace(
    event_times_train.min(),
    tau,
    ibs_n_horizon_points,
)

print("Evaluation horizons (quantiles based on training set):")
for q, t in zip(eval_horizon_quantiles, eval_horizons):
    print(f"  q={q:.2f} -> t={t:.4f}")
print("IBS integration from", float(event_times_train.min()), "to", float(tau))
print()

# Fit censoring distribution Kaplan–Meier on TRAINING split only
censoring_kmf = KaplanMeierFitter()
censoring_kmf.fit(
    durations=Y_train_np,
    event_observed=(D_train_np == 0).astype(int),
)

test_eval_brier_scores_all_events = []  # shape -> (n_events, n_eval_horizons)
test_ibs_all_events = []                # shape -> (n_events,)
test_cindex_td_all_events = []          # shape -> (n_events,)

# Loop over each event type (1, 2, ...)
for e_idx_minus_1, event in enumerate(events_pretty_names):
    event_of_interest = e_idx_minus_1 + 1
    print(f"Evaluating event {event_of_interest} ({event})")

    # CIF for this event on test set: (n_durations, n_test)
    cif_event_test = cifs_np[e_idx_minus_1]  # (n_durations, n_test)
    n_durations, n_test = cif_event_test.shape

    # Interpolate CIF at eval_horizons and IBS horizons
    cif_eval_grid = np.empty((n_test, len(eval_horizons)))          # (n_test, n_eval_horizons)
    cif_ibs_grid = np.empty((n_test, len(ibs_integrate_horizons)))  # (n_test, n_ibs_points)

    for j in range(n_test):
        cif_eval_grid[j, :] = np.interp(
            eval_horizons,
            dkaj_model.duration_index,
            cif_event_test[:, j],
        )
        cif_ibs_grid[j, :] = np.interp(
            ibs_integrate_horizons,
            dkaj_model.duration_index,
            cif_event_test[:, j],
        )

    # --- Brier scores at the 0.25 / 0.5 / 0.75 quantile times ---
    eval_brier_scores = compute_brier_competing_multiple_times(
        cif_values_grid=cif_eval_grid,     # (n_samples, n_timepoints)
        censoring_kmf=censoring_kmf,
        Y_test=Y_test_np,
        D_test=D_test_np,
        event_of_interest=event_of_interest,
        time_horizons=eval_horizons,
    )

    # --- Integrated Brier Score (IBS) up to 90th percentile ---
    ibs = compute_ibs_competing(
        cif_values_grid=cif_ibs_grid,
        censoring_kmf=censoring_kmf,
        Y_test=Y_test_np,
        D_test=D_test_np,
        event_of_interest=event_of_interest,
        time_horizons=ibs_integrate_horizons,
    )

    # --- Time-dependent concordance index (Antolini) ---
    # Follow run_dkaj.py: use -CIF as "survival-like" scores on the IBS time grid
    surv_for_ctd = -cif_ibs_grid.T  # shape (n_times, n_samples)
    cindex_td = -neg_cindex_td(
        Y_test_np,
        (D_test_np == event_of_interest).astype(int),
        (surv_for_ctd, ibs_integrate_horizons),
        exact=False,
    )

    test_eval_brier_scores_all_events.append(eval_brier_scores)
    test_ibs_all_events.append(ibs)
    test_cindex_td_all_events.append(cindex_td)

# Convert to arrays for easy summarization
test_eval_brier_scores_all_events = np.array(test_eval_brier_scores_all_events)  # (n_events, n_eval_horizons)
test_ibs_all_events = np.array(test_ibs_all_events)                              # (n_events,)
test_cindex_td_all_events = np.array(test_cindex_td_all_events)                  # (n_events,)

# Aggregate across events with equal weights (which is the criteria used in the paper for early stopping when evaluted on the validation set)
avg_IBS = test_ibs_all_events.mean()
avg_Brier_per_horizon = test_eval_brier_scores_all_events.mean(axis=0)
avg_Ctd = test_cindex_td_all_events.mean()

print("\n=== Test-set metrics (averaged over events) ===")
print(f"Average IBS (0 to 90th percentile): {avg_IBS:.4f}")
for q, t, b in zip(eval_horizon_quantiles, eval_horizons, avg_Brier_per_horizon):
    print(f"Brier score at q={q:.2f} (t={t:.4f}): {b:.4f}")
print(f"Average time-dependent concordance (Antolini Ctd): {avg_Ctd:.4f}")

print("\nPer-event Ctd:")
for event_name, ctd in zip(events_pretty_names, test_cindex_td_all_events):
    print(f"  {event_name}: {ctd:.4f}")

Computing CIFs on test set (demo)...
CIFs shape: (2, 113, 1331)

Evaluation horizons (quantiles based on training set):
  q=0.25 -> t=2231.2500
  q=0.50 -> t=4639.0000
  q=0.75 -> t=6751.5000
IBS integration from 1.0 to 7942.000000000002

Evaluating event 1 (non-CVD Death)
Evaluating event 2 (CVD Death)

=== Test-set metrics (averaged over events) ===
Average IBS (0 to 90th percentile): 0.0838
Brier score at q=0.25 (t=2231.2500): 0.0513
Brier score at q=0.50 (t=4639.0000): 0.0956
Brier score at q=0.75 (t=6751.5000): 0.1370
Average time-dependent concordance (Antolini Ctd): 0.6699

Per-event Ctd:
  non-CVD Death: 0.6386
  CVD Death: 0.7012
