In [1]:
import sys
sys.path.append('../')
sys.path.append('../nn_survival_analysis')
from nn_survival_analysis.general_utils import *
from nn_survival_analysis.model_utils import *
from nn_survival_analysis.losses import *
from nn_survival_analysis.models import *
from nn_survival_analysis.other_nn_models import *
from nn_survival_analysis.time_invariant_surv import *
from nn_survival_analysis.time_variant_surv import *
from nn_survival_analysis.traditional_models import *
import scipy

# define sigmoid function - will be handy later
sigmoid = lambda z : 1 / (1 + np.exp(-z))

config_file_path = '../nn_survival_analysis/config.json'

Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)


In [2]:
flchain = pd.read_csv("../resources/other_data/FLCHAIN.csv")
flchain.rename(columns = {'futime':'time_to_event'} , inplace = True)

In [3]:
def train_test_splitter(df , test_size = 0.2 , val_size = 0.25 , duration_col = 'futime' , event_col = 'death'):
    df_test = df.sample(frac=test_size)
    df_train = df.drop(df_test.index)

    df_val = df_train.sample(frac=val_size)
    df_train = df_train.drop(df_val.index)

    return df_train , df_val , df_test

In [4]:
# Create an instance of the OneHotEncoder class
encoder1 = LabelEncoder()
encoder2 = LabelEncoder()

# Fit and transform the data
encoded_sex_data = encoder1.fit_transform(flchain['sex'])
encoded_sex_data = pd.DataFrame(encoded_sex_data.reshape(-1 , 1) , columns = ['sex']) 
print(f'encoded_sex_data.shape {encoded_sex_data.shape}')

# Fit and transform the data
encoded_chap_data = encoder2.fit_transform(flchain['chapter'])
encoded_chap_data = pd.DataFrame(encoded_chap_data.reshape(-1 , 1) , columns = ['chapter'])
print(f'encoded_chap_data.shape {encoded_chap_data.shape}')

flchain_mod = pd.DataFrame(
    pd.concat(
        [
            encoded_sex_data , 
            encoded_chap_data , 
            flchain[['age' , 'sample.yr' , 'kappa' , 'lambda' , 'flc.grp' , 'creatinine' , 'mgus']] , 
            flchain[['time_to_event' , 'death']]
        ] ,
        axis = 1
        )
    )

print(f'flchain_mod.shape {flchain_mod.shape}')

encoded_sex_data.shape (7874, 1)
encoded_chap_data.shape (7874, 1)
flchain_mod.shape (7874, 11)


In [5]:
def impute(df):
    _imputer = SimpleImputer(strategy='mean')
    df_imp = pd.DataFrame(_imputer.fit_transform(df) , columns = df.columns)
    return df_imp

def scale(df):
    _scaler = StandardScaler()
    scale_cols = ['age' , 'sample.yr' , 'kappa' , 'lambda' , 'flc.grp' , 'creatinine' , 'mgus']
    unscaled_cols = [col for col in df.columns if col not in scale_cols]
    scaled = pd.DataFrame(_scaler.fit_transform(df[scale_cols]) , columns = scale_cols)
    return pd.concat([scaled , df[unscaled_cols]] , axis = 1)

In [6]:
df_train , df_test , df_val = train_test_splitter(flchain_mod)
preprocess = lambda df: scale(impute(df))
# preprocess train
x_train = preprocess(df_train)

# preprocess test
x_test = preprocess(df_test)

# preprocess val
x_val = preprocess(df_val)

#### TIS

In [7]:
%%time

# Get configs
with open(config_file_path, "r") as file:
        configs = json.load(file)

#-----------------------------------------------------------------------------------------------
# instantiate - Time Invariant Survival
tis = Time_Invariant_Survival(
        configs = configs, 
        train_data = x_train,
        test_data = x_test, 
        val_data = x_val
)

# fit
tis.fit(verbose = True)
mean_ , up_ , low_ , y_test_dur , y_test_event = tis.predict() # Visualize -> tis.visualize(mean_ , up_ , low_ , _from = 40 , _to = 50 )
tis_cindex , tis_ibs = tis.evaluation(mean_ , y_test_dur , y_test_event, plot = False)
print(tis_cindex , tis_ibs)

training cluster 0
Epoch 50: Training Loss: -0.2409167, Val Loss: -0.2741865
Epoch 100: Training Loss: -0.2657355, Val Loss: -0.3048589
Epoch 150: Training Loss: -0.2934957, Val Loss: -0.3156575
Epoch 200: Training Loss: -0.3343136, Val Loss: -0.3268056
Epoch 250: Training Loss: -0.3424828, Val Loss: -0.3402769
Epoch 300: Training Loss: -0.3323167, Val Loss: -0.3330546
Epoch 350: Training Loss: -0.3458107, Val Loss: -0.3494644
Epoch 400: Training Loss: -0.3638918, Val Loss: -0.3583965
shapes : (1575, 1575, 1575, 1575)
0.9245094591214367 0.10560674525871205
CPU times: total: 13min 40s
Wall time: 2min 40s


