# A PROGNOSTIC MODEL FOR MYELODYSPLASTIC SYNDROMES BASED ON DTA MUTATIONAL BURDEN: DEVELOPMENT AND VALIDATION

In [24]:
# ------------------------------------------------------------------
# GLOBAL WARNING CONTROL
# ------------------------------------------------------------------
import warnings
warnings.filterwarnings("ignore")


# Basic packages

In [None]:
import numpy as np
import pandas as pd
#import i2bmi

In [None]:
pd.set_option('display.max_rows', None)
pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_columns',None)

# Overall survival (OS)

# Data Loading and Preprocessing

## Load pre-processed data

In [None]:
### Path definitions

from pathlib import Path

# Folder where this script lives 
HERE = Path.cwd()

# Project root -> go one level up from HERE
PROJECT_ROOT = HERE.parent

# Get paths to the data output folders
OUTPUT = PROJECT_ROOT / "output"
FIGURES = PROJECT_ROOT / "figures"

In [None]:
### Load data using saved csv file (created based on data pre-processing script)
df=pd.read_csv(OUTPUT / "mds_dta_cohort_os.csv")  

In [None]:
df.head()

In [None]:
### Split data in features and outcomes (X and y)
target_cols = ['os_months', 'os_status']
X = df.drop(columns=target_cols)
y = df[['os_months', 'os_status']]

# Splitting Data into Training and Testing Sets

The most important step to avoid Data Leakage, carried out using the `train_test_split` function from the `sklearn.model_selection` module.

In [None]:
### Split training and testing sets
from sklearn.model_selection import train_test_split

# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, 
                                                    test_size=0.2, 
                                                    random_state=42,
                                                    shuffle=True,
                                                    stratify=y['os_status']) #to ensure fair distribution of events

In [None]:
len(X_train), len(X_test)

# Standardization/ Data Scaler (z-score)

It centers features to mean 0 and standard deviation as 1. Since most ML algorithms assume approx normal distribution, scaling is used to improve performance, particularly in classification models. \
Once training data is established, we pursued scaling using `StandardScaler`.
Scaling was performed before imputation since we are using MICE, which can be affected by magnitude of variables

In [None]:
### Initialize the Scaler and chose columns to scale (only float64 columns need to normalize and few extra ones, do not normalize dummies even if "numerical")
from sklearn.preprocessing import StandardScaler

# Fit the scaler *only* on the training data
scaler = StandardScaler()
cols_to_scale = X.select_dtypes(include=['float64']).columns.tolist()
extra_cols = ['asxl1_only_counts', 'dta_non_asxl1_counts', 'asxl1_mixed_counts', 'dta_non_asxl1_counts', 'truncating_variant', 'asxl1_truncating_variant', 'dnmt3a_truncating_variant', 'tet2_truncating_variant', 'n_truncating_variant', 'pathogenic_asxl1', 'pathogenic_dnmt3a', 'pathogenic_tet2']

all_cols_to_scale = cols_to_scale + extra_cols

# Fit on all columns
scaler.fit(X_train[all_cols_to_scale])

# Transform all columns
X_train[all_cols_to_scale] = scaler.transform(X_train[all_cols_to_scale])
X_test[all_cols_to_scale] = scaler.transform(X_test[all_cols_to_scale]) #trained on train, applied to test

In [None]:
X_train.head()

# Imputation

The dataset has missing values (NaN). To manage this (and since the data has been already split to avoid data leakage):
1. CADD and phyloP (MOST LIKELY) will have missing values if truncating_variant = 1 since they are novel frameshifts not reported in databases (ClinVar and gnomAD). 
2. That said, we will performed imputation using MICE appproach and Random Forest Regressor as the estimator. 
3. Finally, the very fist step was add missingness indicators for every predictor with `NaN` (already done).

In [None]:
### MICE imputation for remaining missing values (Train set only)


from sklearn.ensemble import RandomForestRegressor
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
import pandas as pd

# Identifying columns with missing values (tolist)
cols_with_missing = X_train.columns[X_train.isnull().any()].tolist()

# Use all columns as predictors
predictor_cols = X_train.columns.tolist()

# Initialize MICE imputer with Random Forest as estimator
mice_imputer = IterativeImputer(
    random_state=42,
    max_iter=10,
    estimator=RandomForestRegressor(
        n_estimators=100,
        random_state=42,
        n_jobs=-1
    ),
    sample_posterior=False
)

# Fit MICE imputer on training data (it will return as array)
data_imputed_array_train = mice_imputer.fit_transform(X_train[predictor_cols])

# Convert array back to DataFrame (full imputed train matrix)
X_train_imputed_full = pd.DataFrame(
    data_imputed_array_train,
    columns=predictor_cols,
    index=X_train.index
)

# Overwrite ONLY columns that had missing values
X_train[cols_with_missing] = X_train_imputed_full[cols_with_missing]

In [None]:
### MICE imputation for testing set

# Apply MICE imputer on testing data (using training-fitted imputer)
data_imputed_array_test = mice_imputer.transform(X_test[predictor_cols])

# Convert output array back to DataFrame (full imputed test matrix)
X_test_imputed = pd.DataFrame(
    data_imputed_array_test,
    columns=predictor_cols,
    index=X_test.index
)

# Overwrite ONLY columns that had missing values
X_test[cols_with_missing] = X_test_imputed[cols_with_missing]

# Machine Learning Models
Considering we are evaluating survival anlaysis as outcomes, we searched for related models. 
1. **CoxPH** is a traditional statistical analysis that struggles with high-dimensional data and multicollinearity. However `CoxnetSurvivalAnalysis` is a machine learning model that predicts survival in high-dimensional data settings. It is less vulnerable to overfitting due to Elastic Net (L1 and L2) penalties (PMID: 37884606, 27065756)

2. **Random Survival Forest** is a nonparametric machine learning model that handles non-linear interactions. Other similar studies have used this model (see abstract)

3. **Gradient-Boosted Survival Trees**: "A gradient boosted model is similar to a Random Survival Forest, in the sense that it relies on multiple base learners to produce an overall prediction, but differs in how those are combined. While a Random Survival Forest fits a set of Survival Trees independently and then averages their predictions, a gradient boosted model is constructed sequentially in a greedy stagewise fashion." (scikit-survival)


# Hyperparameter tuning

Hyperparameter tuning (or optimization) is the process of finding the best set of hyperparameter values that maximize model performance (e.g., accuracy, AUC, F1, etc.) on validation data.

### Note: 
We have experienced issue with censoring (time_train.max), since every time point that is passed must lie in a region where the censoring Kaplan-Meier (under the hood analysis) is defined (i.e., up to the last censoring time in the training set) AND must have follow-up in the test set (i.e., between the minimun and maximum observed test times). U;timately this was caused by `integrated_brier_score`. We fixed having the IBS evaluation to always construct a safe time grid from the durrent data instead of hard-coding percentiles.

In [None]:
### Prepare IBS evaluation function and time grid

import numpy as np

def prepare_ibs_evaluation(y_train_surv, y_test_surv, X_test, n_grid=200, eps=1e-3):
    """
    Prepare a test subset and time grid that are valid for integrated_brier_score.

    IBS restriction:
      - Only test patients with follow-up < max(train time) are used.
      - Time grid is chosen within the overlap of train times and IBS-eligible test times.
    """
    time_train = y_train_surv["os_months"]
    time_test_all = y_test_surv["os_months"]

    # Max observed time in TRAIN (this is what sksurv's censoring KM uses)
    max_train_time = time_train.max()

    # Keep only test patients with follow-up strictly below max_train_time
    ibs_mask = time_test_all < max_train_time
    n_ibs = ibs_mask.sum()
    if n_ibs < 2:
        raise ValueError(f"Too few IBS-eligible test patients: {n_ibs}.")

    # Structured array subset for Surv
    y_test_surv_ibs = y_test_surv[ibs_mask]

    # Match X_test rows
    if hasattr(X_test, "iloc"):
        idx = np.where(ibs_mask)[0]
        X_test_ibs = X_test.iloc[idx]
    else:
        X_test_ibs = X_test[ibs_mask]

    # Build a time grid in the overlap between train and IBS-eligible test
    time_test_ibs = y_test_surv_ibs["os_months"]

    t_min = max(time_train.min(), time_test_ibs.min())
    t_max = min(max_train_time, time_test_ibs.max()) - eps  # stay strictly below max_train_time

    if t_max <= t_min:
        raise ValueError(f"Invalid IBS time window: t_min={t_min}, t_max={t_max}")

    time_grid = np.linspace(t_min, t_max, n_grid, endpoint=False)

    return X_test_ibs, y_test_surv_ibs, time_grid, max_train_time


