In [55]:
matplotlib_style = 'fivethirtyeight'
import matplotlib.pyplot as plt; plt.style.use(matplotlib_style)
import numpy as np
import tensorflow as tf
import os
from pathlib import Path
from sklearn.preprocessing import StandardScaler
from sksurv.linear_model.coxph import BreslowEstimator, CoxPHSurvivalAnalysis
matplotlib_style = 'fivethirtyeight'
import matplotlib.pyplot as plt; plt.style.use(matplotlib_style)
from sklearn.model_selection import train_test_split
from utility.training import get_data_loader, scale_data, split_time_event
from utility.survival import calculate_event_times, predict_median_survival_time, predict_mean_survival_time
import paths as pt
import pandas as pd
from utility.model import load_mlp_model, load_sota_model
from utility.survival import compute_survival_function

DATASET_NAME = "SEER"

# Load data
dl = get_data_loader(DATASET_NAME).load_data()
X, y = dl.get_data()
num_features, cat_features = dl.get_features()

# Split data in train, valid and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
X_train, X_valid, y_train, y_valid  = train_test_split(X_train, y_train, test_size=0.25, random_state=0)

# Scale data
X_train, X_valid, X_test = scale_data(X_train, X_valid, X_test, cat_features, num_features)

# Make time/event split
t_train, e_train = split_time_event(y_train)
t_test, e_test = split_time_event(y_test)

# Make event times
event_times = calculate_event_times(t_train, e_train)

# Load MLP model
n_input_dims = X_train.shape[1:]
model = load_sota_model(DATASET_NAME, "cox")

# Select only test samples where event occurs
test_idx = list(np.where(y_test['event'] == True)[0])
X_test = X_test.iloc[test_idx]
y_test = y_test[test_idx]

# Compute surv func
surv_preds = compute_survival_function(model, X_train, X_test, e_train, t_train, event_times)
surv_preds = pd.DataFrame(surv_preds.mean(axis=0), columns=event_times)



FileNotFoundError: [Errno 2] No such file or directory: 'c:\\users\\au475271\\onedrive - aarhus universitet\\desktop\\baysurv\\models\\seer_cox.joblib'

In [None]:
# Plot surv funcs
'''
styles = ('-', '--')
plt.figure(dpi=80)
for i, surv_fn in enumerate(test_surv_fn):
    plt.step(event_times, surv_fn(event_times), where="post", label=str(i))
plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.05),
          ncol=5, fancybox=True, shadow=True)
plt.ylabel("Probability of survival $S(t)$")
plt.xlabel("Time $t$")
plt.grid(True)
plt.show()
'''

'\nstyles = (\'-\', \'--\')\nplt.figure(dpi=80)\nfor i, surv_fn in enumerate(test_surv_fn):\n    plt.step(event_times, surv_fn(event_times), where="post", label=str(i))\nplt.legend(loc=\'upper center\', bbox_to_anchor=(0.5, 1.05),\n          ncol=5, fancybox=True, shadow=True)\nplt.ylabel("Probability of survival $S(t)$")\nplt.xlabel("Time $t$")\nplt.grid(True)\nplt.show()\n'

In [None]:
from scipy.integrate import trapezoid
def compute_survival_times(risk_scores, t_train, e_train, seed):
    # https://pubmed.ncbi.nlm.nih.gov/15724232/
    rnd = np.random.RandomState(seed)

    # generate survival time
    mean_survival_time = t_train[e_train].mean()
    baseline_hazard = 1. / mean_survival_time
    scale = baseline_hazard * np.exp(risk_scores)
    u = rnd.uniform(low=0, high=1, size=risk_scores.shape[0])
    t = -np.log(u) / scale

    return t

def trapz_wrapper(X):
    """Helper function
    TBC
    """
    # NOTE: y - values | x - index
    return trapezoid(X.values, X.columns)

trapz_times = np.array(trapz_wrapper(surv_preds))

In [None]:
surv_preds.shape

(34, 260)

In [None]:
risk_scores = model.predict(X_test)
samples = list()
for i in range(0, 500): # sample 500 times
    samples.append(compute_survival_times(risk_scores, t_train, e_train, i))
simulated_times = np.mean(samples, axis=0)
simulated_times

