In [1]:
from general_utils import *

# NN
from pycox.datasets import metabric
from pycox.models import LogisticHazard
import torchtuples as tt
from auton_survival.models.dsm import DeepSurvivalMachines

Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)


In [2]:
# Get configs
with open(config_file_path, "r") as file:
    configs = json.load(file)

# Read the pickled DataFrames
with open('../05_preprocessing_emr_data/data/x_train.pickle', 'rb') as file:
    x_train = pickle.load(file)
with open('../05_preprocessing_emr_data/data/x_test.pickle', 'rb') as file:
    x_test = pickle.load(file)
with open('../05_preprocessing_emr_data/data/x_val.pickle', 'rb') as file:
    x_val = pickle.load(file)

# Read the pickled DataFrame
with open('../05_preprocessing_emr_data/data/consolidated_pat_tbl.pickle', 'rb') as file:
    consolidated_pat_tbl = pickle.load(file)

In [3]:
class nn_fitter:
    '''
    simple wrapper class for NN models
    '''
    def __init__(self , configs , train_data , test_data , val_data , num_durations):
        self.configs = configs
        self.train_data = train_data
        self.test_data = test_data
        self.val_data = val_data
        # some aux vars
        self.num_durations = num_durations
        self.labtrans = LogisticHazard.label_transform(self.num_durations)

        # targets
        self.y_train = self.labtrans.fit_transform(*get_target(self.train_data))
        self.y_val = self.labtrans.transform(*get_target(self.val_data))
        self.out_features = self.labtrans.out_features

        # state var
        self.fitted = False
        self.fitter = None
    
    
    def eval(self):
        if not self.fitted:
            raise Exception('Model not fitted yet!')

        # _surv = self.fitter.predict_surv_df(self.test_data.iloc[: , :-2].to_numpy().astype('float32'))
        ev = EvalSurv(pd.DataFrame(self._surv), self.test_data['time_to_event'].to_numpy(), self.test_data['death'].to_numpy(), censor_surv='km')

        # get time grid
        time_grid_div = self.configs['time_invariant']['eval']['time_grid_div']
        time_grid = np.linspace(self.test_data['time_to_event'].min(), self.test_data['time_to_event'].max(), time_grid_div)
        # get metrics
        cindex = ev.concordance_td('antolini')
        ibs = ev.integrated_brier_score(time_grid)
        return cindex , ibs

In [4]:
class pycox_fitter(nn_fitter):
    '''
    simple class for pycox logisitc hazards model
    '''
    def __init__(self , configs , train_data , test_data , val_data , num_durations):
        super(pycox_fitter , self).__init__(configs , train_data , test_data , val_data , num_durations)

    def fit(self):
        in_features = self.train_data.iloc[: , :-2].shape[1]
        num_nodes = [256,256]

        batch_norm = True
        dropout = 0.5

        train = (self.train_data.iloc[: , :-2].to_numpy().astype('float32'), self.y_train)
        val = (self.val_data.iloc[: , :-2].to_numpy().astype('float32'), self.y_val)

        net = tt.practical.MLPVanilla(in_features, num_nodes, self.out_features, batch_norm, dropout)

        model = LogisticHazard(net, tt.optim.Adam(0.002), duration_index=self.labtrans.cuts)

        batch_size = 256
        epochs = 500
        callbacks = [tt.cb.EarlyStopping()]

        log = model.fit(self.train_data.iloc[:,:-2].to_numpy().astype('float32'), self.y_train, batch_size, epochs, callbacks, val_data=val)
        
        # assign the fitted model to a class attr
        self.fitter = model
        # change state var
        self.fitted = True
        
        # predict
        self._surv = self.fitter.predict_surv_df(self.test_data.iloc[: , :-2].to_numpy().astype('float32'))


In [5]:
# instantiate - CPH
pyc = pycox_fitter(configs = configs, train_data = x_train, test_data = x_test, val_data = x_val, num_durations = 10)

# fit
pyc.fit()

# eval
pyc_cindex , pyc_ibs = pyc.eval()