### CoxNetSurvival
We are doing simple `GridSearch` as there are only two parameters (alpha and L1 regularization). That said, it is feasible to search for all possibilities/combinations \
To evaluate performance, we are using Harrell's C-index and Integrated Brier Score (IBS), 

References: 
1. Haider H., Hoehn B eet al. Effective  Ways to Build and Evaluate Individual Survival Distributions. Journal of Machine Learning Research 21 (2020) 1-63 
2. Ping Wang, Yan Li, and Chandan k. Reddy. 2019. Machine Learning for Survival Analysis: A Survey. ACM
Comput. Surv. 51, 6, Article 110 (February 2019)
3. https://scikit-survival.readthedocs.io/en/latest/user_guide/evaluating-survival-models.html



### Note:
In the earlier analysis, all patients with zero survival duration were randomly assigned to the test set only, so `CoxnetSurvivalAnalysis` never encountered a non-positive time in model fitting. \
After feature engineering and column removal, the train/test distribution changed, placing some zero-time patients in the training folds. \
`CoxnetSurvivalAnalysis` requires strictly positive survival times and will crash (segfault) when encountering zero values during partial likelihood computation. \
Adjusting the time origin by adding a small constant (e.g., 0.01 months) prevents this issue without affecting hazard ratios or model discrimination.

In [None]:
### Avoid zero months by clipping at small value and avoid kernel crashes

epsilon = 0.01

y_train["os_months"] = y_train["os_months"].clip(lower=epsilon)
y_test["os_months"]  = y_test["os_months"].clip(lower=epsilon)


### Note:
Some correlated features remained even after dropping fold-constant ones.
`CoxnetSurvivalAnalysis` is extremely sensitive to correlation (weights blow up at small alpha and small tol (1e−7))

So the solution was: increase regularization + relax the solver tolerance

The `sksurv` authors explicitly recommend:
- alpha_min ≥ 0.01 for medium-sized datasets
- tol ≥ 1e−6
- max_iter ≤ 2000

In [None]:
### Coxnet model with hyperparameter tuning, C-index and IBS evaluation

from sklearn.model_selection import KFold, GridSearchCV
from sksurv.linear_model import CoxnetSurvivalAnalysis
from sksurv.metrics import concordance_index_censored, integrated_brier_score
from sksurv.util import Surv
import numpy as np
import time

# Custom scoring function - C-index for GridSearchCV
def cindex_scorer(estimator, X, y_struct):
    pred = estimator.predict(X)
    return concordance_index_censored(
        y_struct["os_status"], y_struct["os_months"], pred
    )[0]

# Model
cox = CoxnetSurvivalAnalysis(
    fit_baseline_model=True,  # needed for predict_survival_function
    max_iter=20000, #max iterations or steps for convergence
    tol=1e-6  # tolerance for convergence (min change in coefficients)
)

# Parameter grid
alpha_path = np.logspace(-2, 1, 50)
param_grid_coxnet = {
    "alphas": [alpha_path], #intensity of regularization
    "l1_ratio": [0.1, 0.3, 0.5, 0.7, 0.9], #balance between L1 and L2 regularization
}

# Cross-validation setup (in the training set)
cv = KFold(n_splits=5, shuffle=True, random_state=42)

# Grid search setup
grid_search_coxnet = GridSearchCV(
    estimator=cox,
    param_grid=param_grid_coxnet,
    cv=cv,
    scoring=cindex_scorer,
    n_jobs=1,
    refit=True,
    verbose=1,
    error_score="raise"
)

## Prepare survival data for training and test
y_train = y_train.copy()
y_train["os_status"] = y_train["os_status"].fillna(0).astype(int).astype(bool)

y_test = y_test.copy()
y_test["os_status"] = y_test["os_status"].fillna(0).astype(int).astype(bool)

y_train_surv = Surv.from_dataframe("os_status", "os_months", y_train)
y_test_surv  = Surv.from_dataframe("os_status", "os_months", y_test)

# Fit model
grid_search_coxnet.fit(X_train, y_train_surv)

print("Best params:", grid_search_coxnet.best_params_)
print(f"Best C-index on train (CV): {grid_search_coxnet.best_score_:.4f}")

cox_best = grid_search_coxnet.best_estimator_


## Test metrics – C-index

pred_test = cox_best.predict(X_test)
c_index_test = concordance_index_censored(
    y_test_surv["os_status"],
    y_test_surv["os_months"],
    pred_test
)[0]
print(f"C-index on test set (full): {c_index_test:.4f}")

### IBS on IBS-eligible subset (subsampled, with safe grid)

#epsilon = 1e-3  # small margin so times are strictly inside follow-up

# IBS-eligible subset using helper
X_test_cox_ibs, y_test_surv_ibs_cox, TIME_GRID_COX, max_train_time_cox = prepare_ibs_evaluation(
    y_train_surv, y_test_surv, X_test, n_grid=100   # 100 time points is enough
)

print(f"Max train time (Cox): {max_train_time_cox:.2f}")
print(f"Max test time (all): {y_test_surv['os_months'].max():.2f}")
print(f"Max IBS-eligible test time (Cox): {y_test_surv_ibs_cox['os_months'].max():.2f}")
print(f"TIME_GRID_COX range (pre-subsample): {TIME_GRID_COX.min():.2f} to {TIME_GRID_COX.max():.2f}")
print(f"IBS-eligible test size (before subsampling): {len(y_test_surv_ibs_cox)}")

# Subsample IBS test set for speed
max_ibs_test = 300
n_ibs_cox = len(y_test_surv_ibs_cox)

if n_ibs_cox > max_ibs_test:
    rng = np.random.default_rng(42)
    idx_sub_cox = rng.choice(n_ibs_cox, size=max_ibs_test, replace=False)

    if hasattr(X_test_cox_ibs, "iloc"):
        X_test_cox_ibs_sub = X_test_cox_ibs.iloc[idx_sub_cox]
    else:
        X_test_cox_ibs_sub = X_test_cox_ibs[idx_sub_cox]

    y_test_surv_ibs_sub_cox = y_test_surv_ibs_cox[idx_sub_cox]
    print(f"Subsampled IBS test size (Cox): {len(y_test_surv_ibs_sub_cox)} (from {n_ibs_cox})")
else:
    X_test_cox_ibs_sub = X_test_cox_ibs
    y_test_surv_ibs_sub_cox = y_test_surv_ibs_cox

# Recompute a VALID time grid for this subsample
time_train = y_train_surv["os_months"]
time_test_sub_cox = y_test_surv_ibs_sub_cox["os_months"]

t_min = max(time_train.min(), time_test_sub_cox.min())
t_max = min(time_train.max(), time_test_sub_cox.max()) - epsilon  # strictly < max(test_time)

if t_max <= t_min:
    raise ValueError(f"Invalid IBS time window after subsample (Cox): t_min={t_min}, t_max={t_max}")

n_grid = 100
TIME_GRID_COX_SUB = np.linspace(t_min, t_max, n_grid, endpoint=False)

print(f"TIME_GRID_COX_SUB range: {TIME_GRID_COX_SUB.min():.2f} to {TIME_GRID_COX_SUB.max():.2f}")
print(f"Cox subsample follow-up range: {time_test_sub_cox.min():.2f} to {time_test_sub_cox.max():.2f}")

