In [1]:
import sys
sys.path.append('../')
sys.path.append('../nn_survival_analysis')
from nn_survival_analysis.general_utils import *
from nn_survival_analysis.model_utils import *
from nn_survival_analysis.losses import *
from nn_survival_analysis.models import *
from nn_survival_analysis.other_nn_models import *
from nn_survival_analysis.time_invariant_surv import *
from nn_survival_analysis.time_variant_surv import *
from nn_survival_analysis.traditional_models import *
import scipy
from pycox.datasets import support
from sklearn_pandas import DataFrameMapper
from pycox.preprocessing.feature_transforms import OrderedCategoricalLong

# define sigmoid function - will be handy later
sigmoid = lambda z : 1 / (1 + np.exp(-z))

config_file_path = '../nn_survival_analysis/config.json'

Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)


In [2]:
# Get configs
with open(config_file_path , "r") as file:
    configs = json.load(file)

df_train = support.read_df()
df_test = df_train.sample(frac=0.2)
df_train = df_train.drop(df_test.index)
df_val = df_train.sample(frac=0.2)
df_train = df_train.drop(df_val.index)

cols_standardize =  ['x0', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13']
cols_leave = ['x1', 'x4', 'x5']
cols_categorical =  ['x2', 'x3', 'x6']

standardize = [([col], StandardScaler()) for col in cols_standardize]
leave = [(col, None) for col in cols_leave]
categorical = [(col, OrderedCategoricalLong()) for col in cols_categorical]

x_mapper = DataFrameMapper(standardize + leave + categorical)

x_train = x_mapper.fit_transform(df_train).astype('float32')
x_val = x_mapper.transform(df_val).astype('float32')
x_test = x_mapper.transform(df_test).astype('float32')

get_target = lambda df: (df['duration'].values, df['event'].values)
y_train_dur , y_train_event = get_target(df_train)
y_test_dur , y_test_event = get_target(df_test)
y_val_dur , y_val_event = get_target(df_val)

_df_train = np.concatenate([x_train , y_train_dur.reshape(-1 , 1) , y_train_event.reshape(-1 , 1)] , axis = 1)
_df_test = np.concatenate([x_test , y_test_dur.reshape(-1 , 1) , y_test_event.reshape(-1 , 1)] , axis = 1)
_df_val = np.concatenate([x_val , y_val_dur.reshape(-1 , 1) , y_val_event.reshape(-1 , 1)] , axis = 1)

_df_train = pd.DataFrame(_df_train , columns = df_train.columns)
_df_test = pd.DataFrame(_df_test , columns = df_test.columns)
_df_val = pd.DataFrame(_df_val , columns = df_val.columns)

_df_train.rename(columns = {'duration':'time_to_event' , 'event':'death'} , inplace = True)
_df_test.rename(columns = {'duration':'time_to_event' , 'event':'death'} , inplace = True)
_df_val.rename(columns = {'duration':'time_to_event' , 'event':'death'} , inplace = True)

Dataset 'support' not locally available. Downloading...
Done


#### TIS

In [3]:
%%time
#-----------------------------------------------------------------------------------------------
# instantiate - Time Invariant Survival
tis = Time_Invariant_Survival(
        configs = configs, 
        train_data = _df_train,
        test_data = _df_test, 
        val_data = _df_val
)

# fit
tis.fit(verbose = True)
mean_ , up_ , low_ , y_test_dur , y_test_event = tis.predict() # Visualize -> tis.visualize(mean_ , up_ , low_ , _from = 40 , _to = 50 )
tis_cindex , tis_ibs = tis.evaluation(mean_ , y_test_dur , y_test_event, plot = False)
print(tis_cindex , tis_ibs)  

training cluster 0
Epoch 50: Training Loss: 0.0037825, Val Loss: 0.0154994
Epoch 100: Training Loss: 0.0062350, Val Loss: 0.0096503
Epoch 150: Training Loss: 0.0158692, Val Loss: 0.0138773
Epoch 200: Training Loss: 0.0118470, Val Loss: 0.0066545
Epoch 250: Training Loss: 0.0166348, Val Loss: 0.0098582
Epoch 300: Training Loss: 0.0111556, Val Loss: 0.0113556
Epoch 350: Training Loss: 0.0069933, Val Loss: 0.0111245
Epoch 400: Training Loss: -0.0001434, Val Loss: 0.0129846
shapes : (1775, 1775, 1775, 1775)
0.562298916354898 0.21533423290168918
CPU times: total: 4min 29s
Wall time: 1min 3s


#### Other Fitters

In [4]:
%%time
#-----------------------------------------------------------------------------------------------
# instantiate - PyCox
pyc = PYC(configs = configs, train_data = _df_train, test_data = _df_test, val_data = _df_val, num_durations = 10)

# fit
pyc.fit()

# eval
pyc_cindex , pyc_ibs = pyc.eval()
        
print(f'PyCox: cindex {pyc_cindex} , ibs {pyc_ibs}')
#-----------------------------------------------------------------------------------------------
# instantiate - Deep Survival Machines
dsm = DSM(configs = configs, train_data = _df_train, test_data = _df_test, val_data = _df_val, num_durations = 10)

# fit
dsm.fit()

# eval
dsm_cindex , dsm_ibs = dsm.eval()
       
print(f'Deep Survival Machines: cindex {dsm_cindex} , ibs {dsm_ibs}')

# -----------------------------------------------------------------------------------------------
# instantiate - CPH
cph = CPH(configs = configs, train_data = _df_train, test_data = _df_test, val_data = _df_val)

# fit
cph.fit()
# eval
cph_cindex , cph_ibs = cph.eval(fitter_is_rsf = False)
        
print(f'Cox Proportional Hazards: cindex {cph_cindex} , ibs {cph_ibs}')

#-----------------------------------------------------------------------------------------------
# instantiate - AFT
aft = AFT(configs = configs, train_data = _df_train, test_data = _df_test, val_data = _df_val)

# fit
aft.fit()
# eval
aft_cindex , aft_ibs = aft.eval(fitter_is_rsf = False)
        
print(f'Weibull Accelerated Failure Time: cindex {aft_cindex} , ibs {aft_ibs}')

#-----------------------------------------------------------------------------------------------
# instantiate - RSF
rsf = RSF(configs = configs, train_data = _df_train, test_data = _df_test, val_data = _df_val)

# fit
rsf.fit()
# eval
rsf_cindex , rsf_ibs = rsf.eval(fitter_is_rsf = True)


print(f'Random Survival Forest: cindex {rsf_cindex} , ibs {rsf_ibs}')

0:	[0s / 0s],		train_loss: 2.4396,	val_loss: 2.1060
1:	[0s / 0s],		train_loss: 2.0634,	val_loss: 1.7398
2:	[0s / 0s],		train_loss: 1.6757,	val_loss: 1.4191
3:	[0s / 0s],		train_loss: 1.3892,	val_loss: 1.2767
4:	[0s / 0s],		train_loss: 1.2741,	val_loss: 1.2329
5:	[0s / 0s],		train_loss: 1.2275,	val_loss: 1.2189
6:	[0s / 0s],		train_loss: 1.2112,	val_loss: 1.2136
7:	[0s / 0s],		train_loss: 1.1984,	val_loss: 1.2083
8:	[0s / 1s],		train_loss: 1.1863,	val_loss: 1.2029
9:	[0s / 1s],		train_loss: 1.1856,	val_loss: 1.2010
10:	[0s / 1s],		train_loss: 1.1770,	val_loss: 1.1975
11:	[0s / 1s],		train_loss: 1.1664,	val_loss: 1.1925
12:	[0s / 1s],		train_loss: 1.1689,	val_loss: 1.1903
13:	[0s / 1s],		train_loss: 1.1623,	val_loss: 1.1907
14:	[0s / 1s],		train_loss: 1.1579,	val_loss: 1.1909
15:	[0s / 1s],		train_loss: 1.1515,	val_loss: 1.1908
16:	[0s / 2s],		train_loss: 1.1457,	val_loss: 1.1896
17:	[0s / 2s],		train_loss: 1.1472,	val_loss: 1.1932
18:	[0s / 2s],		train_loss: 1.1521,	val_loss: 1.1892
19:

 12%|█▏        | 1234/10000 [00:01<00:11, 733.30it/s]
 50%|█████     | 50/100 [00:06<00:06,  8.10it/s]
 12%|█▏        | 1234/10000 [00:01<00:12, 720.09it/s]
 19%|█▉        | 19/100 [00:03<00:13,  6.12it/s]
 12%|█▏        | 1234/10000 [00:01<00:12, 712.92it/s]
 20%|██        | 20/100 [00:02<00:11,  6.98it/s]
 12%|█▏        | 1234/10000 [00:01<00:13, 655.21it/s]
 19%|█▉        | 19/100 [00:03<00:13,  6.16it/s]


shapes : (1775, 1775, 1775, 1775)
Deep Survival Machines: cindex 0.589253120859095 , ibs 0.2180058427814511
shapes : (1775, 1775, 1775, 1775)
Cox Proportional Hazards: cindex 0.5523647531912814 , ibs 0.19693319946615204
shapes : (1775, 1775, 1775, 1775)
Weibull Accelerated Failure Time: cindex 0.5504770035153679 , ibs 0.19698153867971116
shapes : (1775, 1775, 1775, 1775)
Random Survival Forest: cindex 0.6340819658466058 , ibs 0.18112291703494607
CPU times: total: 5min 11s
Wall time: 2min 27s
