In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, KFold
import statsmodels.api as sm
from sklearn.metrics import accuracy_score, roc_auc_score, recall_score, precision_score, confusion_matrix, log_loss
from sklearn.metrics import precision_recall_curve, roc_curve
import matplotlib.pyplot as plt
import pickle
from sklearn.utils import resample
from sksurv.metrics import concordance_index_censored
import lifelines as ll
import sys
sys.path.append('/odinn/users/thjodbjorge/Python_functions/')
from Calculate_score import calculate_metrics, make_class_table
from R_functions import R_pROC,R_pROC_compareROC,R_pROC_compareROC_boot, R_pROC_AUC, R_timeROC, R_timeROC_CI, R_timeROC_pval, R_NRIbin,R_NRIcens,R_NRIcensipw, R_censROC, R_hoslem, R_Greenwood_Nam
from matplotlib import gridspec

# raw_data = pd.read_csv('/odinn/users/thjodbjorge/Proteomics/Data/raw_with_info.csv',index_col = 'Barcode2d' )
probe_info = pd.read_csv('/odinn/users/thjodbjorge/Proteomics/Data/probe_info.csv', index_col = 'SeqId')

pn_info = pd.read_csv('/odinn/users/thjodbjorge/Proteomics/Data/pn_info_Mor/pn_info_Mor_event.csv',index_col = 'Barcode2d' )
probes_to_skip = pd.read_csv('/odinn/users/thjodbjorge/Proteomics/Data/probes_to_skip.txt')['probe']
nopro = pd.read_csv('/odinn/users/thjodbjorge/Proteomics/Data/no_protein_probes.txt', header = None)[0] # non-proten probes that were included 
probes_to_skip = set(probes_to_skip).union(set(nopro))

In [None]:
folder = '/odinn/users/thjodbjorge/Proteomics/Mortality2/'
feat_folder = 'Features2/'
pred_folder = 'Predictions3/'
pred_folder_update = 'Predictions4/'
plots = 'Plots6_final_plots_in_paper/'
save_plot = True
updated_predictions=True

endpoints = ['death']
# endpoints = ['Neoplasm','Nervous','Circulatory','Respiratory','Other','death'] 
time_to_event = pn_info.time_to_death
no_event_before = pn_info.no_death_before
for endpoint in endpoints:
    if endpoint == 'death':
        use_event = pn_info.event_death
    elif endpoint == 'Neoplasm':
        use_event = pn_info.event_death & (pn_info.Cause_of_death == 'Neoplasm')
    elif endpoint == 'Nervous':
        use_event = pn_info.event_death & (pn_info.Cause_of_death == 'Nervous')
    elif endpoint == 'Circulatory':
        use_event = pn_info.event_death & (pn_info.Cause_of_death == 'Circulatory')
    elif endpoint == 'Respiratory':
        use_event = pn_info.event_death & (pn_info.Cause_of_death == 'Respiratory')
    elif endpoint == 'Other':
        use_event = pn_info.event_death & (pn_info.Cause_of_death == 'Other')

y = []
for i in range(1,19):
    y.append(use_event & (time_to_event <= i))

kf = KFold(n_splits=10, random_state=10, shuffle=False) 
I_train_main, I_test_main = train_test_split(pn_info.index, train_size=0.7, random_state = 10)
# I_val_main, I_test_main = train_test_split(I_test_main, train_size=0.5, random_state = 10)


file = open(folder+"{}_keep_samples.pkl".format('Mor'),'rb')
keep_samples_dict = pickle.load(file)

# print(keep_samples_dict.keys())

In [None]:
# VERY_SMALL_SIZE = 12
# SMALL_SIZE = 14
# MEDIUM_SIZE = 16
# BIGGER_SIZE = 16

VERY_SMALL_SIZE = 5
SMALL_SIZE = 6
MEDIUM_SIZE = 7
BIGGER_SIZE = 7

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=VERY_SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE) 