array([238.3992795 , 208.70478054,  36.44637331, 314.5861069 ,
       106.06691095,  32.59273362, 188.47747349, 353.79618008,
       458.4596614 , 345.94208561,  51.29069509, 124.57658146,
       148.95560635,  70.04225079,  82.87057882, 189.20569261,
       109.96320753, 502.68529064, 149.67011791, 241.79839155,
       323.59582164,  28.94970501, 312.71309777, 400.1600163 ,
       126.40877781,  46.38812705, 113.67773611,  52.82494274,
        28.16884124,  13.48996112, 130.26522288,  73.15421421,
       939.60387349,  79.51241077])

In [None]:
from typing import Union
import pandas as pd

def qth_survival_time(q: float, model_or_survival_function) -> float:
    """
    Returns the time when a single survival function reaches the qth percentile, that is,
    solves  :math:`q = S(t)` for :math:`t`.

    Parameters
    ----------
    q: float
      value between 0 and 1.
    model_or_survival_function: Series, single-column DataFrame, or lifelines model


    See Also
    --------
    qth_survival_times, median_survival_times
    """
    from lifelines.fitters import UnivariateFitter

    if isinstance(model_or_survival_function, UnivariateFitter):
        return model_or_survival_function.percentile(q)
    elif isinstance(model_or_survival_function, pd.DataFrame):
        if model_or_survival_function.shape[1] > 1:
            raise ValueError(
                "Expecting a DataFrame (or Series) with a single column. Provide that or use utils.qth_survival_times."
            )
        return qth_survival_time(q, model_or_survival_function.T.squeeze())
    elif isinstance(model_or_survival_function, pd.Series):
        if model_or_survival_function.iloc[-1] > q:
            return np.inf
        return model_or_survival_function.index[(-model_or_survival_function).searchsorted([-q])[0]]
    else:
        raise ValueError(
            "Unable to compute median of object %s - should be a DataFrame, Series or lifelines univariate model"
            % model_or_survival_function
        )

def _to_1d_array(x) -> np.ndarray:
    v = np.atleast_1d(x)
    try:
        if v.shape[0] > 1 and v.shape[1] > 1:
            raise ValueError("Wrong shape (2d) given to _to_1d_array")
    except IndexError:
        pass
    return v

def qth_survival_times(q, survival_functions) -> Union[pd.DataFrame, float]:
    """
    Find the times when one or more survival functions reach the qth percentile.

    Parameters
    ----------
    q: float or array
      a float between 0 and 1 that represents the time when the survival function hits the qth percentile.
    survival_functions: a (n,d) DataFrame, Series, or NumPy array.
      If DataFrame or Series, will return index values (actual times)
      If NumPy array, will return indices.

    Returns
    -------
    float, or DataFrame
         if d==1, returns a float, np.inf if infinity.
         if d > 1, an DataFrame containing the first times the value was crossed.

    See Also
    --------
    qth_survival_time, median_survival_times
    """
    # pylint: disable=cell-var-from-loop,misplaced-comparison-constant,no-else-return
    q = _to_1d_array(q)
    q = pd.Series(q.reshape(q.size), dtype=float)

    if not ((q <= 1).all() and (0 <= q).all()):
        raise ValueError("q must be between 0 and 1")

    survival_functions = pd.DataFrame(survival_functions)

    if survival_functions.shape[1] == 1 and q.shape == (1,):
        q = q[0]
        # If you add print statements to `qth_survival_time`, you'll see it's called
        # once too many times. This is expected Pandas behavior
        # https://stackoverflow.com/questions/21635915/why-does-pandas-apply-calculate-twice
        return survival_functions.apply(lambda s: qth_survival_time(q, s)).iloc[0]
    else:
        d = {_q: survival_functions.apply(lambda s: qth_survival_time(_q, s)) for _q in q}
        survival_times = pd.DataFrame(d).T
        #  Typically, one would expect that the output should equal the "height" of q.
        #  An issue can arise if the Series q contains duplicate values. We solve
        #  this by duplicating the entire row.
        if q.duplicated().any():
            survival_times = survival_times.loc[q]

        return survival_times

In [None]:
#test_surv_fn = model.predict_survival_function(test_sample)
#surv_preds = np.row_stack([fn(event_times) for fn in test_surv_fn])

