### **Experiments**

In [1]:
import sys
import os
import pickle
sys.path.append('../')
sys.path.append('../nn_survival_analysis')

In [2]:
from nn_survival_analysis.general_utils import *
from nn_survival_analysis.model_utils import *
from nn_survival_analysis.losses import *
from nn_survival_analysis.models import *
from nn_survival_analysis.other_nn_models import *
from nn_survival_analysis.time_invariant_surv import *
from nn_survival_analysis.time_variant_surv import *
from nn_survival_analysis.traditional_models import *

Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)


### **Load Data for TIS**

In [5]:
# Read the pickled DataFrames
with open('../05_preprocessing_emr_data/data/x_train.pickle', 'rb') as file:
    x_train = pickle.load(file)
with open('../05_preprocessing_emr_data/data/x_test.pickle', 'rb') as file:
    x_test = pickle.load(file)
with open('../05_preprocessing_emr_data/data/x_val.pickle', 'rb') as file:
    x_val = pickle.load(file)
# Get configs
with open('../nn_survival_analysis/config.json', "r") as file:
    configs = json.load(file)

### **Testing Time-Invariant Performance**

In [6]:
#-----------------------------------------------------------------------------------------------
# instantiate - Time Invariant Survival
tis = Time_Invariant_Survival(
        configs = configs, 
        train_data = x_train,
        test_data = x_test, 
        val_data = x_val
)

# fit
tis.fit(verbose = True)
mean_ , up_ , low_ , y_test_dur , y_test_event = tis.predict() # Visualize -> tis.visualize(mean_ , up_ , low_ , _from = 40 , _to = 50 )
tis_cindex , tis_ibs = tis.evaluation(mean_ , y_test_dur , y_test_event, plot = False)    

training cluster 0
Epoch 50: Training Loss: 0.0052754, Val Loss: 0.0063630
Epoch 100: Training Loss: 0.0035944, Val Loss: 0.0030867
Epoch 150: Training Loss: 0.0013075, Val Loss: 0.0006926
Epoch 200: Training Loss: -0.0007793, Val Loss: 0.0005737
Epoch 250: Training Loss: -0.0004497, Val Loss: -0.0002188
Epoch 300: Training Loss: -0.0022442, Val Loss: 0.0007392
Epoch 350: Training Loss: -0.0104226, Val Loss: -0.0011012
Epoch 400: Training Loss: -0.0033959, Val Loss: -0.0000453
shapes : (1060, 1060, 1060, 1060)


In [7]:
tis_cindex , tis_ibs

(0.6957719495276475, 0.39952222311343927)

### **Load Data for TVS**

In [8]:
# Read the pickled DataFrames
with open('../05_preprocessing_emr_data/data/x_train_reshape_tv.pickle', 'rb') as file:
    x_train_reshape_tv = pickle.load(file)
with open('../05_preprocessing_emr_data/data/x_test_reshape_tv.pickle', 'rb') as file:
    x_test_reshape_tv = pickle.load(file)
with open('../05_preprocessing_emr_data/data/x_val_reshape_tv.pickle', 'rb') as file:
    x_val_reshape_tv = pickle.load(file)

# Read the pickled targets
with open('../05_preprocessing_emr_data/data/y_train.pickle', 'rb') as file:
    y_train = pickle.load(file)
with open('../05_preprocessing_emr_data/data/y_test.pickle', 'rb') as file:
    y_test = pickle.load(file)
with open('../05_preprocessing_emr_data/data/y_val.pickle', 'rb') as file:
    y_val = pickle.load(file)

### **Testing Time-Variant Performance**

In [9]:
#-----------------------------------------------------------------------------------------------
# instantiate - Time Variant Survival
tvs = Time_Variant_Survival(
            configs = configs, 
            x_train_reshape_tv = x_train_reshape_tv,
            x_test_reshape_tv = x_test_reshape_tv, 
            x_val_reshape_tv = x_val_reshape_tv,
            y_train = y_train,
            y_test = y_test,
            y_val = y_val
)

# fit
tvs.fit(verbose = True)
mean_ , up_ , low_ , y_test_dur , y_test_event = tvs.predict() # Visualize -> tis.visualize(mean_ , up_ , low_ , _from = 40 , _to = 50 )
tvs_cindex , tvs_ibs = tvs.evaluation(mean_ , y_test_dur , y_test_event, plot = False)

Epoch 50: Training Loss: -0.0025976, Val Loss: 0.0015660
Epoch 100: Training Loss: 0.0000589, Val Loss: 0.0010896
shapes : (1180, 1180, 1180, 1180)


In [10]:
tvs_cindex , tvs_ibs

(0.7335313759085885, 0.40075786613809444)