In [None]:
line_cycle = ['-','--','-.',':','-','--','-.',':','-','--','-.',':','-','--','-.',':']
# color_cycle = ['#d73027','#fc8d59','#fee090','#e0f3f8','#91bfdb','#4575b4']
# color_cycle = ['#a6cee3','#1f78b4','#b2df8a','#33a02c']
# color_cycle = ['C0','C1','C2','C3','C4','C5','C6','C7','C8','C9']
# color_cycle =["#E69F00", "#56B4E9", "#009E73", "#0072B2", "#D55E00", "#CC79A7", "#F0E442"]
# color_cycle = [(0,0,0),(0.9,0.6,0),(0.35,0.7,0.9),(0,0.6,0.5)]
# color_cycle = ["#E69F00", "#56B4E9", "#009E73", "#F0E442","#0072B2", "#D55E00", "#CC79A7", "#000000"]
color_cycle = [ "#000000", "#CC79A7","#0072B2", "#D55E00","#009E73","#E69F00","#56B4E9","#F0E442"]

In [None]:
for i,c in enumerate(color_cycle):
    plt.scatter(i,i,color = c)

In [None]:
from matplotlib.font_manager import findfont, FontProperties, findSystemFonts
font = findfont(FontProperties(family=['sans-serif']))
print(font)
# findSystemFonts()

### Select dataset and k


In [None]:
datasets = ['Old_18105','Old_60105']
k_plot=4
k = k_plot
plot_folder = ''
cm = 1/2.54

In [None]:
cm*18

## Load test predictions

In [None]:

I_train_dict = {}
I_test_dict = {}
pred_test_dict_all = {}
for dataset in datasets:
    try: 
        file = open(folder+pred_folder + "{}_{}_test_prediction.pkl".format(endpoint,dataset),'rb')
        pred_test_dict = pickle.load(file)
        file.close()
    except:
        print('No test predictions')

    if updated_predictions:
        print('Include updated predictiions')
        file = open(folder+pred_folder_update + "{}_{}_test_prediction.pkl".format(endpoint,dataset),'rb')
        pred_test_dict_update = pickle.load(file)   
        file.close()
        
        print(pred_test_dict['{}_y{}_baseline2_lr'.format(dataset,k)][:5])
        for key,value in pred_test_dict_update.items():
            print(key)
            pred_test_dict[key] = value
        print(pred_test_dict['{}_y{}_baseline2_lr'.format(dataset,k)][:5])
    # print(pred_test_dict.keys())
        
    keep_samples = keep_samples_dict[dataset]

    I_train_dict[dataset] = I_train_main.intersection(keep_samples)#.intersection(have_prs)
    I_test_dict[dataset] = I_test_main.intersection(keep_samples)
    
    pred_test_dict_all.update(pred_test_dict) 
pred_test_dict = pred_test_dict_all
#     I_train = I_train_main.intersection(keep_samples)#.intersection(have_prs)
#     I_test = I_test_main.intersection(keep_samples)#.intersection(have_prs)

#     y_train = y[k][I_train]
#     y_test= y[k][I_test]


### Figure 1

In [None]:


K = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14]
K_true = [l+1 for l in K]
# K= [0]
pred_k = []
score_dict = {}
roc_dict = {}
for dataset in datasets:
    I_test = I_test_dict[dataset]
    scores = np.zeros([len(K),4,8])
    for j,k in enumerate(K):
        keys = ['{}_y{}_agesex_lr'.format(dataset,k),'{}_y{}_baseline2_lr'.format(dataset,k),
               '{}_y{}_agesexGDF15_lr'.format(dataset,k),'{}_y{}_agesexprotein_l1'.format(dataset,k)]

        pred_key = []
        for i,key in enumerate(keys):
            pred = pred_test_dict[key]
            baseline = pred_test_dict[keys[1]]
            scores[j,i] = calculate_metrics(pred,baseline,y[k][I_test])    
        score_dict[dataset] = scores
        
    k=4
    keys = ['{}_y{}_agesex_lr'.format(dataset,k),'{}_y{}_baseline2_lr'.format(dataset,k),
           '{}_y{}_agesexGDF15_lr'.format(dataset,k),'{}_y{}_agesexprotein_l1'.format(dataset,k)]
    roc_key = []
    for key in keys:
        pred = pred_test_dict[key]
        pred = pd.DataFrame(pred,index=I_test)[0]
        fpr,tpr,th =roc_curve(y[k][I_test],pred)
        print(th[1:].max())
        roc_key.append([fpr,tpr,th])
    roc_dict[dataset] = roc_key

In [None]:
# source_data_fig1 = {}
source_AUC_vs_time = {}
source_AUC_vs_time['Timepoints'] = K_true

source_ROC = {}


