In [5]:
# Standard imports
import os
import pandas as pd
import numpy as np
import plotnine as pn
# sksurv imports
from sksurv.util import Surv as surv_util
from sksurv.linear_model import CoxnetSurvivalAnalysis
from sksurv.ensemble import RandomSurvivalForest
# sklearn imports for data preprocessing
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.compose import make_column_selector, ColumnTransformer
from sklearn.model_selection import train_test_split
from sksurv.metrics import concordance_index_censored as concordance
# SurvSet package imports
from SurvSet.data import SurvLoader

In [6]:
def stratified_group_split(df: pd.DataFrame, 
                           group_col: str ='pid', 
                           stratify_col: str ='event', 
                           test_frac: float = 0.2, 
                           seed: int | None = None):
    '''
    Function to make sure that patient IDs don't show up in both training and test sets. And then also balance the event rate across the two sets.
    '''
    # Step 1: Collapse to group-level (one row per patient)
    group_df = df.groupby(group_col)[stratify_col].any().astype(int).reset_index()
    # Step 2: Stratified split on the group level
    group_train, group_test = train_test_split(
        group_df,
        stratify=group_df[stratify_col],
        test_size=test_frac,
        random_state=seed
    )
    # Step 3: Merge back to original data
    df_train = df[df[group_col].isin(group_train[group_col])]
    df_test = df[df[group_col].isin(group_test[group_col])]
    return df_train, df_test


def bootstrap_concordance_index(
    df: pd.DataFrame,
    id_col: str = "id",
    event_col: str = "event",
    time_col: str = "time",
    score_col: str = "score",
    time2_col: str | None = None,
    n_bs: int = 1000,
    alpha: float = 0.05,
    is_td: bool = False,
    ) -> pd.DataFrame:
    """
    Wrapper on the concordance_index_timevarying function to perform bootstrap resampling for confidence intervals.
    """
    # Input checks
    expected_cols = [id_col, event_col, time_col, score_col]
    if is_td:
        expected_cols.append(time2_col)
    missing_cols = np.setdiff1d(expected_cols, df.columns)
    assert len(missing_cols) == 0, f"Missing required columns: {missing_cols.tolist()}"
    assert isinstance(n_bs, int) and n_bs > 0, "n_bs must be a positive integer."
    assert 0 < alpha < 1, "alpha must be between 0 and 1."
    assert is_td in [True, False], "is_td must be a boolean value."
    # Subset to only the required columns
    df = df[expected_cols].copy()
    # Set up the storage holder and final DataFrame slice
    holder_bs = np.zeros(n_bs)
    # Get baseline result and then loop over the bootstrap samples
    if is_td:
        # Baseline
        conc_test = concordance_index_timevarying(df, id_col, time_col, time2_col, event_col, score_col)
        # Bootstrap
        for j in range(n_bs):
            res_bs = df.groupby(event_col).sample(frac=1,replace=True,random_state=j)
            conc_bs = concordance_index_timevarying(res_bs, id_col, time_col, time2_col, event_col, score_col)
            holder_bs[j] = conc_bs
    else:
        # Baseline
        conc_test = concordance(df[event_col].astype(bool), df[time_col], df[score_col])[0]
        # Bootstrap
        for j in range(n_bs):
            res_bs = df.groupby(event_col).sample(frac=1,replace=True,random_state=j)
            conc_bs = concordance(res_bs[event_col].astype(bool), res_bs[time_col], res_bs[score_col])[0]
            holder_bs[j] = conc_bs
    # Add on the baseline result and empirical confidence intervals
    lb, ub = np.quantile(holder_bs, [alpha,1-alpha])
    holder_cindex = pd.DataFrame(np.atleast_2d((conc_test, lb, ub)), columns=['cindex', 'lb', 'ub'])
    return holder_cindex


In [12]:
###############################
# --- (1) PARAMETER SETUP --- #

# Save to the examples directory
dir_base = os.getcwd()
dir_sim = os.path.join(dir_base, 'results_survset')
if not os.path.isdir(dir_sim):
    os.mkdir(dir_sim)
