In [2]:
import sys
sys.path.append('../')
sys.path.append('../nn_survival_analysis')
from nn_survival_analysis.general_utils import *
from nn_survival_analysis.model_utils import *
from nn_survival_analysis.losses import *
from nn_survival_analysis.models import *
from nn_survival_analysis.other_nn_models import *
from nn_survival_analysis.time_invariant_surv import *
from nn_survival_analysis.time_variant_surv import *
from nn_survival_analysis.traditional_models import *
import scipy
from pycox.datasets import support
from sklearn_pandas import DataFrameMapper
from pycox.preprocessing.feature_transforms import OrderedCategoricalLong

# define sigmoid function - will be handy later
sigmoid = lambda z : 1 / (1 + np.exp(-z))

config_file_path = '../nn_survival_analysis/config.json'

Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)


In [3]:
# Get configs
with open('../nn_survival_analysis/config.json', "r") as file:
    configs = json.load(file)

df_train = support.read_df()
df_test = df_train.sample(frac=0.2)
df_train = df_train.drop(df_test.index)
df_val = df_train.sample(frac=0.2)
df_train = df_train.drop(df_val.index)

cols_standardize =  ['x0', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13']
cols_leave = ['x1', 'x4', 'x5']
cols_categorical =  ['x2', 'x3', 'x6']

standardize = [([col], StandardScaler()) for col in cols_standardize]
leave = [(col, None) for col in cols_leave]
categorical = [(col, OrderedCategoricalLong()) for col in cols_categorical]

x_mapper = DataFrameMapper(standardize + leave + categorical)

x_train = x_mapper.fit_transform(df_train).astype('float32')
x_val = x_mapper.transform(df_val).astype('float32')
x_test = x_mapper.transform(df_test).astype('float32')

get_target = lambda df: (df['duration'].values, df['event'].values)
y_train_dur , y_train_event = get_target(df_train)
y_test_dur , y_test_event = get_target(df_test)
y_val_dur , y_val_event = get_target(df_val)

_df_train = np.concatenate([x_train , y_train_dur.reshape(-1 , 1) , y_train_event.reshape(-1 , 1)] , axis = 1)
_df_test = np.concatenate([x_test , y_test_dur.reshape(-1 , 1) , y_test_event.reshape(-1 , 1)] , axis = 1)
_df_val = np.concatenate([x_val , y_val_dur.reshape(-1 , 1) , y_val_event.reshape(-1 , 1)] , axis = 1)

_df_train = pd.DataFrame(_df_train , columns = df_train.columns)
_df_test = pd.DataFrame(_df_test , columns = df_test.columns)
_df_val = pd.DataFrame(_df_val , columns = df_val.columns)

_df_train.rename(columns = {'duration':'time_to_event' , 'event':'death'} , inplace = True)
_df_test.rename(columns = {'duration':'time_to_event' , 'event':'death'} , inplace = True)
_df_val.rename(columns = {'duration':'time_to_event' , 'event':'death'} , inplace = True)

#### TIS

In [4]:
%%time
#-----------------------------------------------------------------------------------------------
# instantiate - Time Invariant Survival
tis = Time_Invariant_Survival(
        configs = configs, 
        train_data = _df_train,
        test_data = _df_test, 
        val_data = _df_val
)

# fit
tis.fit(verbose = True)
mean_ , up_ , low_ , y_test_dur , y_test_event = tis.predict() # Visualize -> tis.visualize(mean_ , up_ , low_ , _from = 40 , _to = 50 )
tis_cindex , tis_ibs = tis.evaluation(mean_ , y_test_dur , y_test_event, plot = False)
print(tis_cindex , tis_ibs)  

training cluster 0
Epoch 50: Training Loss: 0.0020189, Val Loss: -0.0074056
Epoch 100: Training Loss: -0.0140832, Val Loss: -0.0080518
Epoch 150: Training Loss: -0.0136744, Val Loss: -0.0078765
Epoch 200: Training Loss: -0.0150570, Val Loss: -0.0120581
Epoch 250: Training Loss: -0.0139139, Val Loss: -0.0091328
Epoch 300: Training Loss: -0.0033458, Val Loss: -0.0154889
Epoch 350: Training Loss: -0.0062397, Val Loss: -0.0097667
Epoch 400: Training Loss: 0.0045828, Val Loss: -0.0111534
Epoch 450: Training Loss: -0.0209305, Val Loss: -0.0097705
Epoch 500: Training Loss: -0.0274025, Val Loss: -0.0131506
Epoch 550: Training Loss: -0.0050691, Val Loss: -0.0100244
Epoch 600: Training Loss: -0.0206057, Val Loss: -0.0102631
shapes : (1775, 1775, 1775, 1775)
0.5909738587770982 0.22665645619362926
CPU times: total: 4min 24s
Wall time: 1min 10s


#### Other Fitters

In [5]:
%%time
#-----------------------------------------------------------------------------------------------
# instantiate - PyCox
pyc = PYC(configs = configs, train_data = _df_train, test_data = _df_test, val_data = _df_val, num_durations = 10)

# fit
pyc.fit()

# eval
pyc_cindex , pyc_ibs = pyc.eval()
        
print(f'PyCox: cindex {pyc_cindex} , ibs {pyc_ibs}')
#-----------------------------------------------------------------------------------------------
# instantiate - Deep Survival Machines
dsm = DSM(configs = configs, train_data = _df_train, test_data = _df_test, val_data = _df_val, num_durations = 10)

# fit
dsm.fit()

# eval
dsm_cindex , dsm_ibs = dsm.eval()
       
print(f'Deep Survival Machines: cindex {dsm_cindex} , ibs {dsm_ibs}')

# -----------------------------------------------------------------------------------------------
# instantiate - CPH
cph = CPH(configs = configs, train_data = _df_train, test_data = _df_test, val_data = _df_val)

# fit
cph.fit()
# eval
cph_cindex , cph_ibs = cph.eval(fitter_is_rsf = False)
        
print(f'Cox Proportional Hazards: cindex {cph_cindex} , ibs {cph_ibs}')

#-----------------------------------------------------------------------------------------------
# instantiate - AFT
aft = AFT(configs = configs, train_data = _df_train, test_data = _df_test, val_data = _df_val)

# fit
aft.fit()
# eval
aft_cindex , aft_ibs = aft.eval(fitter_is_rsf = False)
        
print(f'Weibull Accelerated Failure Time: cindex {aft_cindex} , ibs {aft_ibs}')

#-----------------------------------------------------------------------------------------------
# instantiate - RSF
rsf = RSF(configs = configs, train_data = _df_train, test_data = _df_test, val_data = _df_val)

# fit
rsf.fit()
# eval
rsf_cindex , rsf_ibs = rsf.eval(fitter_is_rsf = True)


print(f'Random Survival Forest: cindex {rsf_cindex} , ibs {rsf_ibs}')

0:	[0s / 0s],		train_loss: 2.4025,	val_loss: 2.0611
1:	[0s / 0s],		train_loss: 2.0385,	val_loss: 1.7303
2:	[0s / 0s],		train_loss: 1.6756,	val_loss: 1.4173
3:	[0s / 0s],		train_loss: 1.3905,	val_loss: 1.2589
4:	[0s / 0s],		train_loss: 1.2738,	val_loss: 1.2196
5:	[0s / 0s],		train_loss: 1.2288,	val_loss: 1.2049
6:	[0s / 0s],		train_loss: 1.2128,	val_loss: 1.1995
7:	[0s / 0s],		train_loss: 1.1993,	val_loss: 1.1954
8:	[0s / 1s],		train_loss: 1.1851,	val_loss: 1.1895
9:	[0s / 1s],		train_loss: 1.1788,	val_loss: 1.1902
10:	[0s / 1s],		train_loss: 1.1708,	val_loss: 1.1882
11:	[0s / 1s],		train_loss: 1.1779,	val_loss: 1.1852
12:	[0s / 1s],		train_loss: 1.1544,	val_loss: 1.1854
13:	[0s / 1s],		train_loss: 1.1602,	val_loss: 1.1857
14:	[0s / 1s],		train_loss: 1.1702,	val_loss: 1.1853
15:	[0s / 1s],		train_loss: 1.1506,	val_loss: 1.1816
16:	[0s / 1s],		train_loss: 1.1550,	val_loss: 1.1879
17:	[0s / 1s],		train_loss: 1.1548,	val_loss: 1.1849
18:	[0s / 2s],		train_loss: 1.1519,	val_loss: 1.1836
19:

 12%|█▏        | 1173/10000 [00:01<00:11, 760.11it/s]
 51%|█████     | 51/100 [00:05<00:05,  8.54it/s]
 12%|█▏        | 1173/10000 [00:01<00:11, 770.12it/s]
 14%|█▍        | 14/100 [00:02<00:12,  6.80it/s]
 12%|█▏        | 1173/10000 [00:01<00:11, 758.15it/s]
 46%|████▌     | 46/100 [00:06<00:07,  7.63it/s]
 12%|█▏        | 1173/10000 [00:01<00:11, 758.18it/s]
 14%|█▍        | 14/100 [00:02<00:13,  6.19it/s]


shapes : (1775, 1775, 1775, 1775)
Deep Survival Machines: cindex 0.5864674406684972 , ibs 0.24281393212331515
shapes : (1775, 1775, 1775, 1775)
Cox Proportional Hazards: cindex 0.5765378804391086 , ibs 0.209033969777541
shapes : (1775, 1775, 1775, 1775)
Weibull Accelerated Failure Time: cindex 0.5755321293580378 , ibs 0.21012863114671942
shapes : (1775, 1775, 1775, 1775)
Random Survival Forest: cindex 0.6277225763593447 , ibs 0.19519703490214144
CPU times: total: 4min 38s
Wall time: 2min 17s