In [None]:
from lifelines import CoxPHFitter
cph = CoxPHFitter(penalizer=0.0001)
data = pd.concat([pd.DataFrame(X_train),
                  pd.Series(y_train['time'], name="Survival_time"),
                  pd.Series(y_train['event'], name="Event")], axis=1)
cph.fit(data, duration_col="Survival_time", event_col="Event")

  nonnumeric_cols = [col for (col, dtype) in df.dtypes.iteritems() if dtype.name == "category" or dtype.kind not in "biuf"]


<lifelines.CoxPHFitter: fitted with 300 total observations, 164 right-censored observations>

In [None]:
lifelines_trapz_times = cph.predict_expectation(X_test)

In [None]:
lifelines_median_times = cph.predict_median(X_test)

In [None]:
lifelines_percentile_times = cph.predict_percentile(X_test, 0.5)

In [None]:
surv_times = cph.predict_survival_function(X_test)
surv_times[surv_times < 0.05]

Unnamed: 0,0,3,11,14,19,20,22,29,31,32,...,77,78,79,86,87,89,91,95,96,98
1.0,,,,,,,,,,,...,,,,,,,,,,
2.0,,,,,,,,,,,...,,,,,,,,,,
4.0,,,,,,,,,,,...,,,,,,,,,,
5.0,,,,,,,,,,,...,,,,,,,,,,
6.0,,,,,,,,,,,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2175.0,,,1.474621e-04,,,8.150704e-05,,,,,...,,1.311034e-03,4.576703e-02,3.363756e-03,1.355980e-05,1.716425e-11,,1.710651e-02,,1.945544e-02
2178.0,,,1.474621e-04,,,8.150704e-05,,,,,...,,1.311034e-03,4.576703e-02,3.363756e-03,1.355980e-05,1.716425e-11,,1.710651e-02,,1.945544e-02
2190.0,,,1.474621e-04,,,8.150704e-05,,,,,...,,1.311034e-03,4.576703e-02,3.363756e-03,1.355980e-05,1.716425e-11,,1.710651e-02,,1.945544e-02
2353.0,0.044896,0.034197,1.239029e-09,,1.585202e-03,3.122400e-10,0.024174,,,,...,3.250315e-03,1.991266e-07,7.693069e-04,1.780113e-06,4.826492e-12,9.397889e-26,3.237833e-03,7.807559e-05,,1.052985e-04


In [None]:
real_times = y_test['time']

In [None]:
event_times