name_keys = ['Age+sex','Baseline','Age+sex+GDF15','Age+sex+Protein']
labels = ['a)','b)','c)','d)']
fig=plt.figure(figsize=[18*cm,14*cm])
# f, (a0, a1) = plt.subplots(1, 2, gridspec_kw={'width_ratios': [3, 1]})
gs = gridspec.GridSpec(2, 2, width_ratios=[1.9, 1]) 

j = 0
for dataset in datasets:
    scores = score_dict[dataset]
    roc_key = roc_dict[dataset]
#     fig.add_subplot(2,2,j)
    ax = plt.subplot(gs[j])
    for i, key in enumerate(name_keys):
        source_AUC_vs_time[dataset+key] =  scores[:,i,0]
        plt.plot(K_true,scores[:,i,0],linestyle = line_cycle[-i-1],color = color_cycle[i],linewidth = 0.8)
    plt.ylabel('AUC')
    plt.xlabel('Event within years')
    plt.xticks(K_true)
    ax.set_xlim(1,15)
    ax.get_ylim()
    #     plt.title(dataset)
    plt.legend(name_keys)
    plt.grid()
    ax.text(0, ax.get_ylim()[1]+0.001,labels[j],fontsize=11)

    
    
    j=j+1

#     fig.add_subplot(2,2,j)
    ax = plt.subplot(gs[j])
    for i in [0,1,2,3]:
        fpr,tpr,th = roc_key[i]
        plt.plot(fpr,tpr,'C{}'.format(i),linestyle = line_cycle[-i-1],color = color_cycle[i],linewidth=0.8)
        source_ROC[dataset+'_'+str(i)] = [fpr,tpr]
    ax.set_aspect('equal')
    k=4
    plt.legend(['Age+sex(AUC = {:0.3f})'.format(scores[k,0,0]),
                'Baseline (AUC = {:0.3f})'.format(scores[k,1,0]),
                'Age+sex+GDF15 (AUC = {:0.3f})'.format(scores[k,2,0]),
               'Age+sex+Protein (AUC = {:0.3f})'.format(scores[k,3,0])])
    plt.xlabel('False positive rate')
    plt.ylabel('True positive rate')
    ax.text(-0.2, 1.1,labels[j],fontsize=11)
    
    j=j+1
fig.subplots_adjust(hspace=0.3)
# plt.figtext(0, 1,labels[0],fontsize=11)
# if save_plot:
#     plt.savefig(folder+plots+plot_folder+'Figure_1.pdf'.format(endpoint,dataset),bbox_inches="tight",dpi=300)
plt.tight_layout(pad=1)
if save_plot:
    plt.savefig(folder+plots+plot_folder+'Figure_1.png',dpi=400)
plt.show()

In [None]:
#### ROC_curve source data
for key, value in source_ROC.items():
    print(key,value[0].shape)
    df = pd.DataFrame(value,index=['fpr','tpr']).transpose()
#     plt.plot(df['fpr'],df['tpr'],linewidth=0.5)
    df.to_csv(folder+plots+plot_folder+'ROC_{}.csv'.format(key))
# plt.axis('equal')

In [None]:
### AUC vs time source data
pd.DataFrame.from_dict(source_AUC_vs_time).to_csv(folder+plots+plot_folder+'AUC_vs_time.csv')

### Figure 2

In [None]:
dataset = 'Old_60105'
I_test = I_test_dict[dataset]
age = (pn_info.loc[I_test,'Age_at_sample_collection_2']<80) & (pn_info.loc[I_test,'Age_at_sample_collection_2']>=60)
high_risk = age

In [None]:
print(len(I_test[high_risk]))
print((y[4][I_test[high_risk]]).sum())
print((y[4][I_test[high_risk]]).sum()/len(I_test[high_risk]))
print((y[9][I_test[high_risk]]).sum())
print((y[9][I_test[high_risk]]).sum()/len(I_test[high_risk]))
print((y[-1][I_test[high_risk]]).sum())
print((y[-1][I_test[high_risk]]).sum()/len(I_test[high_risk]))
print(pn_info.loc[I_test[high_risk],'Age_at_sample_collection_2'].mean())
print(pn_info.loc[I_test[high_risk],'Age_at_sample_collection_2'].std())

In [None]:
survival_dict = {}
k = 9
keys = ['{}_y{}_agesex_lr'.format(dataset,k),'{}_y{}_baseline2_lr'.format(dataset,k),
       '{}_y{}_agesexGDF15_lr'.format(dataset,k),'{}_y{}_agesexprotein_l1'.format(dataset,k)]