# Predict survival functions on the subsample
t0 = time.time()
surv_fns_cox = list(cox_best.predict_survival_function(X_test_cox_ibs_sub))
print(f"Cox predict_survival_function done in {time.time() - t0:.2f} s")

# Evaluate survival on the new grid
t1 = time.time()
surv_probs_cox = np.asarray([fn(TIME_GRID_COX_SUB) for fn in surv_fns_cox])
print(f"Building surv_probs_cox done in {time.time() - t1:.2f} s")
print("surv_probs_cox shape:", surv_probs_cox.shape)

# IBS
t2 = time.time()
ibs_test = integrated_brier_score(
    y_train_surv,
    y_test_surv_ibs_sub_cox,
    surv_probs_cox,
    TIME_GRID_COX_SUB
)
print(f"integrated_brier_score (Cox) done in {time.time() - t2:.2f} s")

print(f"IBS on test set (IBS-eligible subset, subsampled): {ibs_test:.4f}")

results = {"c_index_test": c_index_test, "ibs_test": ibs_test}


## Bootstrap CoxSurvival

We decided to run bootstrap to be consistent with similar studies and provide uncertainty in the C-index score. \
Boostrap: will pick `n_test` (nomber of row of test set) 1,000 times (B = 1,000). Some may be repeat, some may not appear since it is with replacement. Then, it will compute C-index and IBS for each iteration. From this, we will get percentiles 2.5 - 97.5

In [None]:
### Bootstrap C-index (CoxNet)

from sksurv.metrics import concordance_index_censored
import numpy as np


# Point estimate on full test set
time_test  = y_test_surv["os_months"]
event_test = y_test_surv["os_status"].astype(bool)

risk_test = cox_best.predict(X_test)

cindex_test_cox = concordance_index_censored(
    event_test,
    time_test,
    risk_test
)[0]

print(f"C-index (test, CoxNet) point estimate: {cindex_test_cox:.4f}")


# Bootstrap 95% CI on full test set
rng = np.random.default_rng(123)
B = 1000

n_test = X_test.shape[0]
cindex_boot_cox = np.empty(B, dtype=float)

print(f"Bootstrap C-index on full test set ({n_test} patients) – CoxNet.")

for b in range(B):
    idx = rng.integers(0, n_test, size=n_test)

    Xb = X_test.iloc[idx]
    tb = time_test[idx]
    eb = event_test[idx]

    risk_b = cox_best.predict(Xb)

    cindex_b = concordance_index_censored(
        eb,
        tb,
        risk_b
    )[0]

    cindex_boot_cox[b] = cindex_b

cindex_ci_cox = np.percentile(cindex_boot_cox, [2.5, 97.5])

print(
    f"C-index (test, CoxNet): {cindex_test_cox:.4f} | "
    f"95% CI [{cindex_ci_cox[0]:.4f}, {cindex_ci_cox[1]:.4f}]"
)


In [None]:
### Bootstrap IBS on subsampled IBS-eligible set (Cox) --- no C-index in this code

# Bootstrap setup
rng = np.random.default_rng(123)  # reproducible bootstrap
B = 1000                          # number of bootstrap resamples

ibs_boot = []

time_train  = y_train_surv["os_months"]
event_train = y_train_surv["os_status"]

censor_times = time_train[~event_train]
if censor_times.size == 0:
    raise ValueError("Cannot compute IBS/bootstrap: no censored observations in training data.")
max_train_time = time_train.max()
eps = 1e-3

n_ibs = X_test_cox_ibs_sub.shape[0]
print(f"Bootstrap IBS on {n_ibs} subsampled IBS-eligible test patients (Cox).")

for _ in range(B):
    # Resample within subsampled IBS set
    idx_ibs = rng.integers(0, n_ibs, size=n_ibs)

    Xb_ibs = X_test_cox_ibs_sub.iloc[idx_ibs]
    yb_surv = y_test_surv_ibs_sub_cox[idx_ibs]
    yb_event = yb_surv["os_status"]
    yb_time  = yb_surv["os_months"]

    # Safe time window for this bootstrap replicate
    t_min_b = max(time_train.min(), yb_time.min())
    t_max_b = min(max_train_time, yb_time.max()) - eps

    if t_max_b <= t_min_b:
        continue

    time_points_b = np.linspace(t_min_b, t_max_b, 100, endpoint=False)

    surv_fns_b = list(cox_best.predict_survival_function(Xb_ibs))
    surv_probs_b = np.asarray([fn(time_points_b) for fn in surv_fns_b])


# C-index


    ibs_b = integrated_brier_score(
        y_train_surv,
        yb_surv,
        surv_probs_b,
        time_points_b
    )
    ibs_boot.append(ibs_b)

ibs_boot = np.array(ibs_boot)
ibs_ci = np.percentile(ibs_boot, [2.5, 97.5])

print(f"IBS     (test, subsampled IBS subset): {ibs_test:.4f} | 95% CI [{ibs_ci[0]:.3f}, {ibs_ci[1]:.3f}]")


# Random Survival Forest
We are implementing hyperparameter tuning by Bayesian optimization (Optuna) since this model will analyze high dimensional and nonlinear relationships, which is exhaustive using `GridSearch` and Random Search may leave the best parameter setiting out. Bayesian approach uses past results to model the parameter-performance relationship and pick the next promising set. Bayesian optimization builds a surrogate models that learns a mapping. Each trial is not random, as this approach chooses the next parameter set based on where it expects the biggest improvement