#### Other Fitters

In [8]:
%%time
#-----------------------------------------------------------------------------------------------
# instantiate - PyCox
pyc = PYC(configs = configs, train_data = x_train, test_data = x_test, val_data = x_val, num_durations = 10)

# fit
pyc.fit()

# eval
pyc_cindex , pyc_ibs = pyc.eval()
        
print(f'PyCox: cindex {pyc_cindex} , ibs {pyc_ibs}')
#-----------------------------------------------------------------------------------------------
# instantiate - Deep Survival Machines
dsm = DSM(configs = configs, train_data = x_train, test_data = x_test, val_data = x_val, num_durations = 10)

# fit
dsm.fit()

# eval
dsm_cindex , dsm_ibs = dsm.eval()
       
print(f'Deep Survival Machines: cindex {dsm_cindex} , ibs {dsm_ibs}')

# -----------------------------------------------------------------------------------------------
# instantiate - CPH
cph = CPH(configs = configs, train_data = x_train, test_data = x_test, val_data = x_val)

# fit
cph.fit()
# eval
cph_cindex , cph_ibs = cph.eval(fitter_is_rsf = False)
        
print(f'Cox Proportional Hazards: cindex {cph_cindex} , ibs {cph_ibs}')

#-----------------------------------------------------------------------------------------------
# instantiate - AFT
aft = AFT(configs = configs, train_data = x_train, test_data = x_test, val_data = x_val)

# fit
aft.fit()
# eval
aft_cindex , aft_ibs = aft.eval(fitter_is_rsf = False)
        
print(f'Weibull Accelerated Failure Time: cindex {aft_cindex} , ibs {aft_ibs}')

#-----------------------------------------------------------------------------------------------
# instantiate - RSF
rsf = RSF(configs = configs, train_data = x_train, test_data = x_test, val_data = x_val)

# fit
rsf.fit()
# eval
rsf_cindex , rsf_ibs = rsf.eval(fitter_is_rsf = True)


print(f'Random Survival Forest: cindex {rsf_cindex} , ibs {rsf_ibs}')

0:	[0s / 0s],		train_loss: 4.3521,	val_loss: 2.9553
1:	[0s / 0s],		train_loss: 3.4812,	val_loss: 2.6144
2:	[0s / 0s],		train_loss: 2.6314,	val_loss: 1.8869
3:	[0s / 0s],		train_loss: 1.7667,	val_loss: 1.1986
4:	[0s / 1s],		train_loss: 1.2120,	val_loss: 0.8345
5:	[0s / 1s],		train_loss: 0.9533,	val_loss: 0.7522
6:	[0s / 1s],		train_loss: 0.8554,	val_loss: 0.6804
7:	[0s / 1s],		train_loss: 0.8221,	val_loss: 0.7595
8:	[0s / 2s],		train_loss: 0.8317,	val_loss: 0.7555
9:	[0s / 2s],		train_loss: 0.7980,	val_loss: 0.6709
10:	[0s / 2s],		train_loss: 0.7718,	val_loss: 0.6405
11:	[0s / 2s],		train_loss: 0.7617,	val_loss: 0.6620
12:	[0s / 2s],		train_loss: 0.7467,	val_loss: 0.6597
13:	[0s / 3s],		train_loss: 0.7391,	val_loss: 0.6364
14:	[0s / 3s],		train_loss: 0.7096,	val_loss: 0.6433
15:	[0s / 3s],		train_loss: 0.7111,	val_loss: 0.6191
16:	[0s / 3s],		train_loss: 0.6867,	val_loss: 0.6163
17:	[0s / 3s],		train_loss: 0.6918,	val_loss: 0.6395
18:	[0s / 4s],		train_loss: 0.6962,	val_loss: 0.6100
19:

100%|██████████| 10000/10000 [00:47<00:00, 211.76it/s]
100%|██████████| 100/100 [00:30<00:00,  3.31it/s]
100%|██████████| 10000/10000 [00:18<00:00, 540.38it/s]
100%|██████████| 100/100 [00:16<00:00,  6.15it/s]
100%|██████████| 10000/10000 [00:20<00:00, 492.70it/s]
100%|██████████| 100/100 [00:18<00:00,  5.40it/s]
100%|██████████| 10000/10000 [00:19<00:00, 517.61it/s]
100%|██████████| 100/100 [00:17<00:00,  5.74it/s]


shapes : (1575, 1575, 1575, 1575)
Deep Survival Machines: cindex 0.0 , ibs nan
shapes : (1575, 1575, 1575, 1575)
Cox Proportional Hazards: cindex 0.8963734176024679 , ibs 0.07483696698342991
shapes : (1575, 1575, 1575, 1575)
Weibull Accelerated Failure Time: cindex 0.8930647254669601 , ibs 0.07575229651062905
shapes : (1575, 1575, 1575, 1575)
Random Survival Forest: cindex 0.9335308391174313 , ibs 0.07106165489894868
CPU times: total: 22min 15s
Wall time: 6min 3s
