In [1]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
import pandas as pd
import numpy as np
import sys, os

from sklearn.model_selection import train_test_split, StratifiedKFold

#user defined
from class_SurvivalQuilts import SurvivalQuilts
from utils_eval import calc_metrics

# IMPORT DATASET
    In this tutorial, preprocessed METABRIC dataset is used as a toy example.
    Please see the data type to use Survival Quilts on your own datasets.

In [20]:
#=================================================================#
##### USER-DEFINED FUNCTIONS
def f_get_Normalization(X, norm_mode):
    num_Patient, num_Feature = np.shape(X)

    if norm_mode == 'standard': #zero mean unit variance
        for j in range(num_Feature):
            if np.std(X[:,j]) != 0:
                X[:,j] = (X[:,j] - np.mean(X[:, j]))/np.std(X[:,j])
            else:
                X[:,j] = (X[:,j] - np.mean(X[:, j]))
    elif norm_mode == 'normal': #min-max normalization
        for j in range(num_Feature):
            X[:,j] = (X[:,j] - np.min(X[:,j]))/(np.max(X[:,j]) - np.min(X[:,j]))
    else:
        print("INPUT MODE ERROR!")

    return X
#=================================================================#



##### DATASET SELECTION
SEED    = 1111
data_mode = 'metabric'  #{'metabric', 'support'}
        
if data_mode == 'metabric':
    X       = pd.read_csv('./sample data/metabric_cleaned_features_final.csv')
    tmp     = pd.read_csv('./sample data/metabric_label.csv')
    T       = tmp[['event_time']]
    Y       = tmp[['label']]
    time_interval_ = 10.
    SEED    = 4321
    

tmp_folder = './results/' + data_mode + ' (seed ' + str(SEED) +')/'
if not os.path.exists(tmp_folder):
    os.makedirs(tmp_folder)

# eval_time_horizons can be selected based on one's interest.
eval_time_horizons = [int(T[Y.iloc[:,0] == 1].quantile(0.25)), int(T[Y.iloc[:,0] == 1].quantile(0.50)), int(T[Y.iloc[:,0] == 1].quantile(0.75))]


tr_X,te_X, tr_T,te_T, tr_Y,te_Y = train_test_split(X, T, Y, test_size=0.2, random_state=1234)

# TRAIN SURVIVAL QUILTS

In [4]:
model_sq = SurvivalQuilts()
model_sq.train(tr_X, tr_T, tr_Y)

# save model
import pickle 
filename = './results/' + 'SurvivalQuilts.sav'
pickle.dump(model_sq, open(filename, 'wb'))

initial training of underlying models...
CV.. 1/10
CV.. 2/10
CV.. 3/10
CV.. 4/10
CV.. 5/10
CV.. 6/10
CV.. 7/10
CV.. 8/10
CV.. 9/10
CV.. 10/10
TIME K = 0
[[0.         0.         0.         0.         0.         0.99999999]
 [0.         0.         0.         0.         0.         0.99999999]
 [0.         0.         0.         0.         0.         0.99999999]
 [0.         0.         0.         0.         0.         0.99999999]
 [0.         0.         0.         0.         0.         0.99999999]
 [0.         0.         0.         0.         0.         0.99999999]
 [0.         0.         0.         0.         0.         0.99999999]
 [0.         0.         0.         0.         0.         0.99999999]
 [0.         0.         0.         0.         0.         0.99999999]
 [0.         0.         0.         0.         0.         0.99999999]]
[[0.         0.         0.         0.         0.         0.99999999]]
[[-0.47267558]]
out_itr: 0 | BEST X: [[0.         0.         0.         0.         0. 

# TEST SURVIVAL QUILTS

In [22]:
pred = model_sq.predict(te_X, eval_time_horizons)

for e_idx, eval_time in enumerate(eval_time_horizons):
    print( calc_metrics(tr_T, tr_Y, te_T, te_Y, pred[:, e_idx], eval_time) )

(0.699212666584546, 0.14023461203864585)
(0.7212535232103512, 0.265499048715562)
(0.6769443449314548, 0.30659239804946314)