array([0.000e+00, 1.000e+00, 2.000e+00, 4.000e+00, 5.000e+00, 6.000e+00,
       7.000e+00, 1.000e+01, 1.100e+01, 1.400e+01, 1.800e+01, 1.900e+01,
       2.000e+01, 2.200e+01, 2.600e+01, 3.100e+01, 3.200e+01, 3.300e+01,
       3.700e+01, 4.200e+01, 4.600e+01, 5.200e+01, 5.700e+01, 6.000e+01,
       6.100e+01, 6.200e+01, 6.400e+01, 6.900e+01, 7.600e+01, 8.100e+01,
       8.800e+01, 9.300e+01, 9.700e+01, 1.010e+02, 1.130e+02, 1.160e+02,
       1.170e+02, 1.180e+02, 1.290e+02, 1.350e+02, 1.370e+02, 1.430e+02,
       1.450e+02, 1.510e+02, 1.690e+02, 2.000e+02, 2.260e+02, 2.330e+02,
       2.350e+02, 2.690e+02, 2.740e+02, 2.870e+02, 2.950e+02, 2.970e+02,
       3.120e+02, 3.130e+02, 3.280e+02, 3.540e+02, 3.580e+02, 3.590e+02,
       3.680e+02, 3.710e+02, 3.760e+02, 3.820e+02, 3.850e+02, 3.860e+02,
       3.900e+02, 3.970e+02, 4.030e+02, 4.050e+02, 4.110e+02, 4.120e+02,
       4.180e+02, 4.190e+02, 4.210e+02, 4.220e+02, 4.240e+02, 4.330e+02,
       4.370e+02, 4.400e+02, 4.420e+02, 4.450e+02, 

In [None]:
X_test

Unnamed: 0,age,bmi,diasbp,hr,los,sysbp,afb_1,av3_1,chf_1,cvd_1,gender_1,miord_1,mitype_1,sho_1
0,0.837597,-1.37095,1.430397,-0.627332,-0.934115,1.939212,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0
3,0.700809,0.246288,-0.352711,-0.42488,1.422777,-0.240611,0.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0
11,2.205475,-0.753922,-0.868874,0.020515,-0.698426,-0.832277,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
14,0.632416,-0.123824,-0.399635,-0.789294,-0.462737,-0.271751,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0
19,1.247961,1.195894,-0.681179,0.46591,0.008642,-0.365172,0.0,0.0,1.0,1.0,1.0,1.0,0.0,0.0
20,1.042779,0.116397,-1.244265,-1.396652,0.244331,0.413336,1.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0
22,1.042779,-0.732064,-0.305787,1.437682,4.486737,-0.302892,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0
29,-0.051523,1.022396,0.257299,-0.019975,1.422777,-0.458593,0.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0
31,-0.325099,0.150198,-0.775026,0.708853,-0.227047,-0.552014,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0
32,0.427234,0.431668,0.398071,0.222968,0.008642,-0.240611,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0


In [None]:
from tools.bnn_isd_evaluator import BaseEvaluator
eval = BaseEvaluator(surv_preds, event_times,
                     y_test['time'], y_test['event'],
                     y_train['time'], y_train['event'])
median_times = eval.predict_time_from_curve(predict_median_survival_time)

In [None]:
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error

max_time = np.max(y['time'])
res = pd.DataFrame([real_times, trapz_times, simulated_times, lifelines_trapz_times,
                    lifelines_median_times,lifelines_percentile_times, median_times]).T
res.columns = ['Real', 'Trapz', 'Sim', 'LL-Trapz', 'LL-Med', 'LL-Per', 'Median']
res['LL-Med'].replace([np.inf, -np.inf], max_time, inplace=True)
res['LL-Per'].replace([np.inf, -np.inf], max_time, inplace=True)

In [None]:
res

Unnamed: 0,Real,Trapz,Sim,LL-Trapz,LL-Med,LL-Per,Median
0,192.0,1199.291085,238.39928,1197.239095,1217.0,1217.0,1213.166894
1,359.0,1132.272639,208.704781,1135.733865,1065.0,1065.0,1052.712293
2,169.0,117.530163,36.446373,116.271826,22.0,22.0,20.34599
3,17.0,1385.192685,314.586107,1385.948585,1579.0,1579.0,1576.995506
4,422.0,657.526588,106.066911,657.8571,358.0,358.0,355.868283
5,7.0,106.253252,32.592734,102.532892,19.0,19.0,18.959118
6,953.0,1049.305832,188.477473,1063.127681,903.0,903.0,852.189363
7,259.0,1459.006992,353.79618,1461.072544,1926.0,1926.0,1924.790859
8,1506.0,1619.406943,458.459661,1620.07719,2160.0,2160.0,2159.466616
9,187.0,1457.644389,345.942086,1459.244116,1926.0,1926.0,1924.669628


In [None]:
print("MSE")
print(mean_squared_error(res['Real'], res['Trapz']))
print(mean_squared_error(res['Real'], res['Sim']))
print(mean_squared_error(res['Real'], res['LL-Trapz']))
print(mean_squared_error(res['Real'], res['LL-Per']))
print(mean_squared_error(res['Real'], res['Median']))
print(mean_squared_error(res['Real'], [71]*len(X_test)))
print()
print("MAE")
print(mean_absolute_error(res['Real'], res['Trapz']))
print(mean_absolute_error(res['Real'], res['Sim']))
print(mean_absolute_error(res['Real'], res['LL-Trapz']))
print(mean_absolute_error(res['Real'], res['LL-Per']))
print(mean_absolute_error(res['Real'], res['Median']))
print(mean_absolute_error(res['Real'], [71]*len(X_test)))


MSE
711268.8357561055
388221.8978025328
713117.1852119731
921878.1470588235
915653.2628061634
451190.76470588235

MAE
625.3600396459185
364.8368025426435
626.1032321103253
652.1470588235294
649.1855421565002
372.47058823529414
