In [1]:
from general_utils import *

# Traditional
from lifelines import CoxPHFitter
from lifelines import WeibullAFTFitter

# Tree-Based
from sksurv.ensemble import RandomSurvivalForest

Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)


In [2]:
# Get configs
with open(config_file_path, "r") as file:
    configs = json.load(file)
# Get configs
with open(config_file_path, "r") as file:
    configs = json.load(file)

# Read the pickled DataFrames
with open('../05_preprocessing_emr_data/data/x_train.pickle', 'rb') as file:
    x_train = pickle.load(file)
with open('../05_preprocessing_emr_data/data/x_test.pickle', 'rb') as file:
    x_test = pickle.load(file)
with open('../05_preprocessing_emr_data/data/x_val.pickle', 'rb') as file:
    x_val = pickle.load(file)

# Read the pickled DataFrame
with open('../05_preprocessing_emr_data/data/consolidated_pat_tbl.pickle', 'rb') as file:
    consolidated_pat_tbl = pickle.load(file)

In [23]:
class _traditional_fitter:
    '''
    simple wrapper class for cox proportional hazards
    '''
    def __init__(self , configs , train_data , test_data , val_data):
        self.train_data = train_data
        self.test_data = test_data
        self.val_data = val_data
        self.configs = configs
        # state var
        self.fitted = False
        self.fitter = None

    def eval(self , fitter_is_rsf = False):
        if not self.fitted:
            raise Exception('Model not fitted yet!')
            
        # get evaluation - with a tweak for RSF. not the best. but works for now.
        if not fitter_is_rsf:
            # predict
            _surv = self.fitter.predict_survival_function(self.test_data.iloc[: , :-2])
            ev = EvalSurv(pd.DataFrame(_surv), self.test_data['time_to_event'].to_numpy(), self.test_data['death'].to_numpy(), censor_surv='km')
        else:
            # predict
            _surv = self.fitter.predict_survival_function(self.test_data.iloc[: , :-2] , return_array = True)
            ev = EvalSurv(pd.DataFrame(_surv.T), self.test_data['time_to_event'].to_numpy(), self.test_data['death'].to_numpy(), censor_surv='km')
        
        # get time grid
        time_grid_div = self.configs['time_invariant']['eval']['time_grid_div']
        time_grid = np.linspace(self.test_data['time_to_event'].min(), self.test_data['time_to_event'].max(), time_grid_div)
        # get metrics
        cindex = ev.concordance_td('antolini')
        ibs = ev.integrated_brier_score(time_grid)
        return cindex , ibs

In [24]:
class CPH(_traditional_fitter):
    def __init__(self , configs , train_data , test_data , val_data):
        super(CPH , self).__init__(configs , train_data , test_data , val_data)

    def fit(self):
        # init CPH
        cph = CoxPHFitter(penalizer = 0.1)
        # fit
        cph.fit(self.train_data, duration_col='time_to_event', event_col='death', fit_options = {'step_size':0.1})
        # assign the fitted model to a class attr
        self.fitter = cph
        # change state var
        self.fitted = True

In [25]:
class AFT(_traditional_fitter):
    def __init__(self , configs , train_data , test_data , val_data):
        super(AFT , self).__init__(configs , train_data , test_data , val_data)

    def fit(self):
        # init AFT
        aft = WeibullAFTFitter(penalizer = 0.01)
        eps = 1e-8
        self.train_data['time_to_event'] = self.train_data['time_to_event'] + eps
        # fit
        aft.fit(self.train_data, duration_col='time_to_event', event_col='death')

        # assign the fitted model to a class attr
        self.fitter = aft
        # change state var
        self.fitted = True   

In [26]:
class RSF(_traditional_fitter):
    def __init__(self , configs , train_data , test_data , val_data):
        super(RSF , self).__init__(configs , train_data , test_data , val_data)

    def fit(self):
        # Train - Create a structured array
        y_train = np.array([(x, y) for x, y in zip(self.train_data['death'].astype('bool') , self.train_data['time_to_event'])],
                                    dtype=[('death', bool) , ('time_to_event', int)])

        # init RSF
        rsf = RandomSurvivalForest(
            n_estimators=100, min_samples_split=10, min_samples_leaf=15, n_jobs=-1, oob_score = True
        )
        rsf.fit(self.train_data.iloc[: , :-2], y_train)

        # assign the fitted model to a class attr
        self.fitter = rsf
        # change state var
        self.fitted = True   

In [27]:
cph = CPH(configs = configs, train_data = x_train, test_data = x_test, val_data = x_val)

# fit
cph.fit()
# eval
cph.eval(fitter_is_rsf = False)

shapes : (1060, 1060, 1060, 1060)


(0.6953971238645091, 0.17313085627507496)

In [28]:
aft = AFT(configs = configs, train_data = x_train, test_data = x_test, val_data = x_val)

# fit
aft.fit()
# eval
aft.eval(fitter_is_rsf = False)

shapes : (1060, 1060, 1060, 1060)


(0.6953838819884001, 0.17419123792832816)

In [29]:
rsf = RSF(configs = configs, train_data = x_train, test_data = x_test, val_data = x_val)

# fit
rsf.fit()
# eval
rsf.eval(fitter_is_rsf = True)

shapes : (1060, 1060, 1060, 1060)


(0.6906300484652665, 0.20274623694935368)

In [None]:
# # PyCox
# from pycox.datasets import metabric
# from pycox.models import LogisticHazard
# import torchtuples as tt

# # Deep Survival Machines
# from auton_survival.models.dsm import DeepSurvivalMachines