In [1]:
import sys
sys.path.append('../')
sys.path.append('../nn_survival_analysis')
from nn_survival_analysis.general_utils import *
from nn_survival_analysis.model_utils import *
from nn_survival_analysis.losses import *
from nn_survival_analysis.models import *
from nn_survival_analysis.other_nn_models import *
from nn_survival_analysis.time_invariant_surv import *
from nn_survival_analysis.time_variant_surv import *
from nn_survival_analysis.traditional_models import *
import scipy

# define sigmoid function - will be handy later
sigmoid = lambda z : 1 / (1 + np.exp(-z))

config_file_path = '../nn_survival_analysis/config.json'

Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)


In [2]:
flchain = pd.read_csv("../resources/other_data/FLCHAIN.csv")
flchain.rename(columns = {'futime':'time_to_event'} , inplace = True)

In [3]:
def train_test_splitter(df , test_size = 0.2 , val_size = 0.25 , duration_col = 'futime' , event_col = 'death'):
    df_test = df.sample(frac=test_size)
    df_train = df.drop(df_test.index)

    df_val = df_train.sample(frac=val_size)
    df_train = df_train.drop(df_val.index)

    return df_train , df_val , df_test

In [4]:
# Create an instance of the OneHotEncoder class
encoder1 = LabelEncoder()
encoder2 = LabelEncoder()

# Fit and transform the data
encoded_sex_data = encoder1.fit_transform(flchain['sex'])
encoded_sex_data = pd.DataFrame(encoded_sex_data.reshape(-1 , 1) , columns = ['sex']) 
print(f'encoded_sex_data.shape {encoded_sex_data.shape}')

# Fit and transform the data
encoded_chap_data = encoder2.fit_transform(flchain['chapter'])
encoded_chap_data = pd.DataFrame(encoded_chap_data.reshape(-1 , 1) , columns = ['chapter'])
print(f'encoded_chap_data.shape {encoded_chap_data.shape}')

flchain_mod = pd.DataFrame(
    pd.concat(
        [
            encoded_sex_data , 
            encoded_chap_data , 
            flchain[['age' , 'sample.yr' , 'kappa' , 'lambda' , 'flc.grp' , 'creatinine' , 'mgus']] , 
            flchain[['time_to_event' , 'death']]
        ] ,
        axis = 1
        )
    )

print(f'flchain_mod.shape {flchain_mod.shape}')

encoded_sex_data.shape (7874, 1)
encoded_chap_data.shape (7874, 1)
flchain_mod.shape (7874, 11)


In [5]:
def impute(df):
    _imputer = SimpleImputer(strategy='mean')
    df_imp = pd.DataFrame(_imputer.fit_transform(df) , columns = df.columns)
    return df_imp

def scale(df):
    _scaler = StandardScaler()
    scale_cols = ['age' , 'sample.yr' , 'kappa' , 'lambda' , 'flc.grp' , 'creatinine' , 'mgus']
    unscaled_cols = [col for col in df.columns if col not in scale_cols]
    scaled = pd.DataFrame(_scaler.fit_transform(df[scale_cols]) , columns = scale_cols)
    return pd.concat([scaled , df[unscaled_cols]] , axis = 1)

In [6]:
df_train , df_test , df_val = train_test_splitter(flchain_mod)
preprocess = lambda df: scale(impute(df))
# preprocess train
x_train = preprocess(df_train)

# preprocess test
x_test = preprocess(df_test)

# preprocess val
x_val = preprocess(df_val)

#### TIS

In [9]:
%%time

# Get configs
with open(config_file_path, "r") as file:
        configs = json.load(file)

#-----------------------------------------------------------------------------------------------
# instantiate - Time Invariant Survival
tis = Time_Invariant_Survival(
        configs = configs, 
        train_data = x_train,
        test_data = x_test, 
        val_data = x_val
)

# fit
tis.fit(verbose = True)
mean_ , up_ , low_ , y_test_dur , y_test_event = tis.predict() # Visualize -> tis.visualize(mean_ , up_ , low_ , _from = 40 , _to = 50 )
tis_cindex , tis_ibs = tis.evaluation(mean_ , y_test_dur , y_test_event, plot = False)
print(tis_cindex , tis_ibs)

training cluster 0
Epoch 50: Training Loss: -0.3624409, Val Loss: -0.3709950
Epoch 100: Training Loss: -0.4037809, Val Loss: -0.3759740
Epoch 150: Training Loss: -0.3937853, Val Loss: -0.3873606
Epoch 200: Training Loss: -0.4089943, Val Loss: -0.3910943
Epoch 250: Training Loss: -0.3807949, Val Loss: -0.3982277
Epoch 300: Training Loss: -0.3711786, Val Loss: -0.3942183
Epoch 350: Training Loss: -0.3999810, Val Loss: -0.4037854
Epoch 400: Training Loss: -0.4269630, Val Loss: -0.4099655
Epoch 450: Training Loss: -0.4170119, Val Loss: -0.4138110
Epoch 500: Training Loss: -0.3885909, Val Loss: -0.4136161
Epoch 550: Training Loss: -0.4063324, Val Loss: -0.4166683
Epoch 600: Training Loss: -0.4347412, Val Loss: -0.4206513
shapes : (1575, 1575, 1575, 1575)
0.9256305351717776 0.11718056272459135
CPU times: total: 8min 45s
Wall time: 1min 56s


