In [3]:
import sys
sys.path.append('../')
sys.path.append('../nn_survival_analysis')
from nn_survival_analysis.general_utils import *
from nn_survival_analysis.model_utils import *
from nn_survival_analysis.losses import *
from nn_survival_analysis.models import *
from nn_survival_analysis.other_nn_models import *
from nn_survival_analysis.time_invariant_surv import *
from nn_survival_analysis.time_variant_surv import *
from nn_survival_analysis.traditional_models import *
import scipy
from pycox.datasets import metabric
from sklearn_pandas import DataFrameMapper 

# define sigmoid function - will be handy later
sigmoid = lambda z : 1 / (1 + np.exp(-z))

config_file_path = '../nn_survival_analysis/config.json'

In [4]:
def train_test_splitter(df , test_size = 0.2 , val_size = 0.25 , duration_col = 'futime' , event_col = 'death'):
    df_test = df.sample(frac=test_size)
    df_train = df.drop(df_test.index)

    df_val = df_train.sample(frac=val_size)
    df_train = df_train.drop(df_val.index)

    return df_train , df_val , df_test

In [5]:
df = metabric.read_df()
df_train , df_test , df_val = train_test_splitter(df)

cols_standardize = ['x0', 'x1', 'x2', 'x3', 'x8']
cols_leave = ['x4', 'x5', 'x6', 'x7']

standardize = [([col], StandardScaler()) for col in cols_standardize]
leave = [(col, None) for col in cols_leave]

x_mapper = DataFrameMapper(standardize + leave)

x_train = x_mapper.fit_transform(df_train).astype('float32')
x_val = x_mapper.transform(df_val).astype('float32')
x_test = x_mapper.transform(df_test).astype('float32')

get_target = lambda df: (df['duration'].values, df['event'].values)
y_train_dur , y_train_event = get_target(df_train)
y_test_dur , y_test_event = get_target(df_test)
y_val_dur , y_val_event = get_target(df_val)

In [6]:
x_train = pd.DataFrame(x_mapper.fit_transform(df_train).astype('float32') , columns = df_train.iloc[: , :-2].columns)
x_val = pd.DataFrame(x_mapper.transform(df_val).astype('float32') , columns = df_val.iloc[: , :-2].columns)
x_test = pd.DataFrame(x_mapper.transform(df_test).astype('float32') , columns = df_test.iloc[: , :-2].columns)

x_train = pd.concat([x_train , df_train[['duration' , 'event']].reset_index().drop('index' , axis = 1)] , axis = 1)
x_test = pd.concat([x_test , df_test[['duration' , 'event']].reset_index().drop('index' , axis = 1)] , axis = 1)
x_val = pd.concat([x_val , df_val[['duration' , 'event']].reset_index().drop('index' , axis = 1)] , axis = 1)

x_train.rename(columns = {'duration':'time_to_event' , 'event': 'death'}, inplace = True)
x_test.rename(columns = {'duration':'time_to_event' , 'event': 'death'}, inplace = True)
x_val.rename(columns = {'duration':'time_to_event' , 'event': 'death'}, inplace = True)

#### TIS

In [10]:
%%time

# Get configs
with open(config_file_path, "r") as file:
        configs = json.load(file)
        
#-----------------------------------------------------------------------------------------------
# instantiate - Time Invariant Survival
tis = Time_Invariant_Survival(
        configs = configs, 
        train_data = x_train,
        test_data = x_test, 
        val_data = x_val
)

# fit
tis.fit(verbose = True)
mean_ , up_ , low_ , y_test_dur , y_test_event = tis.predict() # Visualize -> tis.visualize(mean_ , up_ , low_ , _from = 40 , _to = 50 )
tis_cindex , tis_ibs = tis.evaluation(mean_ , y_test_dur , y_test_event, plot = False)
print(tis_cindex , tis_ibs)

training cluster 0
Epoch 50: Training Loss: -0.0100063, Val Loss: -0.0123921
Epoch 100: Training Loss: -0.0147883, Val Loss: -0.0147190
Epoch 150: Training Loss: -0.0233001, Val Loss: -0.0218463
Epoch 200: Training Loss: -0.0243625, Val Loss: -0.0233700
Epoch 250: Training Loss: -0.0190197, Val Loss: -0.0248876
Epoch 300: Training Loss: -0.0085105, Val Loss: -0.0208942
Epoch 350: Training Loss: -0.0101741, Val Loss: -0.0218758
Epoch 400: Training Loss: -0.0064592, Val Loss: -0.0239173
Epoch 450: Training Loss: -0.0107445, Val Loss: -0.0200270
Epoch 500: Training Loss: -0.0230233, Val Loss: -0.0128368
Epoch 550: Training Loss: -0.0192515, Val Loss: -0.0248307
Epoch 600: Training Loss: -0.0164496, Val Loss: -0.0216110
shapes : (381, 381, 381, 381)
0.5890650135363126 0.3872190798043313
CPU times: total: 48.1 s
Wall time: 9.37 s


