In [1]:
import sys
sys.path.append('../')
sys.path.append('../nn_survival_analysis')
from nn_survival_analysis.general_utils import *
from nn_survival_analysis.model_utils import *
from nn_survival_analysis.losses import *
from nn_survival_analysis.models import *
from nn_survival_analysis.other_nn_models import *
from nn_survival_analysis.time_invariant_surv import *
from nn_survival_analysis.time_variant_surv import *
from nn_survival_analysis.traditional_models import *
import scipy
from pycox.datasets import metabric
from sklearn_pandas import DataFrameMapper 

# define sigmoid function - will be handy later
sigmoid = lambda z : 1 / (1 + np.exp(-z))

config_file_path = '../nn_survival_analysis/config.json'

Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)


In [2]:
def train_test_splitter(df , test_size = 0.2 , val_size = 0.25 , duration_col = 'futime' , event_col = 'death'):
    df_test = df.sample(frac=test_size)
    df_train = df.drop(df_test.index)

    df_val = df_train.sample(frac=val_size)
    df_train = df_train.drop(df_val.index)

    return df_train , df_val , df_test

In [3]:
df = metabric.read_df()
df_train , df_test , df_val = train_test_splitter(df)

cols_standardize = ['x0', 'x1', 'x2', 'x3', 'x8']
cols_leave = ['x4', 'x5', 'x6', 'x7']

standardize = [([col], StandardScaler()) for col in cols_standardize]
leave = [(col, None) for col in cols_leave]

x_mapper = DataFrameMapper(standardize + leave)

x_train = x_mapper.fit_transform(df_train).astype('float32')
x_val = x_mapper.transform(df_val).astype('float32')
x_test = x_mapper.transform(df_test).astype('float32')

get_target = lambda df: (df['duration'].values, df['event'].values)
y_train_dur , y_train_event = get_target(df_train)
y_test_dur , y_test_event = get_target(df_test)
y_val_dur , y_val_event = get_target(df_val)

Dataset 'metabric' not locally available. Downloading...
Done


In [4]:
x_train = pd.DataFrame(x_mapper.fit_transform(df_train).astype('float32') , columns = df_train.iloc[: , :-2].columns)
x_val = pd.DataFrame(x_mapper.transform(df_val).astype('float32') , columns = df_val.iloc[: , :-2].columns)
x_test = pd.DataFrame(x_mapper.transform(df_test).astype('float32') , columns = df_test.iloc[: , :-2].columns)

x_train = pd.concat([x_train , df_train[['duration' , 'event']].reset_index().drop('index' , axis = 1)] , axis = 1)
x_test = pd.concat([x_test , df_test[['duration' , 'event']].reset_index().drop('index' , axis = 1)] , axis = 1)
x_val = pd.concat([x_val , df_val[['duration' , 'event']].reset_index().drop('index' , axis = 1)] , axis = 1)

x_train.rename(columns = {'duration':'time_to_event' , 'event': 'death'}, inplace = True)
x_test.rename(columns = {'duration':'time_to_event' , 'event': 'death'}, inplace = True)
x_val.rename(columns = {'duration':'time_to_event' , 'event': 'death'}, inplace = True)

#### TIS

In [5]:
%%time

# Get configs
with open(config_file_path, "r") as file:
        configs = json.load(file)
        
#-----------------------------------------------------------------------------------------------
# instantiate - Time Invariant Survival
tis = Time_Invariant_Survival(
        configs = configs, 
        train_data = x_train,
        test_data = x_test, 
        val_data = x_val
)

# fit
tis.fit(verbose = True)
mean_ , up_ , low_ , y_test_dur , y_test_event = tis.predict() # Visualize -> tis.visualize(mean_ , up_ , low_ , _from = 40 , _to = 50 )
tis_cindex , tis_ibs = tis.evaluation(mean_ , y_test_dur , y_test_event, plot = False)
print(tis_cindex , tis_ibs)

training cluster 0
Epoch 50: Training Loss: 0.0230443, Val Loss: 0.0212367
Epoch 100: Training Loss: 0.0138489, Val Loss: 0.0092648
Epoch 150: Training Loss: 0.0160065, Val Loss: 0.0028864
Epoch 200: Training Loss: 0.0128226, Val Loss: 0.0023874
Epoch 250: Training Loss: 0.0011836, Val Loss: 0.0033392
Epoch 300: Training Loss: 0.0131717, Val Loss: 0.0142880
Epoch 350: Training Loss: -0.0037746, Val Loss: -0.0052630
Epoch 400: Training Loss: -0.0079549, Val Loss: -0.0073843
shapes : (381, 381, 381, 381)
0.621729578589974 0.4333007458885405
CPU times: total: 2min
Wall time: 29 s


#### Other Fitters

In [6]:
%%time
#-----------------------------------------------------------------------------------------------
# instantiate - PyCox
pyc = PYC(configs = configs, train_data = x_train, test_data = x_test, val_data = x_val, num_durations = 10)

