### Let's try cox survival analysis here with a reduced dataset

In [None]:
import pickle
import pandas as pd
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
sns.set_palette("Set2")

In [None]:
#load initial cleaned data
df=pd.read_csv(r"..\data\processed\prelim_clean2.csv", index_col=[0])
df.drop(['visit_interval','target'],axis=1, inplace=True)
df2=pd.read_csv(r"..\data\processed\prelim_clean2.csv", index_col=[0])

In [None]:
broad_features = pickle.load(open('../models/01final_features_broad.sav', 'rb'))
res_features = pickle.load(open('../models/01final_features_res.sav', 'rb'))

In [None]:
broad_subset=df[df.columns.intersection(broad_features)]
subset=df[df.columns.intersection(res_features)]

In [None]:
from lifelines import CoxPHFitter
from lifelines.utils import k_fold_cross_validation

broad_subset['visit_interval']=df2['visit_interval']
broad_subset['target']=df2['target']

cph = CoxPHFitter(penalizer=0.1)
#scores = k_fold_cross_validation(cph, broad_subset, 'visit_interval', event_col='target', k=10)
#print(scores)

# scores = k_fold_cross_validation(cph, broad_subset, duration_col='visit_interval', event_col='target', k=10, scoring_method="concordance_index")
# print(scores)

cph.fit(df=broad_subset, duration_col='visit_interval', event_col='target')
cph.print_summary()


In [None]:
subset['visit_interval']=df2['visit_interval']
subset['target']=df2['target']

cph = CoxPHFitter(penalizer=0.01)
scores = k_fold_cross_validation(cph, subset, duration_col='visit_interval', event_col='target', k=10, scoring_method="concordance_index")
print(scores)

cph.fit(df=broad_subset, duration_col='visit_interval', event_col='target')
#cph.print_summary()

In [None]:
# from lifelines.utils.sklearn_adapter import sklearn_adapter

# from lifelines import CoxPHFitter
# from lifelines.datasets import load_rossi

# X = load_rossi().drop('week', axis=1) # keep as a dataframe
# Y = load_rossi().pop('week')

# CoxRegression = sklearn_adapter(CoxPHFitter, event_col='arrest')
# # CoxRegression is a class like the `LinearRegression` class or `SVC` class in scikit-learn

# sk_cph = CoxRegression(penalizer=1e-5)
# sk_cph.fit(X, Y)
# print(sk_cph)

# """
# SkLearnCoxPHFitter(alpha=0.05, penalizer=1e-5, strata=None, tie_method='Efron')
# """

# sk_cph.predict(X)
# sk_cph.score(X, Y)

## Cox with skurv (just to test how functions work in their data)

In [None]:
import numpy
from sksurv.datasets import load_gbsg2
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.metrics import integrated_brier_score
from sksurv.preprocessing import OneHotEncoder

In [None]:
X, y = load_gbsg2()
X.loc[:, "tgrade"] = X.loc[:, "tgrade"].map(len).astype(int)
Xt = OneHotEncoder().fit_transform(X)

In [None]:
est = CoxPHSurvivalAnalysis(ties="efron").fit(Xt, y)

In [None]:
survs = est.predict_survival_function(Xt)
times = numpy.arange(365, 1826)
preds = numpy.asarray([[fn(t) for t in times] for fn in survs])

In [None]:
survs[0]

In [None]:
score = integrated_brier_score(y, y, preds, times)
print(score)

# survival forest with sksurv

In [None]:
from sksurv.ensemble import RandomSurvivalForest
from sksurv.util import Surv

In [None]:
y=Surv.from_dataframe('target', 'visit_interval', df2)
X_train, X_test, y_train, y_test = train_test_split(subset, y, test_size=0.35, random_state=1)


In [None]:
rsf = RandomSurvivalForest(n_estimators=100, max_depth=12, max_leaf_nodes=145)
rsf.fit(X_train, y_train)
print(rsf.score(X_test, y_test))
#pickle.dump(rsf, open('../models/test.sav', 'wb'))

In [None]:
rsf.get_params()

### Assessing model

In [None]:
surv = rsf.predict_survival_function(X_test)

surv_p=pd.Series(rsf.predict(X_test))

In [None]:
surv_p

In [None]:
survs = rsf.predict_survival_function(X_test, return_array=True)
times=rsf.event_times_
haz=rsf.predict_cumulative_hazard_function(X_test, return_array=True)

In [None]:
time=np.arange(6, 400+6) #6.0; 2504

In [None]:
from sksurv.metrics import integrated_brier_score, brier_score
integrated_brier_score(y_train, y_test, survs, time)# not implemented for trees in this library =( super hacked, but worked!

In [None]:
from sksurv.metrics import concordance_index_ipcw,cumulative_dynamic_auc
concordance_index_ipcw(y_train, y_test, survs[:,0])

In [None]:
BIGGER_SIZE = 15

plt.rc('font', size=BIGGER_SIZE)          # controls default text sizes
#plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=BIGGER_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=BIGGER_SIZE)    # legend fontsize
#plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

t=[]
s=[]

for i in range(399):
    time, score = brier_score(y_train, y_test, survs[:,i], times[i])
    t.append(time)
    s.append(score)

plt.plot(t, s)
plt.axhline(0.194, linestyle="--")
plt.xlabel("days from enrollment")
plt.ylabel("Prediction error")
#plt.grid(True)


In [None]:
ix = [i[0] for i in y_test]
sum(ix)
len(ix)

In [None]:
nh=np.mean(survs[ix,:],axis=0)
yh=np.mean(survs[~np.array(ix),:],axis=0)

SMALL_SIZE = 8
MEDIUM_SIZE = 10
BIGGER_SIZE = 15

plt.rc('font', size=BIGGER_SIZE)          # controls default text sizes
#plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=BIGGER_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=BIGGER_SIZE)    # legend fontsize
#plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
times=rsf.event_times_
plt.plot(times, nh)
#plt.plot(times, yh)

plt.xlabel("days from enrollment")
plt.ylabel("Survival Probability")
#plt.legend(['AD','MCI'])
plt.xlim([0, 2000])

#plt.grid(True)
# Hide the right and top spines
sns.despine(top=True, right=True, left=False, bottom=False)

In [None]:
nh=np.mean(haz[ix,:],axis=0)
yh=np.mean(haz[~np.array(ix),:],axis=0)

times=rsf.event_times_
plt.plot(times, nh)
plt.plot(times, yh)

plt.xlabel("days from enrollment")
plt.ylabel("Hazard rate")
plt.legend(['AD','MCI'])
plt.xlim([0, 500])

plt.grid(True)


In [None]:
for i in range(1, 50):
    nh=haz[i,:]
    if ix[i]==True:
        times=rsf.event_times_
        plt.plot(times, nh)

        plt.xlabel("days from enrollment")
        plt.ylabel("Hazard rate")
        plt.xlim([300, 500])
        plt.grid(True)

In [None]:
#pickle.dump(rsf, open('../models/surv_model.sav', 'wb'))

In [None]:
for i, s in enumerate(surv):
    plt.step(rsf.event_times_, s, where="post", label=str(i))
plt.ylabel("Survival probability")
plt.xlabel("Time in days")
plt.grid(True)
plt.legend()

In [None]:
plt.step(rsf.event_times_, surv[0], where="post")