name_keys = ['Age+sex','Baseline','Age+sex+GDF15','Age+sex+Protein']
fig = plt.figure(figsize = [18*cm,12*cm])

split_groups = {}
for j,key in enumerate(keys):
#     print(keys[j])
    # key = 'predy{}_{}_tradstatproteinprs_coxelnet'.format(k,dataset)
    pred = pred_test_dict[key][high_risk]

    risk_bins =  np.digitize(pred,np.quantile(pred,[0,0.05,0.2,0.8,0.95,1]))
    pred= pd.DataFrame(pred,index=I_test[high_risk])

    fig.add_subplot(2,2,j+1)
    KMFs = []
    split_group = []
    for i in range(5,0,-1):
        kmf =  ll.fitters.kaplan_meier_fitter.KaplanMeierFitter()
        ind = I_test[high_risk][risk_bins==i]
        split_group.append([time_to_event[ind],use_event[ind]])
        kmf.fit(time_to_event[ind],use_event[ind])
        KMFs.append(kmf)
        kmf.plot(loc=slice(0,16),color=color_cycle[i-1],linewidth=0.5)
        survival_dict[key+'_'+str(i)] = kmf.survival_function_
#         print(kmf.event_table.loc[0,'at_risk'],1- kmf.predict(5),1-kmf.predict(10))
#         print(len(ind), np.mean(pred.loc[ind]))
#         print(kmf.event_table.loc[0,'at_risk'],1- kmf.predict(5),1-kmf.predict(10))
        if i == 5:
            plt.scatter(5, kmf.predict(5),color='r',zorder=10,s = 1)
            plt.scatter(10, kmf.predict(10),color='r',zorder=10,s=1)
            plt.annotate('{:0.2f}'.format(kmf.predict(5)),(5, kmf.predict(5)),(5+0.2, kmf.predict(5)))
            plt.annotate('{:0.2f}'.format(kmf.predict(10)),(10, kmf.predict(10)),(10+0.2, kmf.predict(10)))
    split_groups[key] = split_group
    plt.legend(['95%-100%','80%-95%','20%-80%','5%-20%','0%-5%'])  
    # plt.legend(['0%-5%','5%-20%','20%-50%','50%-80%','80%-95%','95%-100%'])
    plt.axis([0,16,0,1.05])
    plt.title(name_keys[j])
    plt.ylabel('Survival')
    plt.xlabel('Time in years')
    plt.grid(True)
fig.subplots_adjust(hspace=0.4)
    # plt.show()
# if save_plot: 
#     plt.savefig(folder+plots+plot_folder+'Figure_2.png',bbox_inches="tight",dpi=300)
plt.tight_layout(pad=1)
if save_plot: 
    plt.savefig(folder+plots+plot_folder+'Figure_2.pdf',dpi=400)


In [None]:
#### ROC_curve source data
for key, value in survival_dict.items():
    df = value
    df.to_csv(folder+plots+plot_folder+'survival_{}.csv'.format(key))
# df

### Figure 3

In [None]:
prob_death_dict = {}
dataset = 'Old_18105'
I_test = I_test_dict[dataset] 
k=4
print(k)
pred_as = pred_test_dict['{}_y{}_agesex_lr'.format(dataset,k)]
pred_baseline = pred_test_dict['{}_y{}_baseline2_lr'.format(dataset,k)]
pred_gdf = pred_test_dict['{}_y{}_agesexGDF15_lr'.format(dataset,k)]
pred_pro = pred_test_dict['{}_y{}_agesexprotein_l1'.format(dataset,k)]

cases = y[k][I_test]

group_list = []
group_list_as = []
group_list_baseline = []
group_list_gdf =[]
pn_info_list = []

risk = pred_pro[y[k][I_test]]
risk_as=pred_as[y[k][I_test]]
risk_gdf = pred_gdf[y[k][I_test]]
risk_baseline = pred_baseline[y[k][I_test]]
group_list.append(risk)
group_list_as.append(risk_as)
group_list_baseline.append(risk_baseline)
group_list_gdf.append(risk_gdf)
pn_info_list.append(pn_info.loc[I_test][y[k][I_test]])

