In [1]:
# Standard imports
import os
import pandas as pd
import numpy as np
import plotnine as pn
# sksurv imports
from sksurv.util import Surv as surv_util
from sksurv.linear_model import CoxnetSurvivalAnalysis
from sksurv.ensemble import RandomSurvivalForest
# sklearn imports for data preprocessing
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.compose import make_column_selector, ColumnTransformer
from sklearn.model_selection import train_test_split
from sksurv.metrics import concordance_index_censored as concordance
# SurvSet package imports
from SurvSet.data import SurvLoader

In [None]:
def stratified_group_split(df: pd.DataFrame, 
                           group_col: str ='pid', 
                           stratify_col: str ='event', 
                           test_frac: float = 0.2, 
                           seed: int | None = None):
    '''
    Function to make sure that patient IDs don't show up in both training and test sets. And then also balance the event rate across the two sets.
    '''
    # Step 1: Collapse to group-level (one row per patient)
    group_df = df.groupby(group_col)[stratify_col].any().astype(int).reset_index()
    # Step 2: Stratified split on the group level
    group_train, group_test = train_test_split(
        group_df,
        stratify=group_df[stratify_col],
        test_size=test_frac,
        random_state=seed
    )
    # Step 3: Merge back to original data
    df_train = df[df[group_col].isin(group_train[group_col])]
    df_test = df[df[group_col].isin(group_test[group_col])]
    return df_train, df_test


def bootstrap_concordance_index(
    df: pd.DataFrame,
    id_col: str = "id",
    event_col: str = "event",
    time_col: str = "time",
    score_col: str = "score",
    time2_col: str | None = None,
    n_bs: int = 1000,
    alpha: float = 0.05,
    is_td: bool = False,
    ) -> pd.DataFrame:
    """
    Wrapper on the concordance_index_timevarying function to perform bootstrap resampling for confidence intervals.
    """
    # Input checks
    expected_cols = [id_col, event_col, time_col, score_col]
    if is_td:
        expected_cols.append(time2_col)
    missing_cols = np.setdiff1d(expected_cols, df.columns)
    assert len(missing_cols) == 0, f"Missing required columns: {missing_cols.tolist()}"
    assert isinstance(n_bs, int) and n_bs > 0, "n_bs must be a positive integer."
    assert 0 < alpha < 1, "alpha must be between 0 and 1."
    assert is_td in [True, False], "is_td must be a boolean value."
    # Subset to only the required columns
    df = df[expected_cols].copy()
    # Set up the storage holder and final DataFrame slice
    holder_bs = np.zeros(n_bs)
    # Get baseline result and then loop over the bootstrap samples
    if is_td:
        # Baseline
        conc_test = concordance_index_timevarying(df, id_col, time_col, time2_col, event_col, score_col)
        # Bootstrap
        for j in range(n_bs):
            res_bs = df.groupby(event_col).sample(frac=1,replace=True,random_state=j)
            conc_bs = concordance_index_timevarying(res_bs, id_col, time_col, time2_col, event_col, score_col)
            holder_bs[j] = conc_bs
    else:
        # Baseline
        conc_test = concordance(df[event_col].astype(bool), df[time_col], df[score_col])[0]
        # Bootstrap
        for j in range(n_bs):
            res_bs = df.groupby(event_col).sample(frac=1,replace=True,random_state=j)
            conc_bs = concordance(res_bs[event_col].astype(bool), res_bs[time_col], res_bs[score_col])[0]
            holder_bs[j] = conc_bs
    # Add on the baseline result and empirical confidence intervals
    lb, ub = np.quantile(holder_bs, [alpha,1-alpha])
    holder_cindex = pd.DataFrame(np.atleast_2d((conc_test, lb, ub)), columns=['cindex', 'lb', 'ub'])
    return holder_cindex


In [3]:
###############################
# --- (1) PARAMETER SETUP --- #

# Save to the examples directory
dir_base = os.getcwd()
dir_sim = os.path.join(dir_base, 'results_survset')
if not os.path.isdir(dir_sim):
    os.mkdir(dir_sim)
print('Figure will saved here: %s' % dir_sim)

# Concordance empirical alpha level
alpha = 0.1
# Number of bootstrap samples
n_bs = 250
# Set the random seed
seed = 1234
# Percentage of data to use for testing
test_frac = 0.3


#####################################
# --- (2) ENCODER/MODEL/LOADER --- #

# (i) Set up feature transformer pipeline
enc_fac = Pipeline(steps=[('ohe', OneHotEncoder(drop=None,sparse_output=False, handle_unknown='ignore'))])
sel_fac = make_column_selector(pattern='^fac\\_')
enc_num = Pipeline(steps=[('impute', SimpleImputer(strategy='median')), 
                        ('scale', StandardScaler())])