#### Other Fitters

In [10]:
%%time
#-----------------------------------------------------------------------------------------------
# instantiate - PyCox
pyc = PYC(configs = configs, train_data = x_train, test_data = x_test, val_data = x_val, num_durations = 10)

# fit
pyc.fit()

# eval
pyc_cindex , pyc_ibs = pyc.eval()
        
print(f'PyCox: cindex {pyc_cindex} , ibs {pyc_ibs}')
#-----------------------------------------------------------------------------------------------
# instantiate - Deep Survival Machines
dsm = DSM(configs = configs, train_data = x_train, test_data = x_test, val_data = x_val, num_durations = 10)

# fit
dsm.fit()

# eval
dsm_cindex , dsm_ibs = dsm.eval()
       
print(f'Deep Survival Machines: cindex {dsm_cindex} , ibs {dsm_ibs}')

# -----------------------------------------------------------------------------------------------
# instantiate - CPH
cph = CPH(configs = configs, train_data = x_train, test_data = x_test, val_data = x_val)

# fit
cph.fit()
# eval
cph_cindex , cph_ibs = cph.eval(fitter_is_rsf = False)
        
print(f'Cox Proportional Hazards: cindex {cph_cindex} , ibs {cph_ibs}')

#-----------------------------------------------------------------------------------------------
# instantiate - AFT
aft = AFT(configs = configs, train_data = x_train, test_data = x_test, val_data = x_val)

# fit
aft.fit()
# eval
aft_cindex , aft_ibs = aft.eval(fitter_is_rsf = False)
        
print(f'Weibull Accelerated Failure Time: cindex {aft_cindex} , ibs {aft_ibs}')

#-----------------------------------------------------------------------------------------------
# instantiate - RSF
rsf = RSF(configs = configs, train_data = x_train, test_data = x_test, val_data = x_val)

# fit
rsf.fit()
# eval
rsf_cindex , rsf_ibs = rsf.eval(fitter_is_rsf = True)


print(f'Random Survival Forest: cindex {rsf_cindex} , ibs {rsf_ibs}')

0:	[0s / 0s],		train_loss: 4.3831,	val_loss: 3.2037
1:	[0s / 0s],		train_loss: 3.5251,	val_loss: 2.6299
2:	[0s / 0s],		train_loss: 2.6921,	val_loss: 1.9053
3:	[0s / 0s],		train_loss: 1.7268,	val_loss: 1.2249
4:	[0s / 0s],		train_loss: 1.2080,	val_loss: 0.8478
5:	[0s / 0s],		train_loss: 0.9761,	val_loss: 0.7567
6:	[0s / 0s],		train_loss: 0.8800,	val_loss: 0.7093
7:	[0s / 0s],		train_loss: 0.8411,	val_loss: 0.6674
8:	[0s / 0s],		train_loss: 0.8069,	val_loss: 0.6898
9:	[0s / 0s],		train_loss: 0.7869,	val_loss: 0.6856
10:	[0s / 0s],		train_loss: 0.7547,	val_loss: 0.6592
11:	[0s / 1s],		train_loss: 0.7491,	val_loss: 0.6258
12:	[0s / 1s],		train_loss: 0.7635,	val_loss: 0.6127
13:	[0s / 1s],		train_loss: 0.7453,	val_loss: 0.6143
14:	[0s / 1s],		train_loss: 0.7055,	val_loss: 0.6260
15:	[0s / 1s],		train_loss: 0.7014,	val_loss: 0.6018
16:	[0s / 1s],		train_loss: 0.7028,	val_loss: 0.6101
17:	[0s / 1s],		train_loss: 0.7031,	val_loss: 0.5999
18:	[0s / 1s],		train_loss: 0.7040,	val_loss: 0.6110
19:

100%|██████████| 10000/10000 [00:13<00:00, 724.38it/s]
100%|██████████| 100/100 [00:09<00:00, 10.01it/s]
100%|██████████| 10000/10000 [00:13<00:00, 761.94it/s]
100%|██████████| 100/100 [00:11<00:00,  9.07it/s]
100%|██████████| 10000/10000 [00:13<00:00, 761.10it/s]
100%|██████████| 100/100 [00:10<00:00,  9.12it/s]
100%|██████████| 10000/10000 [00:13<00:00, 740.42it/s]
100%|██████████| 100/100 [00:12<00:00,  8.14it/s]


shapes : (1575, 1575, 1575, 1575)
Deep Survival Machines: cindex 0.0 , ibs nan
shapes : (1575, 1575, 1575, 1575)
Cox Proportional Hazards: cindex 0.8955044690129665 , ibs 0.08126015193916596
shapes : (1575, 1575, 1575, 1575)
Weibull Accelerated Failure Time: cindex 0.894211261440437 , ibs 0.08147364695364535
shapes : (1575, 1575, 1575, 1575)
Random Survival Forest: cindex 0.9308624350444722 , ibs 0.07124332232588824
CPU times: total: 11min 42s
Wall time: 3min 11s