### Key features of `optuna`
1. **Objective function**: The core of Optuna is the objective function, which encapsulates the model training and evaluation process. This function takes a trial object as an argument, allowing hyperparameter values for the current trial.
2. **Trial Object**: The trial object within the objective function provides methods like `suggest_float()`, `suggest_int()`, `suggest_categorical(`), etc., to define the search space for different types of hyperparameters.
3. **Study**: A `Study` object in Optuna manages the optimization process. You create a study, specify the optimization direction (maximize or minimize the objective function's return value), and then call `study.optimize()` with your objective function and the number of trials.
4. **Pruners**: `optuna` includes pruners that can stop unpromising trials early based on intermediate evaluation scores, saving computational resources.
5. **Samplers**: `optuna` offers various samplers, such as Tree-structured Parzen Estimator (TPE) and Random Search, to intelligently explore the hyperparameter space. TPE, for instance, adapts its search strategy based on the performance of previous trials, focusing on more promising regions. The Tree-structured Parzen Estimator (TPE) is a widely used Bayesian optimization algorithm, and Optuna's default sampler, for efficiently finding the optimal hyperparameters of a machine learning model. TPE intelligently learns from the results of past trials to propose more promising hyperparameters for the next evaluation, significantly outperforming random or grid search methods, especially for complex or computationally expensive problems
6. **Visualization**: `optuna` provides tools for visualizing the optimization process, including plots for hyperparameter importance, optimization history, and parallel coordinate plots to understand the relationships between hyperparameters and performance.

In [None]:
### in Terminal:
# mamba activate pydev
# mamba install -c conda-forge optuna -y
# python -m ipykernel install --user --name pydev --display-name "Python (pydev)"

In [None]:
#### Random Survival Forest with Optuna hyperparameter tuning (C-index objective)

import numpy as np, pandas as pd, warnings, optuna
from sklearn.feature_selection import VarianceThreshold
from sksurv.ensemble import RandomSurvivalForest
from sksurv.util import Surv
from sksurv.metrics import concordance_index_censored, integrated_brier_score

warnings.filterwarnings("ignore", category=UserWarning)


##   Data & feature filtering

# Avoid zero months by clipping at small value and avoid kernel crashes
epsilon = 0.01
y_train["os_months"] = y_train["os_months"].clip(lower=epsilon)
y_test["os_months"]  = y_test["os_months"].clip(lower=epsilon)

# Variance thresholding
vt = VarianceThreshold(threshold=0.02) # remove low-variance features
X_train_rsf = vt.fit_transform(X_train) #filter fit on train to avoid data leakage
X_test_rsf  = vt.transform(X_test)

# Prepare survival data for training and test (survival object arrays)
y_train_surv = Surv.from_dataframe("os_status", "os_months", y_train)
y_test_surv  = Surv.from_dataframe("os_status", "os_months", y_test)

##   Optuna objective (C-index)

# Objective function for Optuna
def objective(trial):
    params = {
        "n_estimators": trial.suggest_int("n_estimators", 400, 1200),
        "max_depth": trial.suggest_int("max_depth", 6, 18),
        "min_samples_split": trial.suggest_int("min_samples_split", 2, 15),
        "min_samples_leaf": trial.suggest_int("min_samples_leaf", 2, 8),
        "max_features": trial.suggest_float("max_features", 0.3, 0.8),
        "bootstrap": True,
    }

# Train RSF model with suggested hyperparameters
    rsf = RandomSurvivalForest(
        n_jobs=-1,
        random_state=42,
        oob_score=False,   # because evaluation is on test set and not using oob
        **params # inject chosen hyperparameters
    )
    rsf.fit(X_train_rsf, y_train_surv)

# Evaluate C-index on test set
    pred = rsf.predict(X_test_rsf)
    c_index = concordance_index_censored(
        y_test["os_status"].astype(bool),
        y_test["os_months"].astype(float),
        pred
    )[0]

# Compute intermediate C-index on test set and handle pruning
    trial.report(c_index, step=1)
    if trial.should_prune():
        raise optuna.TrialPruned()

    return c_index

# Optuna study setup and optimization
study = optuna.create_study(
    direction="maximize",
    sampler=optuna.samplers.TPESampler(
        seed=42, multivariate=True, n_startup_trials=25
    ),
    pruner=optuna.pruners.HyperbandPruner(),
)

# Run optimization
study.optimize(objective, n_trials=150, show_progress_bar=True)

print("\nBest Trial:", study.best_trial.number)
print("Best Params:", study.best_params)
print(f"Best C-index (Optuna objective): {study.best_value:.4f}")


In [None]:
### Refit best RSF parameters on full train

# Optimal hyperparameters found by Optuna from previous cell
best_params = study.best_params

# Create a new RSF with best hyperparameters
rsf_best = RandomSurvivalForest(
    n_jobs=-1,
    random_state=42,
    oob_score=True,
    **best_params
)
rsf_best.fit(X_train_rsf, y_train_surv)

# Ensemble smoothing (same hyperparameters with different random seeds to reduce variance))
seeds = [21, 42, 63, 84]
rsf_models = [
    RandomSurvivalForest(
        random_state=s, n_jobs=-1, **best_params
    ).fit(X_train_rsf, y_train_surv)
    for s in seeds
]
preds = np.mean([m.predict(X_test_rsf) for m in rsf_models], axis=0) # average predictions

# Evaluate on test set (C-index)
c_index_test = concordance_index_censored(
    y_test_surv["os_status"],
    y_test_surv["os_months"],
    preds
)[0]

print(f"\nRSF – C-index on test set (full): {c_index_test:.4f}")


In [None]:
### IBS on IBS-eligible subset (subsampled, with safe grid) for RSF

from sksurv.metrics import integrated_brier_score
import time
import numpy as np

epsilon = 1e-3  # small safety margin

# Start with IBS-eligible subset (same as before, but grid will be recomputed later)
X_test_rsf_ibs, y_test_surv_ibs_rsf, TIME_GRID_RSF, max_train_time_rsf = prepare_ibs_evaluation(
    y_train_surv, y_test_surv, X_test_rsf, n_grid=100
)

print(f"Max train time (RSF): {max_train_time_rsf:.2f}")
print(f"Max test time (all): {y_test_surv['os_months'].max():.2f}")
print(f"Max IBS-eligible test time (RSF): {y_test_surv_ibs_rsf['os_months'].max():.2f}")
print(f"TIME_GRID_RSF range (pre-subsample): {TIME_GRID_RSF.min():.2f} to {TIME_GRID_RSF.max():.2f}")
print(f"IBS-eligible test size (before subsampling): {len(y_test_surv_ibs_rsf)}")

# Subsample IBS test set to at most 300 pts
max_ibs_test = 300
n_ibs = len(y_test_surv_ibs_rsf)

# Subsample IBS test set
if n_ibs > max_ibs_test:
    rng = np.random.default_rng(42)
    idx_sub = rng.choice(n_ibs, size=max_ibs_test, replace=False)
    if hasattr(X_test_rsf_ibs, "iloc"):
        X_test_rsf_ibs_sub = X_test_rsf_ibs.iloc[idx_sub]
    else:
        X_test_rsf_ibs_sub = X_test_rsf_ibs[idx_sub]
    y_test_surv_ibs_sub = y_test_surv_ibs_rsf[idx_sub]
    print(f"Subsampled IBS test size: {len(y_test_surv_ibs_sub)} (from {n_ibs})")
else:
    X_test_rsf_ibs_sub = X_test_rsf_ibs
    y_test_surv_ibs_sub = y_test_surv_ibs_rsf

# Recompute a VALID time grid for *this subsample*
time_train = y_train_surv["os_months"]
time_test_sub = y_test_surv_ibs_sub["os_months"]

t_min = max(time_train.min(), time_test_sub.min())
t_max = min(time_train.max(), time_test_sub.max()) - epsilon  # must be strictly < max(test_time)

if t_max <= t_min:
    raise ValueError(f"Invalid IBS time window after subsample: t_min={t_min}, t_max={t_max}")

n_grid = 100
TIME_GRID_RSF_SUB = np.linspace(t_min, t_max, n_grid, endpoint=False)

print(f"TIME_GRID_RSF_SUB range: {TIME_GRID_RSF_SUB.min():.2f} to {TIME_GRID_RSF_SUB.max():.2f}")
print(f"Subsample follow-up range: {time_test_sub.min():.2f} to {time_test_sub.max():.2f}")

# Predict survival functions on subsample
t0 = time.time()
surv_fns_rsf = rsf_best.predict_survival_function(X_test_rsf_ibs_sub)
print(f"predict_survival_function done in {time.time() - t0:.2f} s")

# Evaluate survival on new grid
t1 = time.time()
surv_probs_rsf = np.asarray([fn(TIME_GRID_RSF_SUB) for fn in surv_fns_rsf])
print(f"Building surv_probs_rsf done in {time.time() - t1:.2f} s")
print("surv_probs_rsf shape:", surv_probs_rsf.shape)

# IBS on subsampled IBS-eligible set
t2 = time.time()
ibs_test_rsf = integrated_brier_score(
    y_train_surv,
    y_test_surv_ibs_sub,
    surv_probs_rsf,
    TIME_GRID_RSF_SUB
)
print(f"integrated_brier_score done in {time.time() - t2:.2f} s")

print("\nFinal RSF Results:")
print(f"Best Parameters: {best_params}")
print(f"C-index on test set (full): {c_index_test:.4f}")
print(f"IBS on test set (IBS-eligible subset, subsampled): {ibs_test_rsf:.4f}")

results_rsf = {"c_index_test": c_index_test, "ibs_test": ibs_test_rsf}


## Bootstrap Random Survival Forest

We decided to run bootstrap to be consistent with similar studies and provide uncertainty in the C-index score. \
Boostrap: will pick `n_test` (nomber of row of test set) 1,000 times (B = 1,000). Some may be repeat, some may not appear since it is with replacement. Then, it will compute C-index and IBS for each iteration. From this, we will get percentiles 2.5 - 97.5

In [None]:
from sklearn.utils import resample
import numpy as np

### Bootstrap 95% CIs on test (RSF; model fixed, no re-tuning)

B = 1000 # number of bootstrap resamples
rng = np.random.default_rng(42) # random seed

# C-index bootstrap on FULL test set
n_test = X_test_rsf.shape[0] # number of test samples
c_index_boot = [] #list to store bootstrap C-index values

# Create arrays for event and time (surv) from the full test set
y_event_full = y_test_surv["os_status"]
y_time_full  = y_test_surv["os_months"]

for _ in range(B):
    # Resample indices for the full test set
    idx = rng.integers(0, n_test, size=n_test)

    yb_event = y_event_full[idx] # bootstrap sample of events
    yb_time  = y_time_full[idx] # bootstrap sample of times
    pred_b   = preds[idx]  # ensemble RSF risk scores

    # Calculate C-index for this bootstrap sample
    cidx_b = concordance_index_censored(
        yb_event,
        yb_time,
        pred_b)[0]
    c_index_boot.append(cidx_b)

# Final C-index bootstrap results
c_index_boot = np.array(c_index_boot)
c_index_ci = np.percentile(c_index_boot, [2.5, 97.5])


# IBS bootstrap on IBS-eligible subset

ibs_boot = [] #list to store bootstrap IBS values

# Training times / events for censoring model
time_train  = y_train_surv["os_months"]
event_train = y_train_surv["os_status"]
max_train_time = time_train.max()
eps = 1e-3

# IBS-eligible subset size
n_ibs = X_test_rsf_ibs.shape[0]
print(f"Bootstrap IBS on {n_ibs} IBS-eligible RSF test patients.")

for _ in range(B):
    # Resample within IBS-eligible subset
    idx_ibs = rng.integers(0, n_ibs, size=n_ibs)

    # Row-wise bootstrap sample of features
    if hasattr(X_test_rsf_ibs, "iloc"):
        Xb = X_test_rsf_ibs.iloc[idx_ibs]
    else:
        Xb = X_test_rsf_ibs[idx_ibs]

    # Matching survival outcomes (structured array)
    yb_surv   = y_test_surv_ibs_rsf[idx_ibs]
    yb_event  = yb_surv["os_status"]
    yb_time   = yb_surv["os_months"]

    # Define safe time window for THIS bootstrap sample
    t_min_b = max(time_train.min(), yb_time.min())
    t_max_b = min(max_train_time, yb_time.max()) - eps

    # If the window collapses (rare), skip this sample
    if t_max_b <= t_min_b:
        continue

    time_points_b = np.linspace(t_min_b, t_max_b, 200, endpoint=False)

    # Predict survival for bootstrap sample at time_points_b
    surv_fns_b   = rsf_best.predict_survival_function(Xb)
    surv_probs_b = np.asarray([fn(time_points_b) for fn in surv_fns_b])

    # IBS for this bootstrap sample
    ibs_b = integrated_brier_score(
        y_train_surv,  # training data for censoring KM
        yb_surv,       # THIS bootstrap sample (IBS-eligible)
        surv_probs_b,
        time_points_b
    )
    ibs_boot.append(ibs_b)

ibs_boot = np.array(ibs_boot)
ibs_ci   = np.percentile(ibs_boot, [2.5, 97.5])


# Print RSF results + bootstrap CIs

print(f"RSF C-index (test, full): {c_index_test:.4f}")
print(f"  95% CI: [{c_index_ci[0]:.4f}, {c_index_ci[1]:.4f}]")
print(f"RSF IBS (test, IBS subset): {ibs_test_rsf:.4f}")
print(f"  95% CI: [{ibs_ci[0]:.4f}, {ibs_ci[1]:.4f}]")


## Gradient Boosted Survival Tree (GBST)
We are implementing hyperparameter tuning using Bayesian (`optuna`) for the same reason as `RandomSurvivalForest`: given the high dimensional features' relationship, Optuna optimization can find the optimal hyperparameters, resulting in better performance

In [None]:
### Gradient Boosted Survival Tree (GBST) with Optuna Tuning

import optuna, numpy as np
from sksurv.ensemble import GradientBoostingSurvivalAnalysis
from sksurv.metrics import concordance_index_censored, integrated_brier_score
from sksurv.util import Surv
import time

# Survival objects (reuse same y_train, y_test)
y_train_surv = Surv.from_dataframe("os_status", "os_months", y_train)
y_test_surv  = Surv.from_dataframe("os_status", "os_months", y_test)


## Optuna objective for GBST (C-index)

# Objective function for Optuna
def objective_gbst(trial):
    params = {
        "n_estimators": trial.suggest_int("n_estimators", 100, 2000),
        "learning_rate": trial.suggest_float("learning_rate", 0.01, 0.3, log=True),
        "max_depth": trial.suggest_int("max_depth", 2, 6),
        "min_samples_split": trial.suggest_int("min_samples_split", 2, 10),
        "min_samples_leaf": trial.suggest_int("min_samples_leaf", 1, 6),
        "subsample": trial.suggest_float("subsample", 0.6, 1.0),
        "max_features": trial.suggest_float("max_features", 0.3, 1.0),
    }

# Train GBST model with suggested hyperparameters
    gbst = GradientBoostingSurvivalAnalysis(random_state=42, **params)
    gbst.fit(X_train, y_train_surv)

# Evaluate C-index on test set
    pred = gbst.predict(X_test)
    c_index = concordance_index_censored(
        y_test["os_status"].astype(bool),
        y_test["os_months"].astype(float),
        pred
    )[0]

    return c_index


## Optuna study and optimization

# Optuna study setup and optimization
study_gbst = optuna.create_study(
    direction="maximize",
    sampler=optuna.samplers.TPESampler(seed=42, multivariate=True, n_startup_trials=20),
    pruner=optuna.pruners.HyperbandPruner()
)
study_gbst.optimize(objective_gbst, n_trials=120, show_progress_bar=True)

print("\nBest Trial (GBST):", study_gbst.best_trial.number)
print("Best Params (GBST):", study_gbst.best_params)
print(f"Best C-index (CV, GBST): {study_gbst.best_value:.4f}")


# -----------------------------------------------------------------------------------

## Refit best GBST parameters on full train (optimal hyperparameters found by Optuna)

# Create a new GBST with best hyperparameters
gbst_best = GradientBoostingSurvivalAnalysis(
    random_state=42, **study_gbst.best_params
)
gbst_best.fit(X_train, y_train_surv)

# C-index on FULL test set
pred_test_gbst = gbst_best.predict(X_test)

# Evaluate on test set (C-index)
c_index_test_gbst = concordance_index_censored(
    y_test_surv["os_status"],
    y_test_surv["os_months"],
    pred_test_gbst
)[0]

print(f"GBST – C-index on test set (full): {c_index_test_gbst:.4f}")

# ----------------------------------------------------------------------------------

## IBS on IBS-eligible subset (subsampled, with safe grid)

epsilon = 1e-3  # small margin so times are strictly inside follow-up

# IBS-eligible subset using your helper
X_test_gbst_ibs, y_test_surv_ibs_gbst, TIME_GRID_GBST, max_train_time_gbst = prepare_ibs_evaluation(
    y_train_surv, y_test_surv, X_test, n_grid=100  # 100 points is enough
)

print(f"Max train time (GBST): {max_train_time_gbst:.2f}")
print(f"Max test time (all): {y_test_surv['os_months'].max():.2f}")
print(f"Max IBS-eligible test time (GBST): {y_test_surv_ibs_gbst['os_months'].max():.2f}")
print(f"TIME_GRID_GBST range (pre-subsample): {TIME_GRID_GBST.min():.2f} to {TIME_GRID_GBST.max():.2f}")
print(f"IBS-eligible test size (before subsampling): {len(y_test_surv_ibs_gbst)}")

# Subsample IBS test set (e.g. to 300 patients)
max_ibs_test = 300
n_ibs_gbst = len(y_test_surv_ibs_gbst)

# Subsample IBS test set to at most 300 pts
if n_ibs_gbst > max_ibs_test:
    rng = np.random.default_rng(42)
    idx_sub_gbst = rng.choice(n_ibs_gbst, size=max_ibs_test, replace=False)

# Get subsampled X_test and y_test_surv
    if hasattr(X_test_gbst_ibs, "iloc"):
        X_test_gbst_ibs_sub = X_test_gbst_ibs.iloc[idx_sub_gbst]
    else:
        X_test_gbst_ibs_sub = X_test_gbst_ibs[idx_sub_gbst]

# Subsampled survival outcomes
    y_test_surv_ibs_sub_gbst = y_test_surv_ibs_gbst[idx_sub_gbst]
    print(f"Subsampled IBS test size (GBST): {len(y_test_surv_ibs_sub_gbst)} (from {n_ibs_gbst})")
else:
    X_test_gbst_ibs_sub = X_test_gbst_ibs
    y_test_surv_ibs_sub_gbst = y_test_surv_ibs_gbst

# Recompute a VALID time grid for this subsample
time_train = y_train_surv["os_months"]
time_test_sub_gbst = y_test_surv_ibs_sub_gbst["os_months"]

# Recompute a VALID time grid for *this subsample*
t_min = max(time_train.min(), time_test_sub_gbst.min())
t_max = min(time_train.max(), time_test_sub_gbst.max()) - epsilon  # strictly < max(test_time)

if t_max <= t_min:
    raise ValueError(f"Invalid IBS time window after subsample (GBST): t_min={t_min}, t_max={t_max}")

n_grid = 100
TIME_GRID_GBST_SUB = np.linspace(t_min, t_max, n_grid, endpoint=False)

print(f"TIME_GRID_GBST_SUB range: {TIME_GRID_GBST_SUB.min():.2f} to {TIME_GRID_GBST_SUB.max():.2f}")
print(f"GBST subsample follow-up range: {time_test_sub_gbst.min():.2f} to {time_test_sub_gbst.max():.2f}")

# Predict survival functions on the subsample
t0 = time.time()
surv_fns_gbst = gbst_best.predict_survival_function(X_test_gbst_ibs_sub)
print(f"GBST predict_survival_function done in {time.time() - t0:.2f} s")

# Evaluate survival on the new grid
t1 = time.time()
surv_probs_gbst = np.asarray([fn(TIME_GRID_GBST_SUB) for fn in surv_fns_gbst])
print(f"Building surv_probs_gbst done in {time.time() - t1:.2f} s")
print("surv_probs_gbst shape:", surv_probs_gbst.shape)

# IBS on subsampled IBS-eligible set
t2 = time.time()
ibs_test_gbst = integrated_brier_score(
    y_train_surv,
    y_test_surv_ibs_sub_gbst,
    surv_probs_gbst,
    TIME_GRID_GBST_SUB
)
print(f"integrated_brier_score (GBST) done in {time.time() - t2:.2f} s")

# Output summary
print("\nGradient Boosted Survival Tree (GBST) Results:")
print(f"Best parameters: {study_gbst.best_params}")
print(f"C-index on test set (full): {c_index_test_gbst:.4f}")
print(f"IBS on test set (IBS-eligible subset, subsampled): {ibs_test_gbst:.4f}")

results_gbst = {
    "c_index_test_gbst": c_index_test_gbst,
    "ibs_test_gbst": ibs_test_gbst,
}


## Bootstrap for `GradientBoostingSurvivalAnalysis`

In [None]:
## Bootstrap C-index on FULL test set (GBST)

from sksurv.metrics import concordance_index_censored, integrated_brier_score
import numpy as np

# Bootstrap setup  
B = 1000
rng = np.random.default_rng(42)

# C-index bootstrap on FULL test set
n_test = X_test.shape[0]
cidx_boot_gbst = []

y_event_te = y_test_surv["os_status"]
y_time_te  = y_test_surv["os_months"]

# Bootstrap C-index
for _ in range(B):
    idx = rng.integers(0, n_test, size=n_test)
    Xb = X_test.iloc[idx]
    yb_event = y_event_te[idx]
    yb_time  = y_time_te[idx]

    pred_b = gbst_best.predict(Xb)
    cidx_b = concordance_index_censored(yb_event, yb_time, pred_b)[0]
    cidx_boot_gbst.append(cidx_b)

cidx_boot_gbst = np.array(cidx_boot_gbst)
cidx_ci_gbst = np.percentile(cidx_boot_gbst, [2.5, 97.5])

# Print GBST results + bootstrap CIs

print(f"GBST C-index (test, full): {c_index_test:.4f}")
print(f"  95% CI: [{c_index_ci[0]:.4f}, {c_index_ci[1]:.4f}]")
print(f"GBST IBS (test, IBS subset): {ibs_test_rsf:.4f}")
print(f"  95% CI: [{ibs_ci[0]:.4f}, {ibs_ci[1]:.4f}]")

# SHAP
SHAP, which stands for SHapley Additive exPlanations, is a unified framework to explain how each feature contributes to a model’s prediction, for any ML model (tree models, random forests, XGBoost, neural nets, etc.). \
SHAP explains:

1. **Global importance**
Which features matter overall \
Ranked importance \
Summary plots 

2. **Local importance**
Why the model made a prediction for this specific patient \
Direction + magnitude of contribution 

3. **Interaction effects**
How two features together change risk 

4. **Consistent feature importance**
If a feature contributes more in the model, its SHAP value must be higher
(unlike some random forest importances)

In [None]:
# Install SHAP if needed (via mamba or pip in Terminal)
import shap
import matplotlib.pyplot as plt


### Prepare DataFrames for SHAP Analysis

SHAP requires feature names for interpretability. We'll rebuild DataFrames from the numpy arrays if needed.

In [None]:
### Rebuild DataFrames with feature names for SHAP

##### Notes:
# Get feature names that survived variance threshold filtering
# The VarianceThreshold object 'vt' was fitted on X_train earlier
if isinstance(X, pd.DataFrame):
    original_feature_names = X.columns.tolist()
    # Get the features that passed the variance threshold
    # vt.get_support() returns a boolean mask of selected features
    selected_features_mask = vt.get_support()
    feature_names = [original_feature_names[i] for i, selected in enumerate(selected_features_mask) if selected]
else:
    # Fallback: create generic names matching the current shape
    feature_names = [f"feat_{i}" for i in range(X_train.shape[1])]

# Convert X_train and X_test to DataFrames if they're numpy arrays
if isinstance(X_train, pd.DataFrame):
    X_train_df = X_train.copy()
else:
    X_train_df = pd.DataFrame(X_train, columns=feature_names)

if isinstance(X_test, pd.DataFrame):
    X_test_df = X_test.copy()
else:
    X_test_df = pd.DataFrame(X_test, columns=feature_names)

print(f"X_train_df shape: {X_train_df.shape}")
print(f"X_test_df shape: {X_test_df.shape}")
print(f"Number of features: {len(feature_names)}")
print(f"Features removed by VarianceThreshold: {len(original_feature_names) - len(feature_names)}")

### SHAP Analysis for Gradient Boosted Survival Tree (GBST)

We'll use SHAP to explain the GBST model predictions. SHAP values show how much each feature contributes to the risk score for each patient. \
Since GBST achieved higher performance, we will focus on this algorithm

In [None]:
### SHAP for GBST Model

# Sample background data for SHAP baseline (smaller subset for speed)
X_bg_gbst = X_train_df.sample(n=min(100, X_train_df.shape[0]), random_state=42)

# Sample test data to explain (manageable subset)
X_explain_gbst = X_test_df.sample(n=min(200, X_test_df.shape[0]), random_state=42)

print(f"Background sample: {X_bg_gbst.shape}")
print(f"Explanation sample: {X_explain_gbst.shape}")

# Create SHAP explainer for GBST
# Use the predict function to get risk scores
explainer_gbst = shap.Explainer(
    lambda x: gbst_best.predict(x if isinstance(x, np.ndarray) else x.values),
    X_bg_gbst
)

# Compute SHAP values
print("Computing SHAP values for GBST... (this may take a few minutes)")
shap_values_gbst = explainer_gbst(X_explain_gbst)

print(f"SHAP values shape: {shap_values_gbst.values.shape}")

### SHAP Summary Plots for GBST

Global feature importance visualizations showing which features most impact survival predictions.

In [None]:
### SHAP Summary Plots - Bar Plot (Feature Importance) - GBST

# SHAP bar plot for GBST 
shap.summary_plot(
    shap_values_gbst.values,
    X_explain_gbst,
    plot_type="bar",
    show=False,
    max_display=20  # top 20 features
)

fig = plt.gcf()
fig.set_size_inches(8, 6)  # control figure size

# Titles & labels
plt.title(
    "GBST: Top 20 Features by Mean |SHAP| (Impact on Survival Risk)",
    fontsize=14,
    pad=18
)
plt.xlabel("Mean(|SHAP value|)", fontsize=12)

# Tidy up fonts and layout
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.tight_layout(rect=[0, 0, 1, 0.96])  # leave room for title

# Save high-res figure
#plt.savefig(FIGURES / "shap_gbst_bar_os.tiff", dpi=300, bbox_inches="tight")
plt.show()

In [None]:
### SHAP Beeswarm Plot (Feature Impact with Direction) - GBST

plt.figure(figsize=(10, 10))
shap.summary_plot(
    shap_values_gbst.values,
    X_explain_gbst,
    show=False,
    max_display=20  # Show top 20 features
)
plt.title("GBST: Feature Impact on Survival Risk", fontsize=14, pad=20)
plt.tight_layout()
#plt.savefig(FIGURES / "shap_gbst_beeswarm_os.tiff", dpi=300, bbox_inches="tight")
plt.show()

print("\nInterpretation:")
print("- Each dot is a patient")
print("- Red = high feature value, Blue = low feature value")
print("- Right side (positive SHAP) = increases risk")
print("- Left side (negative SHAP) = decreases risk")

### SHAP dependence plot ASXL1 versus non ASXL1
To better understand the importance of ASXL1 and non-ASXL1 DTA mutations, we plotted them separately

In [None]:
### SHAP Dependence Plot for ASXL1 Burden - GBST

# Ensure 'asxl1_count' is in the DataFrame
needed_cols = ["asxl1_count", "dta_non_asxl1_counts", "asxl1"]
missing = [c for c in needed_cols if c not in X_explain_gbst.columns]
print("Missing columns:", missing)   # should print [] ideally

# For convenience
shap_vals = shap_values_gbst.values  # (n_samples, n_features)


# ASXL1 burden: dependence plot
plt.figure(figsize=(7, 5))
shap.dependence_plot(
    "asxl1_count",
    shap_vals,
    X_explain_gbst,
    interaction_index=None,  # pure main effect
    show=False
)
plt.title("GBST – SHAP Dependence: ASXL1 Burden", fontsize=14, pad=12)
plt.xlabel("ASXL1 mutational burden (standardized, z-score)", fontsize=12)
plt.ylabel("SHAP value (impact on log survival risk)", fontsize=12)
plt.tight_layout()
#plt.savefig(FIGURES / "dep_asxl1_count_os.tiff", dpi=300, bbox_inches="tight")
plt.show()


In [None]:
### SHAP Dependence Plot for Non-ASXL1 DTA Burden - GBST

# Ensure 'dta_non_asxl1_counts' is in the DataFrame
plt.figure(figsize=(7, 5))
shap.dependence_plot(
    "dta_non_asxl1_counts",
    shap_vals,
    X_explain_gbst,
    interaction_index=None,
    show=False
)

ax = plt.gca()

# Get current x-limits and create ticks every 0.5
xmin, xmax = ax.get_xlim()
xmin = np.floor(xmin * 2) / 2.0
xmax = np.ceil(xmax * 2) / 2.0
ax.set_xticks(np.arange(xmin, xmax + 0.25, 0.5))

plt.title("GBST – SHAP Dependence: Non-ASXL1 DTA Burden", fontsize=14, pad=12)
plt.xlabel("Non-ASXL1 DTA mutational burden (standardized, z-score)", fontsize=12)
plt.ylabel("SHAP value (impact on log survival risk)", fontsize=12)
#plt.savefig(FIGURES / "dep_dta_non_asxl1_counts_os.tiff", dpi=300, bbox_inches="tight")
plt.show()


In [None]:
### SHAP Dependence Plot for TET2 DTA Burden - GBST

# Ensure 'tet2_count' is in the DataFrame
plt.figure(figsize=(7, 5))
shap.dependence_plot(
    "tet2_count",
    shap_vals,
    X_explain_gbst,
    interaction_index=None,
    show=False
)

ax = plt.gca()

# Get current x-limits and create ticks every 0.5
xmin, xmax = ax.get_xlim()
xmin = np.floor(xmin * 2) / 2.0
xmax = np.ceil(xmax * 2) / 2.0
ax.set_xticks(np.arange(xmin, xmax + 0.25, 0.5))

plt.title("GBST – SHAP Dependence: TET2 DTA Burden", fontsize=14, pad=12)
plt.xlabel("TET2 DTA mutational burden (standardized, z-score)", fontsize=12)
plt.ylabel("SHAP value (impact on log survival risk)", fontsize=12)
#plt.savefig(FIGURES / "dep_tet2_count_os.tiff", dpi=300, bbox_inches="tight")
plt.show()


### SHAP Analysis for Random Survival Forest (RSF)
Analyzed for comparison with GBST

In [None]:
### Rebuild RSF dataframes in the reduced feature space for SHAP (without refit)
# To avoid the error 'X has 45 features, but RandomSurvivalForest is expecting 37 features as input'


# X_train: original training DataFrame (before VarianceThreshold)
# vt: the VarianceThreshold used for RSF
# X_train_rsf, X_test_rsf: the arrays RSF was actually trained on

if isinstance(X_train, pd.DataFrame):
    original_feature_names = X_train.columns
    selected_mask = vt.get_support()
    selected_features = original_feature_names[selected_mask]

    X_train_rsf_df = pd.DataFrame(X_train_rsf, columns=selected_features)
    X_test_rsf_df  = pd.DataFrame(X_test_rsf,  columns=selected_features)
else:
    # Fallback if X_train wasn't a DataFrame
    X_train_rsf_df = pd.DataFrame(X_train_rsf)
    X_test_rsf_df  = pd.DataFrame(X_test_rsf)

In [None]:
### SHAP for RSF Model

# Background sample (for SHAP baseline)
X_bg_rsf = X_train_rsf_df.sample(
    n=min(100, X_train_rsf_df.shape[0]),
    random_state=42)

# Test sample to explain
X_explain_rsf = X_test_rsf_df.sample(
    n=min(200, X_test_rsf_df.shape[0]),
    random_state=42)

print(f"Background sample:  {X_bg_rsf.shape}")
print(f"Explanation sample: {X_explain_rsf.shape}")

# SHAP explainer – RSF risk scores
explainer_rsf = shap.Explainer(
    lambda x: rsf_best.predict(x if isinstance(x, np.ndarray) else x.values),
    X_bg_rsf)

print("Computing SHAP values for RSF... (this may take a few minutes)")
shap_values_rsf = explainer_rsf(X_explain_rsf)

print(f"SHAP values shape: {shap_values_rsf.values.shape}")

In [None]:
### RSF SHAP Summary Plots

# Bar plot
plt.figure(figsize=(10, 8))
shap.summary_plot(
    shap_values_rsf.values,
    X_explain_rsf,
    plot_type="bar",
    show=False,
    max_display=20)

plt.title("RSF: Top 20 Features by Mean Absolute SHAP Value", fontsize=14, pad=20)
plt.tight_layout()
#plt.savefig(FIGURES/"shap_rsf_bar_os.tiff", dpi=300, bbox_inches="tight")
plt.show()

# Beeswarm plot
plt.figure(figsize=(10, 10))
shap.summary_plot(
    shap_values_rsf.values,
    X_explain_rsf,
    show=False,
    max_display=25)

plt.title("RFS: Feature Impact on Survival Risk", fontsize=14, pad=20)
plt.tight_layout()
#plt.savefig(FIGURES/"shap_rsf_beeswarm_os.tiff", dpi=300, bbox_inches="tight")
plt.show()

## Comparison: GBST and RSF

### Feature Importance Comparison: GBST vs RSF

Compare which features are most important across different models.

In [None]:
### Compare Feature Importance Across Models (Normalized)


## Normalize SHAP values

# GBST normalized importance
gbst_raw = np.abs(shap_values_gbst.values)
gbst_norm = gbst_raw / gbst_raw.sum(axis=1, keepdims=True)   # normalize per sample
gbst_importance = pd.DataFrame({
    'feature': X_explain_gbst.columns,
    'gbst_importance': gbst_norm.mean(axis=0)
}).sort_values('gbst_importance', ascending=False)

# RSF normalized importance
rsf_raw = np.abs(shap_values_rsf.values)
rsf_norm = rsf_raw / rsf_raw.sum(axis=1, keepdims=True)      # normalize per sample
rsf_importance = pd.DataFrame({
    'feature': X_explain_rsf.columns,
    'rsf_importance': rsf_norm.mean(axis=0)
}).sort_values('rsf_importance', ascending=False)


## Merge and compare
comparison = gbst_importance.merge(rsf_importance, on='feature', how='outer').fillna(0)
comparison['avg_importance'] = (comparison['gbst_importance'] + comparison['rsf_importance']) / 2
comparison = comparison.sort_values('avg_importance', ascending=False)

## Display top 20
print("Top 20 Features by Average *Normalized* SHAP Importance (GBST + RSF):\n")
print(comparison.head(20).to_string(index=False))

## Plot
top20 = comparison.head(20)
fig, ax = plt.subplots(figsize=(12, 8))

x = np.arange(len(top20))
width = 0.35

ax.barh(x - width/2, top20['gbst_importance'], width, label='GBST', alpha=0.8)
ax.barh(x + width/2, top20['rsf_importance'], width, label='RSF', alpha=0.8)

ax.set_yticks(x)
ax.set_yticklabels(top20['feature'])
ax.invert_yaxis()
ax.set_xlabel('Mean Normalized SHAP Value', fontsize=12)
ax.set_title('Feature Importance Comparison in OS: GBST vs RSF (Normalized)', fontsize=14, pad=20)
ax.legend()

plt.tight_layout()
#plt.savefig(FIGURES/"shap_model_comparison_os_normalized.tiff", dpi=300, bbox_inches="tight")
plt.show()


In [None]:
### Compare Feature Importance Across Models

# Calculate mean absolute SHAP values for each model (mean raw values)
gbst_importance = pd.DataFrame({
    'feature': X_explain_gbst.columns,
    'gbst_importance': np.abs(shap_values_gbst.values).mean(axis=0)
}).sort_values('gbst_importance', ascending=False)

rsf_importance = pd.DataFrame({
    'feature': X_explain_rsf.columns,
    'rsf_importance': np.abs(shap_values_rsf.values).mean(axis=0)
}).sort_values('rsf_importance', ascending=False)

# Merge and compare
comparison = gbst_importance.merge(rsf_importance, on='feature')
comparison['avg_importance'] = (comparison['gbst_importance'] + comparison['rsf_importance']) / 2
comparison = comparison.sort_values('avg_importance', ascending=False)

# Display top 20 features
print("Top 20 Features by Average SHAP Importance (GBST + RSF):\n")
print(comparison.head(20).to_string(index=False))

# Visualize comparison
top20 = comparison.head(20)
fig, ax = plt.subplots(figsize=(12, 8))

x = np.arange(len(top20))
width = 0.35

ax.barh(x - width/2, top20['gbst_importance'], width, label='GBST', alpha=0.8)
ax.barh(x + width/2, top20['rsf_importance'], width, label='RSF', alpha=0.8)

ax.set_yticks(x)
ax.set_yticklabels(top20['feature'])
ax.invert_yaxis()
ax.set_xlabel('Mean |SHAP Value|', fontsize=12)
ax.set_title('Feature Importance Comparison in OS: GBST vs RSF', fontsize=14, pad=20)
ax.legend()

plt.tight_layout()
#plt.savefig(FIGURES/"shap_model_comparison_os.tiff", dpi=300, bbox_inches="tight")
plt.show()


### Waterfall plot

In [None]:
### SHAP Waterfall Plot - High Risk Patient (GBST)

# Find a high-risk patient (high predicted risk score)
gbst_risks = gbst_best.predict(X_explain_gbst.values)
high_risk_idx = np.argmax(gbst_risks)


# Define custom font size
CUSTOM_FONTSIZE = 10

plt.figure(figsize=(10, 8))
shap.waterfall_plot(shap_values_gbst[high_risk_idx], max_display=15, show=False)

# Access the current axes and iterate through ALL text objects to set font size
ax = plt.gca()

# For feature names, values, SHAP values, etc.
for text in ax.texts:
    text.set_fontsize(CUSTOM_FONTSIZE)

# For x and y tick labels (might be needed if ax.texts doesn't cover them)
ax.tick_params(axis='x', labelsize=CUSTOM_FONTSIZE)
ax.tick_params(axis='y', labelsize=CUSTOM_FONTSIZE * 0.9) # Slightly smaller for long feature names

# For x and y axis labels
ax.set_xlabel(ax.get_xlabel(), fontsize=CUSTOM_FONTSIZE)
ax.set_ylabel(ax.get_ylabel(), fontsize=CUSTOM_FONTSIZE)

# Set the title with desired font size and padding
plt.title(f"GBST: High-Risk Patient Explanation (Risk Score: {gbst_risks[high_risk_idx]:.3f})", 
          fontsize=CUSTOM_FONTSIZE + 2, 
          pad=40) # Increase padding

plt.tick_params(axis='both', which='major', labelsize=10)
plt.tight_layout()
#plt.savefig(OUTPUT/"shap_waterfall_high_risk_gbst_os.tiff", dpi=300, bbox_inches="tight")
plt.show()

In [None]:
### SHAP Waterfall Plot - High Risk Patient (RSF)

# Find a high-risk patient (high predicted risk score)
rsf_risks = rsf_best.predict(X_explain_rsf.values)
high_risk_idx = np.argmax(rsf_risks)


# Define custom font size
CUSTOM_FONTSIZE = 10

plt.figure(figsize=(10, 8))
shap.waterfall_plot(shap_values_rsf[high_risk_idx], max_display=15, show=False)

#  Access the current axes and iterate through ALL text objects to set font size
ax = plt.gca()

# For feature names, values, SHAP values, etc.
for text in ax.texts:
    text.set_fontsize(CUSTOM_FONTSIZE)

# For x and y tick labels (might be needed if ax.texts doesn't cover them)
ax.tick_params(axis='x', labelsize=CUSTOM_FONTSIZE)
ax.tick_params(axis='y', labelsize=CUSTOM_FONTSIZE * 0.9) # Slightly smaller for long feature names

# For x and y axis labels
ax.set_xlabel(ax.get_xlabel(), fontsize=CUSTOM_FONTSIZE)
ax.set_ylabel(ax.get_ylabel(), fontsize=CUSTOM_FONTSIZE)

# Set the title with desired font size and padding
plt.title(f"RSF: High-Risk Patient Explanation (Risk Score: {rsf_risks[high_risk_idx]:.3f})", 
          fontsize=CUSTOM_FONTSIZE + 2, 
          pad=40) # Increase padding

plt.tick_params(axis='both', which='major', labelsize=10)
plt.tight_layout()
#plt.savefig(FIGUREs/"shap_waterfall_high_risk_rsf_os.tiff", dpi=300, bbox_inches="tight")
plt.show()

### Export SHAP Values for Further Analysis

In [None]:
### Export SHAP Values and Feature Importance

# Export GBST SHAP values
shap_df_gbst = pd.DataFrame(
    shap_values_gbst.values,
    columns=X_explain_gbst.columns,
    index=X_explain_gbst.index)

#shap_df_gbst.to_csv(OUTPUT/"shap_values_gbst_os.csv")

# Export RSF SHAP values
shap_df_rsf = pd.DataFrame(
    shap_values_rsf.values,
    columns=X_explain_rsf.columns,
    index=X_explain_rsf.index)

#shap_df_rsf.to_csv(OUTPUT/"shap_values_rsf_os.csv")

# Export feature importance comparison
#comparison.to_csv(OUTPUT/"shap_feature_importance_comparison_os.csv", index=False)


# END