#### Other Fitters

In [11]:
%%time
#-----------------------------------------------------------------------------------------------
# instantiate - PyCox
pyc = PYC(configs = configs, train_data = x_train, test_data = x_test, val_data = x_val, num_durations = 10)

# fit
pyc.fit()

# eval
pyc_cindex , pyc_ibs = pyc.eval()
        
print(f'PyCox: cindex {pyc_cindex} , ibs {pyc_ibs}')
#-----------------------------------------------------------------------------------------------
# instantiate - Deep Survival Machines
dsm = DSM(configs = configs, train_data = x_train, test_data = x_test, val_data = x_val, num_durations = 10)

# fit
dsm.fit()

# eval
dsm_cindex , dsm_ibs = dsm.eval()
       
print(f'Deep Survival Machines: cindex {dsm_cindex} , ibs {dsm_ibs}')

# -----------------------------------------------------------------------------------------------
# instantiate - CPH
cph = CPH(configs = configs, train_data = x_train, test_data = x_test, val_data = x_val)

# fit
cph.fit()
# eval
cph_cindex , cph_ibs = cph.eval(fitter_is_rsf = False)
        
print(f'Cox Proportional Hazards: cindex {cph_cindex} , ibs {cph_ibs}')

#-----------------------------------------------------------------------------------------------
# instantiate - AFT
aft = AFT(configs = configs, train_data = x_train, test_data = x_test, val_data = x_val)

# fit
aft.fit()
# eval
aft_cindex , aft_ibs = aft.eval(fitter_is_rsf = False)
        
print(f'Weibull Accelerated Failure Time: cindex {aft_cindex} , ibs {aft_ibs}')

#-----------------------------------------------------------------------------------------------
# instantiate - RSF
rsf = RSF(configs = configs, train_data = x_train, test_data = x_test, val_data = x_val)

# fit
rsf.fit()
# eval
rsf_cindex , rsf_ibs = rsf.eval(fitter_is_rsf = True)


print(f'Random Survival Forest: cindex {rsf_cindex} , ibs {rsf_ibs}')

0:	[0s / 0s],		train_loss: 3.2039,	val_loss: 2.8919
1:	[0s / 0s],		train_loss: 2.9687,	val_loss: 2.8036
2:	[0s / 0s],		train_loss: 2.8297,	val_loss: 2.6876
3:	[0s / 0s],		train_loss: 2.7513,	val_loss: 2.5781
4:	[0s / 0s],		train_loss: 2.7182,	val_loss: 2.4893
5:	[0s / 0s],		train_loss: 2.5951,	val_loss: 2.3948
6:	[0s / 0s],		train_loss: 2.4830,	val_loss: 2.2898
7:	[0s / 0s],		train_loss: 2.3424,	val_loss: 2.1898
8:	[0s / 0s],		train_loss: 2.2193,	val_loss: 2.0819
9:	[0s / 0s],		train_loss: 2.0996,	val_loss: 1.9694
10:	[0s / 0s],		train_loss: 1.9376,	val_loss: 1.8410
11:	[0s / 0s],		train_loss: 1.8507,	val_loss: 1.7356
12:	[0s / 0s],		train_loss: 1.7816,	val_loss: 1.6492
13:	[0s / 0s],		train_loss: 1.6743,	val_loss: 1.5671
14:	[0s / 0s],		train_loss: 1.6069,	val_loss: 1.5124
15:	[0s / 0s],		train_loss: 1.5324,	val_loss: 1.4853
16:	[0s / 0s],		train_loss: 1.4910,	val_loss: 1.4638
17:	[0s / 0s],		train_loss: 1.4455,	val_loss: 1.4528
18:	[0s / 0s],		train_loss: 1.4347,	val_loss: 1.4614
19:

100%|██████████| 10000/10000 [00:10<00:00, 963.96it/s]
100%|██████████| 100/100 [00:02<00:00, 41.69it/s]
100%|██████████| 10000/10000 [00:10<00:00, 983.32it/s]
100%|██████████| 100/100 [00:02<00:00, 36.36it/s]
100%|██████████| 10000/10000 [00:10<00:00, 979.96it/s]
100%|██████████| 100/100 [00:02<00:00, 33.56it/s]
100%|██████████| 10000/10000 [00:10<00:00, 972.31it/s]
100%|██████████| 100/100 [00:03<00:00, 31.48it/s]


shapes : (381, 381, 381, 381)
Deep Survival Machines: cindex 0.0 , ibs nan
shapes : (381, 381, 381, 381)
Cox Proportional Hazards: cindex 0.607519918364143 , ibs 0.17498796741261524
shapes : (381, 381, 381, 381)
Weibull Accelerated Failure Time: cindex 0.6110129910907022 , ibs 0.17663067235510943
shapes : (381, 381, 381, 381)
Random Survival Forest: cindex 0.6314612033439303 , ibs 0.17305955357321945
CPU times: total: 2min 3s
Wall time: 1min