sel_num = make_column_selector(pattern='^num\\_')
# Combine both
enc_df = ColumnTransformer(transformers=[('ohe', enc_fac, sel_fac),('s', enc_num, sel_num)])
enc_df.set_output(transform='pandas')  # Ensure output is a DataFrame

# (ii) Run on datasets
senc = surv_util()
loader = SurvLoader()

# (iii) Set up the models
# model = CoxnetSurvivalAnalysis(normalize=True)
model = RandomSurvivalForest(n_estimators=100, random_state=123)

Figure will saved here: /mnt/evafs/faculty/home/hbaniecki/survshapiq/experiments/results_survset


In [None]:
##################################
# --- (3) LOOP OVER DATASETS --- #

# (i) Initialize results holder and loop over datasets
n_ds = len(loader.df_ds)
holder_cindex = []
for i, r in loader.df_ds.iterrows():
    is_td, ds = r['is_td'], r['ds']
    if is_td:
        continue
    print('Dataset %s (%i of %i)' % (ds, i+1, n_ds))
    df = loader.load_dataset(ds)['df']
    # Split based on both the event rate and unique IDs
    df_train, df_test = stratified_group_split(df=df, group_col='pid', 
                            stratify_col='event', test_frac=test_frac, seed=seed)
    assert not df_train['pid'].isin(df_test['pid']).any(), \
        'Training and test sets must not overlap in patient IDs.'
    # Fit encoder
    enc_df.fit(df_train)
    # Transform data
    X_train = enc_df.transform(df_train)
    assert X_train.columns.str.split('\\_{1,2}', expand=True).to_frame(False)[1].isin(['fac','num']).all(), 'Expected feature names to be prefixed with "fac_" or "num_"'
    X_test = enc_df.transform(df_test)
    # Set up Surv object for static model and fit
    So_train = senc.from_arrays(df_train['event'].astype(bool), df_train['time'])
    model.fit(X=X_train, y=So_train)
    # Get test prediction
    scores_test = model.predict(X_test)
    # Prepare test data for concordance calculation
    res_test = df_test[['pid','event','time']].assign(scores=scores_test)
    # Generate results and bootstrap concordance index
    res_cindex = bootstrap_concordance_index(res_test, 'pid', 'event', 'time', 'scores', 'time2', n_bs, alpha, is_td=is_td)
    res_cindex.insert(0, 'ds', ds) 
    res_cindex.insert(1, 'is_td', is_td)  # Add dataset and type
    holder_cindex.append(res_cindex)

# (ii) Merge results
df_cindex = pd.concat(holder_cindex, ignore_index=True, axis=0)
ds_ord = df_cindex.sort_values('cindex')['ds'].values
df_cindex['ds'] = pd.Categorical(df_cindex['ds'], ds_ord)

path_df = os.path.join(dir_sim, 'rsf_cindex.csv')
df_cindex.to_csv(path_df, index=False)

In [None]:
############################
# --- (4) PLOT RESULTS --- #

# (i) Plot concordance index
gg_cindex = (pn.ggplot(df_cindex, pn.aes(y='cindex',x='ds', color='is_td')) + 
    pn.theme_bw() + pn.coord_flip() + 
    pn.geom_point(size=2) + 
    pn.geom_linerange(pn.aes(ymin='lb', ymax='ub')) + 
    pn.labs(y='Concordance') + 
    pn.scale_color_discrete(name='Time-varying covariates') +
    pn.geom_hline(yintercept=0.5,linetype='--', color='black') + 
    pn.theme(axis_title_y=pn.element_blank()))
path_fig = os.path.join(dir_sim, 'rsf_index.png')
gg_cindex.save(path_fig, height=10, width=5)


print('~~~ The SurvSet.sim_run module was successfully executed ~~~')

### Analyse results

In [18]:
df_cindex = pd.read_csv("results_survset/rsf_cindex.csv")

In [19]:
df = loader.df_ds

In [20]:
# Option 1: Convert both to string
df["ds"] = df["ds"].astype(str)
df_cindex["ds"] = df_cindex["ds"].astype(str)

In [26]:
temp = df.set_index("ds").join(df_cindex.set_index("ds").drop("is_td", axis=1), how="right")

In [42]:
temp[(temp.n_num >= 10) & (temp.n_num < 30)]

Unnamed: 0_level_0,is_td,n,n_fac,n_ohe,n_num,cindex,lb,ub
ds,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
phpl04K8a,False,442,1,1,20,0.616717,0.567722,0.65919
rhc,False,5735,31,51,22,0.707549,0.695597,0.717386
Bergamaschi,False,82,0,0,10,0.709677,0.561937,0.848206
smarto,False,3873,9,17,17,0.690374,0.661732,0.721616
support2,False,9105,11,41,24,0.823342,0.819199,0.827998