print('Figure will saved here: %s' % dir_sim)

# Concordance empirical alpha level
alpha = 0.1
# Number of bootstrap samples
n_bs = 250
# Set the random seed
seed = 1234
# Percentage of data to use for testing
test_frac = 0.3


#####################################
# --- (2) ENCODER/MODEL/LOADER --- #

# (i) Set up feature transformer pipeline
enc_fac = Pipeline(steps=[('ohe', OneHotEncoder(drop=None,sparse_output=False, handle_unknown='ignore'))])
sel_fac = make_column_selector(pattern='^fac\\_')
enc_num = Pipeline(steps=[('impute', SimpleImputer(strategy='median')), 
                        ('scale', StandardScaler())])
sel_num = make_column_selector(pattern='^num\\_')
# Combine both
enc_df = ColumnTransformer(transformers=[('ohe', enc_fac, sel_fac),('s', enc_num, sel_num)])
enc_df.set_output(transform='pandas')  # Ensure output is a DataFrame

# (ii) Run on datasets
senc = surv_util()
loader = SurvLoader()

# (iii) Set up the models
# model = CoxnetSurvivalAnalysis(normalize=True)
model = RandomSurvivalForest(n_estimators=100, random_state=123)


##################################
# --- (3) LOOP OVER DATASETS --- #

# (i) Initialize results holder and loop over datasets
n_ds = len(loader.df_ds)
holder_cindex = []
for i, r in loader.df_ds.iterrows():
    is_td, ds = r['is_td'], r['ds']
    if is_td:
        continue
    print('Dataset %s (%i of %i)' % (ds, i+1, n_ds))
    df = loader.load_dataset(ds)['df']
    # Split based on both the event rate and unique IDs
    df_train, df_test = stratified_group_split(df=df, group_col='pid', 
                            stratify_col='event', test_frac=test_frac, seed=seed)
    assert not df_train['pid'].isin(df_test['pid']).any(), \
        'Training and test sets must not overlap in patient IDs.'
    # Fit encoder
    enc_df.fit(df_train)
    # Transform data
    X_train = enc_df.transform(df_train)
    assert X_train.columns.str.split('\\_{1,2}', expand=True).to_frame(False)[1].isin(['fac','num']).all(), 'Expected feature names to be prefixed with "fac_" or "num_"'
    X_test = enc_df.transform(df_test)
    # Set up Surv object for static model and fit
    So_train = senc.from_arrays(df_train['event'].astype(bool), df_train['time'])
    model.fit(X=X_train, y=So_train)
    # Get test prediction
    scores_test = model.predict(X_test)
    # Prepare test data for concordance calculation
    res_test = df_test[['pid','event','time']].assign(scores=scores_test)
    # Generate results and bootstrap concordance index
    res_cindex = bootstrap_concordance_index(res_test, 'pid', 'event', 'time', 'scores', 'time2', n_bs, alpha, is_td=is_td)
    res_cindex.insert(0, 'ds', ds) 
    res_cindex.insert(1, 'is_td', is_td)  # Add dataset and type
    holder_cindex.append(res_cindex)

# (ii) Merge results
df_cindex = pd.concat(holder_cindex, ignore_index=True, axis=0)
ds_ord = df_cindex.sort_values('cindex')['ds'].values
df_cindex['ds'] = pd.Categorical(df_cindex['ds'], ds_ord)

path_df = os.path.join(dir_sim, 'rsf_cindex.csv')
df_cindex.to_csv(path_df, index=False)

############################
# --- (4) PLOT RESULTS --- #

# (i) Plot concordance index
gg_cindex = (pn.ggplot(df_cindex, pn.aes(y='cindex',x='ds', color='is_td')) + 
    pn.theme_bw() + pn.coord_flip() + 
    pn.geom_point(size=2) + 
    pn.geom_linerange(pn.aes(ymin='lb', ymax='ub')) + 
    pn.labs(y='Concordance') + 
    pn.scale_color_discrete(name='Time-varying covariates') +
    pn.geom_hline(yintercept=0.5,linetype='--', color='black') + 
    pn.theme(axis_title_y=pn.element_blank()))