groups = ['Neoplasm','Nervous','Circulatory','Respiratory','Other']
group_ind = []
for g in groups:
    ind = pn_info.loc[I_test][y[k][I_test]]['Cause_of_death'] == g
    risk = pred_pro[y[k][I_test]][ind]
    risk_as=pred_as[y[k][I_test]][ind]
    risk_gdf = pred_gdf[y[k][I_test]][ind]
    risk_baseline = pred_baseline[y[k][I_test]][ind]
    group_list.append(risk)
    group_list_as.append(risk_as)
    group_list_baseline.append(risk_baseline)
    group_list_gdf.append(risk_gdf)
    pn_info_list.append(pn_info.loc[I_test][y[k][I_test]][ind])
    print(np.sum(ind))


group_list.append(pred_pro[~y[k][I_test]])
group_list_as.append(pred_as[~y[k][I_test]])
group_list_baseline.append(pred_baseline[~y[k][I_test]])
group_list_gdf.append(pred_gdf[~y[k][I_test]])
pn_info_list.append(pn_info.loc[I_test][~y[k][I_test]])
print(np.sum(~y[k][I_test]))
groups.append('Ctrl')


fig = plt.figure(figsize= [18*cm,10*cm])

flierprops = dict(markersize=1)
boxprops_pro = dict(color=color_cycle[3], linewidth=2)    
boxprops_as = dict(color=color_cycle[0], linewidth=2)  
boxprops_gdf = dict(color=color_cycle[2], linewidth=2)   
boxprops_baseline = dict(color=color_cycle[1], linewidth=2)   


bp1 = plt.boxplot(group_list, positions = np.arange(1,3*len(group_list),3), boxprops =boxprops_pro,flierprops=flierprops)
bp2 = plt.boxplot(group_list_as, positions = np.arange(2.8,3*len(group_list),3),boxprops=boxprops_as,flierprops=flierprops)
bp3 = plt.boxplot(group_list_gdf, positions = np.arange(1.6,3*len(group_list),3),boxprops=boxprops_gdf,flierprops=flierprops)
bp4 = plt.boxplot(group_list_baseline, positions = np.arange(2.2,3*len(group_list),3),boxprops=boxprops_baseline,flierprops=flierprops)
# plt.axvline(x=3, color = 'k', alpha=0.7)
# plt.axvline(x=3*len(group_list)-3, color = 'k', alpha=0.5)

plt.legend([bp1['boxes'][0],bp3['boxes'][0],bp4['boxes'][0],bp2['boxes'][0]],['Age+sex+Protein','Age+sex+GDF15','Baseline','Age+sex'],loc='upper right')

plt.ylabel('Predicted probability of \n death within {} years'.format(k+1))#,fontsize=14)
plt.xlabel('Cause of death \n# participants')#,fontsize=14)

locs, _ = plt.xticks()
labels =['All-cause \n','Neoplasms \n','Nervous s. \n','Circulatory s. \n','Respiratory s. \n','Other \n','Alive \n']
new_labels = []
for i, lab in enumerate(labels):
    new_labels.append(lab + str(pn_info_list[i].shape[0]))
labels = new_labels
plt.xticks(np.arange(1,3*len(group_list),3)+0.75,labels = labels)
# plt.title('Predicted risk of death within {} years'.format(k+1), fontsize = 16)
plt.grid(axis='y')
# if save_plot:
#     plt.savefig(folder+plots+plot_folder+'Figure_3.png',bbox_inches = 'tight',dpi=300)
plt.tight_layout(pad=1)
if save_plot:
    plt.savefig(folder+plots+plot_folder+'Figure_3.png',dpi=400)
plt.show()

In [None]:
names =['All-cause','Neoplasms','Nervous','Circulatory','Respiratory','Other','Alive']
for i ,name in enumerate(names):
    df = [group_list_as[i],group_list_baseline[i],group_list_gdf[i],group_list[i]]
    df = pd.DataFrame(df,index=['Age+sex','Baseline','GDF15','Protein']).transpose()
    df = df.sample(frac=1).reset_index(drop=True)
    df.to_csv(folder+plots+plot_folder+'Pred_prob_death_{}.csv'.format(name))


In [None]:
I_test = I_test_dict[dataset]
I_train = I_train_dict[dataset]
print((pn_info.loc[I_test][y[-1][I_test]]['Cause_of_death'] == 'Respiratory').sum())
print((pn_info.loc[I_train][y[-1][I_train]]['Cause_of_death'] == 'Respiratory').sum())