0:	[0s / 0s],		train_loss: 1.7261,	val_loss: 1.4946
1:	[0s / 0s],		train_loss: 1.5203,	val_loss: 1.3216
2:	[0s / 0s],		train_loss: 1.3096,	val_loss: 1.1028
3:	[0s / 0s],		train_loss: 1.0920,	val_loss: 0.8955
4:	[0s / 0s],		train_loss: 0.9154,	val_loss: 0.7668
5:	[0s / 0s],		train_loss: 0.8226,	val_loss: 0.7208
6:	[0s / 0s],		train_loss: 0.7808,	val_loss: 0.7026
7:	[0s / 0s],		train_loss: 0.7549,	val_loss: 0.6867
8:	[0s / 0s],		train_loss: 0.7291,	val_loss: 0.6825
9:	[0s / 0s],		train_loss: 0.7443,	val_loss: 0.6784
10:	[0s / 0s],		train_loss: 0.7350,	val_loss: 0.6818
11:	[0s / 0s],		train_loss: 0.7154,	val_loss: 0.6762
12:	[0s / 0s],		train_loss: 0.7077,	val_loss: 0.6751
13:	[0s / 1s],		train_loss: 0.7082,	val_loss: 0.6699
14:	[0s / 1s],		train_loss: 0.6932,	val_loss: 0.6702
15:	[0s / 1s],		train_loss: 0.6894,	val_loss: 0.6670
16:	[0s / 1s],		train_loss: 0.6955,	val_loss: 0.6731
17:	[0s / 1s],		train_loss: 0.6831,	val_loss: 0.6673
18:	[0s / 1s],		train_loss: 0.6907,	val_loss: 0.6661
19:

In [60]:
class dsm_fitter(nn_fitter):
    '''
    simple class for DSM model
    '''
    def __init__(self , configs , train_data , test_data , val_data , num_durations):
        super(dsm_fitter , self).__init__(configs , train_data , test_data , val_data , num_durations)

    def fit(self):
        times = list(self.labtrans.cuts)

        param_grid = {'k' : [3,4],
              'distribution' : ['LogNormal'],
              'learning_rate' : [1e-3],
              'layers' : [[100],[100,100]]
             }

        params = ParameterGrid(param_grid)

        models = []
        for param in params:
            model = DeepSurvivalMachines(k = param['k'],
                                        distribution = param['distribution'],
                                        layers = param['layers'])
            # The fit method is called to train the model
            model.fit(x_train.iloc[: , :-2].to_numpy(), x_train['time_to_event'].to_numpy(), x_train['death'].to_numpy() ,
                    iters = 100 , 
                    learning_rate = param['learning_rate']
                    )
            models.append(
                [
                    [
                        model.compute_nll(x_val.iloc[: , :-2].to_numpy(), x_val['time_to_event'].to_numpy(), x_val['death'].to_numpy()), 
                        model,
                        param
                    ]
                ]
            )
        best_model = min(models)
        model = best_model[0][1]
        param = best_model[0][2]
        self.best_param = param
        
        # assign the fitted model to a class attr
        self.fitter = model
        # change state var
        self.fitted = True
        
        # predict
        out_survival = model.predict_survival(self.test_data.iloc[: , :-2].to_numpy().astype('float64'), times)
        self._surv = out_survival.T    

In [64]:
# instantiate - CPH
dsm = dsm_fitter(configs = configs, train_data = x_train, test_data = x_test, val_data = x_val, num_durations = 10)

# fit
dsm.fit()

# eval
dsm_cindex , dsm_ibs = dsm.eval()

  0%|          | 0/10000 [00:00<?, ?it/s]

 14%|█▍        | 1444/10000 [00:01<00:11, 775.82it/s]
  8%|▊         | 8/100 [00:00<00:11,  8.30it/s]
 14%|█▍        | 1444/10000 [00:01<00:10, 785.69it/s]
  7%|▋         | 7/100 [00:00<00:12,  7.20it/s]
 14%|█▍        | 1444/10000 [00:01<00:11, 777.46it/s]
 14%|█▍        | 14/100 [00:01<00:09,  8.81it/s]
 14%|█▍        | 1444/10000 [00:01<00:10, 785.93it/s]
  6%|▌         | 6/100 [00:00<00:14,  6.63it/s]


shapes : (1060, 1060, 1060, 1060)


In [65]:
dsm_cindex , dsm_ibs

(0.645945337535422, 0.27302147855965664)

In [66]:
dsm.best_param
# {'distribution': 'LogNormal', 'k': 4, 'layers': [100], 'learning_rate': 0.001}

{'distribution': 'LogNormal', 'k': 3, 'layers': [100], 'learning_rate': 0.001}