# fit
pyc.fit()

# eval
pyc_cindex , pyc_ibs = pyc.eval()
        
print(f'PyCox: cindex {pyc_cindex} , ibs {pyc_ibs}')
#-----------------------------------------------------------------------------------------------
# instantiate - Deep Survival Machines
dsm = DSM(configs = configs, train_data = x_train, test_data = x_test, val_data = x_val, num_durations = 10)

# fit
dsm.fit()

# eval
dsm_cindex , dsm_ibs = dsm.eval()
       
print(f'Deep Survival Machines: cindex {dsm_cindex} , ibs {dsm_ibs}')

# -----------------------------------------------------------------------------------------------
# instantiate - CPH
cph = CPH(configs = configs, train_data = x_train, test_data = x_test, val_data = x_val)

# fit
cph.fit()
# eval
cph_cindex , cph_ibs = cph.eval(fitter_is_rsf = False)
        
print(f'Cox Proportional Hazards: cindex {cph_cindex} , ibs {cph_ibs}')

#-----------------------------------------------------------------------------------------------
# instantiate - AFT
aft = AFT(configs = configs, train_data = x_train, test_data = x_test, val_data = x_val)

# fit
aft.fit()
# eval
aft_cindex , aft_ibs = aft.eval(fitter_is_rsf = False)
        
print(f'Weibull Accelerated Failure Time: cindex {aft_cindex} , ibs {aft_ibs}')

#-----------------------------------------------------------------------------------------------
# instantiate - RSF
rsf = RSF(configs = configs, train_data = x_train, test_data = x_test, val_data = x_val)

# fit
rsf.fit()
# eval
rsf_cindex , rsf_ibs = rsf.eval(fitter_is_rsf = True)


print(f'Random Survival Forest: cindex {rsf_cindex} , ibs {rsf_ibs}')

0:	[0s / 0s],		train_loss: 3.1690,	val_loss: 3.0904
1:	[0s / 0s],		train_loss: 2.9098,	val_loss: 2.8373
2:	[0s / 0s],		train_loss: 2.8725,	val_loss: 2.6685
3:	[0s / 0s],		train_loss: 2.7517,	val_loss: 2.5517
4:	[0s / 0s],		train_loss: 2.6576,	val_loss: 2.4476
5:	[0s / 0s],		train_loss: 2.5985,	val_loss: 2.3644
6:	[0s / 0s],		train_loss: 2.4759,	val_loss: 2.2753
7:	[0s / 0s],		train_loss: 2.3749,	val_loss: 2.1635
8:	[0s / 0s],		train_loss: 2.2515,	val_loss: 2.0344
9:	[0s / 0s],		train_loss: 2.1028,	val_loss: 1.9076
10:	[0s / 0s],		train_loss: 2.0265,	val_loss: 1.7955
11:	[0s / 0s],		train_loss: 1.8949,	val_loss: 1.7005
12:	[0s / 0s],		train_loss: 1.7600,	val_loss: 1.6043
13:	[0s / 1s],		train_loss: 1.7098,	val_loss: 1.5283
14:	[0s / 1s],		train_loss: 1.6254,	val_loss: 1.4700
15:	[0s / 1s],		train_loss: 1.5257,	val_loss: 1.4418
16:	[0s / 1s],		train_loss: 1.5549,	val_loss: 1.4226
17:	[0s / 1s],		train_loss: 1.5107,	val_loss: 1.4113
18:	[0s / 1s],		train_loss: 1.4427,	val_loss: 1.3921
19:

  6%|▌         | 610/10000 [00:02<00:36, 255.94it/s]
 14%|█▍        | 14/100 [00:02<00:16,  5.14it/s]
  6%|▌         | 610/10000 [00:02<00:37, 248.00it/s]
 16%|█▌        | 16/100 [00:03<00:16,  5.01it/s]
  6%|▌         | 610/10000 [00:02<00:34, 273.96it/s]
  8%|▊         | 8/100 [00:02<00:23,  3.97it/s]
  6%|▌         | 610/10000 [00:02<00:41, 226.43it/s]
 18%|█▊        | 18/100 [00:03<00:17,  4.59it/s]


shapes : (381, 381, 381, 381)
Deep Survival Machines: cindex 0.6403699782248237 , ibs 0.34104402019157726
shapes : (381, 381, 381, 381)
Cox Proportional Hazards: cindex 0.6597878418602793 , ibs 0.17181585111251882
shapes : (381, 381, 381, 381)
Weibull Accelerated Failure Time: cindex 0.6624048584613541 , ibs 0.1705047102713802
shapes : (381, 381, 381, 381)
Random Survival Forest: cindex 0.6482410052540104 , ibs 0.1706838736598975
CPU times: total: 1min 21s
Wall time: 46.2 s