path_fig = os.path.join(dir_sim, 'rsf_index.png')
gg_cindex.save(path_fig, height=10, width=5)


print('~~~ The SurvSet.sim_run module was successfully executed ~~~')

Figure will saved here: /Users/hbaniecki/GitHub/sophhan/survshapiq/experiments/results_survset
Dataset hdfail (1 of 77)
Dataset e1684 (2 of 77)
Dataset phpl04K8a (3 of 77)
Dataset prostate (4 of 77)
Dataset cancer (5 of 77)
Dataset DBCD (6 of 77)
Dataset LeukSurv (7 of 77)
Dataset Dialysis (8 of 77)
Dataset Melanoma (9 of 77)
Dataset d.oropha.rec (10 of 77)
Dataset pharmacoSmoking (11 of 77)
Dataset zinc (12 of 77)
Dataset nki70 (13 of 77)
Dataset burn (14 of 77)
Dataset breast (15 of 77)
Dataset stagec (16 of 77)
Dataset ovarian (17 of 77)
Dataset actg (18 of 77)
Dataset rhc (19 of 77)
Dataset rdata (20 of 77)
Dataset vlbw (22 of 77)
Dataset grace (23 of 77)
Dataset TRACE (24 of 77)
Dataset whas500 (25 of 77)
Dataset dataDIVAT1 (26 of 77)
Dataset Bergamaschi (27 of 77)
Dataset AML_Bull (28 of 77)
Dataset prostateSurvival (29 of 77)
Dataset nwtco (30 of 77)
Dataset divorce (31 of 77)
Dataset dataDIVAT3 (32 of 77)
Dataset uis (33 of 77)
Dataset colon (34 of 77)
Dataset glioma (35 of 77)



Dataset Aids2 (38 of 77)
Dataset Z243 (39 of 77)
Dataset veteran (40 of 77)
Dataset chop (41 of 77)
Dataset wpbc (42 of 77)
Dataset ova (43 of 77)
Dataset micro.censure (44 of 77)
Dataset MCLcleaned (45 of 77)




Dataset cost (46 of 77)
Dataset gse1992 (49 of 77)
Dataset smarto (50 of 77)




Dataset NSBCD (51 of 77)
Dataset retinopathy (52 of 77)
Dataset support2 (53 of 77)
Dataset pbc (54 of 77)
Dataset Pbc3 (55 of 77)
Dataset UnempDur (57 of 77)
Dataset cgd (58 of 77)
Dataset acath (59 of 77)
Dataset scania (60 of 77)
Dataset GBSG2 (62 of 77)
Dataset gse3143 (64 of 77)
Dataset Unemployment (65 of 77)
Dataset gse4335 (66 of 77)
Dataset rott2 (67 of 77)
Dataset DLBCL (68 of 77)
Dataset dataOvarian1 (69 of 77)
Dataset diabetes (71 of 77)
Dataset vdv (72 of 77)
Dataset hepatoCellular (73 of 77)




Dataset flchain (74 of 77)
Dataset Framingham (75 of 77)
Dataset dataDIVAT2 (77 of 77)




~~~ The SurvSet.sim_run module was successfully executed ~~~


In [16]:
loader.df_ds.join(df_cindex, index="ds")

TypeError: DataFrame.join() got an unexpected keyword argument 'index'

In [13]:
df_cindex

Unnamed: 0,ds,is_td,cindex,lb,ub
0,hdfail,False,0.877454,0.868224,0.886905
1,e1684,False,0.540407,0.483285,0.596818
2,phpl04K8a,False,0.616717,0.567722,0.659190
3,prostate,False,0.625828,0.589569,0.662380
4,cancer,False,0.608749,0.551812,0.671120
...,...,...,...,...,...
64,vdv,False,0.600000,0.513661,0.711585
65,hepatoCellular,False,0.759885,0.704579,0.823440
66,flchain,False,0.934313,0.931904,0.937108
67,Framingham,False,0.703529,0.688968,0.718210
