In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, KFold
import statsmodels.api as sm
from sklearn.metrics import accuracy_score, roc_auc_score, recall_score, precision_score, confusion_matrix, log_loss
from sklearn.metrics import precision_recall_curve, roc_curve
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
from sklearn.preprocessing import QuantileTransformer, PowerTransformer
from sklearn.feature_selection import SelectKBest, SelectFromModel, RFE, RFECV, SelectPercentile, SelectFpr, SelectFdr, SelectFwe
from sklearn.feature_selection import chi2, f_classif, mutual_info_classif, SelectFdr
from sklearn.ensemble import RandomForestClassifier
from boruta import BorutaPy
import matplotlib.pyplot as plt
import pickle
from sksurv.metrics import concordance_index_censored
from scipy.stats import pearsonr
import lifelines as ll
# from lifelines.utils.sklearn_adapter import sklearn_adapter
# CoxRegression = sklearn_adapter(ll.CoxPHFitter, event_col = 'event')
import sys
sys.path.append('/odinn/users/thjodbjorge/Python_functions/')
import Predict_functions as pf
from Calculate_score import calculate_metrics, make_class_table
from R_functions import R_pROC,R_pROC_compareROC,R_pROC_compareROC_boot, R_pROC_AUC, R_timeROC, R_timeROC_CI, R_timeROC_pval, R_NRIbin,R_NRIcens,R_NRIcensipw, R_censROC, R_hoslem, R_Greenwood_Nam

# qt_proteins = pd.read_csv('/odinn/users/thjodbjorge/Proteomics/Data/protein_data/protein_batchsiteagesexcorr_qt.csv',index_col = 'Barcode2d' )
raw_data = pd.read_csv('/odinn/users/thjodbjorge/Proteomics/Data/raw_with_info.csv',index_col = 'Barcode2d' )
probe_info = pd.read_csv('/odinn/users/thjodbjorge/Proteomics/Data/probe_info.csv', index_col = 'SeqId')

pn_info = pd.read_csv('/odinn/users/thjodbjorge/Proteomics/Data/pn_info_Mor/pn_info_Mor_event.csv',index_col = 'Barcode2d' )
probes_to_skip = pd.read_csv('/odinn/users/thjodbjorge/Proteomics/Data/probes_to_skip.txt')['probe']
nopro = pd.read_csv('/odinn/users/thjodbjorge/Proteomics/Data/no_protein_probes.txt', header = None)[0] # non-priten probes that were included 
probes_to_skip = set(probes_to_skip).union(set(nopro))

In [None]:
folder = '/odinn/users/thjodbjorge/Proteomics/Mortality2/'
feat_folder = 'Features2/'
pred_folder = 'Predictions3/'
plots = 'Plots3/'
save_plot = True

endpoints = ['death']
# endpoints = ['death','Cdeath','Gdeath','Ideath','Jdeath','Otherdeath']
# event_date = event_date_death
time_to_event = pn_info.time_to_death
no_event_before = pn_info.no_death_before
for endpoint in endpoints:
    if endpoint == 'death':
        use_event = pn_info.event_death
        print(use_event.sum())
    elif endpoint == 'Cdeath':
        use_event = pn_info.event_death & (pn_info.ICD_group == 'C')
        print(use_event.sum())
    elif endpoint == 'Gdeath':
        use_event = pn_info.event_death & (pn_info.ICD_group == 'G')
        print(use_event.sum())
    elif endpoint == 'Ideath':
        use_event = pn_info.event_death & (pn_info.ICD_group == 'I')
        print(use_event.sum())
    elif endpoint == 'Jdeath':
        use_event = pn_info.event_death & (pn_info.ICD_group == 'J')
        print(use_event.sum())
    elif endpoint == 'Otherdeath':
        use_event = pn_info.event_death & (~(pn_info.ICD_group == 'C')&~(pn_info.ICD_group == 'G')&~(pn_info.ICD_group == 'I')&~(pn_info.ICD_group == 'J'))
        print(use_event.sum())

y = []
for i in range(1,19):
    y.append(use_event & (time_to_event <= i))

kf = KFold(n_splits=10, random_state=10, shuffle=False) 
I_train_main, I_test_main = train_test_split(pn_info.index, train_size=0.7, random_state = 10)
# I_val_main, I_test_main = train_test_split(I_test_main, train_size=0.5, random_state = 10)


file = open(folder+"{}_keep_samples.pkl".format('Mor'),'rb')
keep_samples_dict = pickle.load(file)

dataset = 'Old_18105'


In [None]:
do_prediction = True
if do_prediction:

    keep_samples = keep_samples_dict[dataset]

    I_train = I_train_main.intersection(keep_samples)#.intersection(have_prs)
    I_test = I_test_main.intersection(keep_samples)#.intersection(have_prs)

    print('Training set: {}, MI within 15: {}, 10: {}, 5: {}, 2: {}'.format(len(I_train),y[14][I_train].sum(),y[9][I_train].sum(),y[4][I_train].sum(),y[1][I_train].sum()))
    print('Test set: {}, MI within 15: {}, 10: {}, 5: {}, 2: {}'.format(len(I_test),y[14][I_test].sum(),y[9][I_test].sum(),y[4][I_test].sum(),y[1][I_test].sum()))

        # ### Select data and normalize

    X = np.log(raw_data.iloc[:,16:].drop(probes_to_skip,axis=1))

    all_protein = X.columns
    X['sex'] = pn_info[['sex']].values-1
    X['age'] = pn_info[['Age_at_sample_collection_2']].values

    X['age2'] = X['age']**2
#     X['age3'] = X['age']**3
    X['agesex'] = X['age']*X['sex']
    X['age2sex'] = X['age2']*X['sex']

    agesex = ['age','sex','agesex','age2','age2sex']

    X['lnage'] = np.log(X['age'])
    X['lnage2'] = X['lnage']**2


    X['CAD'] = ~pn_info.no_CAD_before
    X['MI'] = ~pn_info.no_MI_before
    X['ApoB'] = X['SeqId.2797-56']
    X['Smoker'] = pn_info['Smoker'].astype(int).values
    X['diabetes'] = pn_info['T2D'].astype(int).values
    X['HTN_treated'] = pn_info[['HTN_treated']].astype(int).values
#     X['statin'] = pn_info['statin'].astype(int).values
    X['statin'] = pn_info['statin_estimate_unsure'].astype(int).values
    
    X['GDF15'] = X['SeqId.4374-45'].copy()
    X['GDF152'] = X['GDF15']**2

    X['bmi'] = pn_info['bmi']

    no_bmi = (X['bmi'].isna())
    no_bmi_ind = X[no_bmi].index
    X.loc[no_bmi_ind,'bmi'] = X.loc[I_train].bmi.mean()

    X['Platelets'] = pn_info['Platelets']
    no_p = (X['Platelets'].isna()); print(no_p.sum())
    no_p_ind = X[no_p].index
    X.loc[no_p_ind,'Platelets'] = X.loc[I_train].Platelets.mean()
    X['Platelets2'] = X['Platelets']*X['Platelets']

    X['Creatinine'] = pn_info['Creatinine']
    no_p = (X['Creatinine'].isna()); print(no_p.sum())
    no_p_ind = X[no_p].index
    X.loc[no_p_ind,'Creatinine'] = X.loc[I_train].Creatinine.mean()

    X['Triglycerides'] = pn_info['Triglycerides']
    no_p = (X['Triglycerides'].isna()); print(no_p.sum())
    no_p_ind = X[no_p].index
    X.loc[no_p_ind,'Triglycerides'] = X.loc[I_train].Triglycerides.mean()    


    X['bmiage'] = X['bmi']*X['age']
    X['bmisex'] = X['bmi']*X['sex']
    X['ApoBage']  = X['ApoB']*X['age']
    X['Smokerage'] = X['Smoker']*X['age']
    X['diabetesage'] = X['diabetes']*X['age'] 
    X['statinage'] = X['statin']*X['age']
    X['CADage'] = X['CAD']*X['age']
    X['MIage'] = X['MI'] * X['age']
    X['HTN_treatedage'] =  X['age']*X['HTN_treated']    
    X['GDF15age']  = X['GDF15']*X['age']
    

    X['Plateletsage'] = X['Platelets']*X['age']
    X['Creatinineage'] = X['Creatinine']*X['age']
    X['Triglyceridesage'] = X['Triglycerides']*X['age']    

    X['ApoBsex']  = X['ApoB']*X['sex']
    X['Smokersex'] = X['Smoker']*X['sex']
    X['diabetessex'] = X['diabetes']*X['sex'] 
    X['statinsex'] = X['statin']*X['sex']
    X['CADsex'] = X['CAD']*X['sex']
    X['MIsex'] = X['MI'] * X['sex']
    X['HTN_treatedsex'] =  X['sex']*X['HTN_treated']   
    X['GDF15sex']  = X['GDF15']*X['sex']

    X['Plateletssex'] = X['Platelets']*X['sex']
    X['Creatininesex'] = X['Creatinine']*X['sex']
    X['Triglyceridessex'] = X['Triglycerides']*X['sex']        

    X = X.join(pd.get_dummies(pn_info['agebin'],drop_first = True,prefix='age'))
    X['ageage2'] = X['age']*X['age_2.0']
    X['ageage3'] = X['age']*X['age_3.0']
    X['ageage4'] = X['age']*X['age_4.0']

    agebins = ['age_2.0','age_3.0','age_4.0', 'ageage2','ageage3','ageage4']
    agebinssex = [s+'sex' for s in agebins]
    X[agebinssex] = (X[agebins].transpose()*X['sex']).transpose()    


    PRS = ['nonHDL_prs', 'HT_prs', 'CAD_prs', 'Cancer_prs', 'Stroke2_prs', 'alz_Jansen',
       'pgc_adhd_2017', 'PD_Nalls_2018', 'edu_160125', 'dep_2018', 'bpd_2018',
       'giant_bmi', 'schizo_clozuk', 'iq_2018', 'ipsych_pgc_aut_2017',
       'pgc_Anorexia_2019']
    X[PRS] = pn_info[PRS]

    trad = ['ApoB','Smoker','diabetes','HTN_treated','statin','CAD','MI','bmi']
    tradage = ['ApoBage','Smokerage','diabetesage','CADage','MIage','HTN_treatedage','bmiage']
    tradsex = ['ApoBsex','Smokersex','diabetessex','CADsex','MIsex','HTN_treatedsex','bmisex']

    tradcoxR = ['Smoker','Smokersex','diabetes','diabetesage','HTN_treated','HTN_treatedage','MI','MIage','CAD','bmi','bmiage','statin','statinage']
    tradextralog = ['Smokersex','diabetessex','CADage','MIage','HTN_treatedage']
    tradblood = ['Creatinine','Triglycerides','Platelets','Platelets2','Plateletsage','Creatinineage','Triglyceridessex']


    X_train = X.loc[I_train]
    X_test = X.loc[I_test]

    train_mean = X_train.mean()
    train_std = X_train.std()

    X_train = (X_train-train_mean)/train_std
    X_test = (X_test-train_mean)/train_std

        ## For survival analysis    
    X_train['event'] = use_event[I_train]
    X_test['event'] = use_event[I_test]

    tte_train = time_to_event[I_train]
    tte_test = time_to_event[I_test]

    ysurv_train = pd.DataFrame()
    ysurv_train['event'] = use_event[I_train]
    ysurv_train['time_to_event'] = time_to_event[I_train]
    
#     for k in K:



In [None]:
dataset


In [None]:
k_plot=4
k = k_plot
y_train = y[k][I_train]
y_test = y[k][I_test]

In [None]:
uni = pd.read_csv(folder+'Univariate2/{}_{}_uni_y{}_table.csv'.format(endpoint,dataset,k),index_col = 'new_ind')

In [None]:
uni.sort_values('log_beta')

In [None]:
corr_col = agesex
corr_pred = sm.OLS(X[all_protein],sm.add_constant(X[corr_col])).fit().predict(sm.add_constant(X[corr_col]))
corr_pred.columns = all_protein
corr_pro = X[all_protein] - corr_pred
corr_pro = corr_pro/corr_pro.std()

In [None]:
corr_res = pd.DataFrame(columns = ['correlation','p-value'],index=all_protein)
for p in all_protein:
    corr_res.loc[p] = (pearsonr(corr_pro.loc[keep_samples,'SeqId.4374-45'],corr_pro.loc[keep_samples,p]))

In [None]:
corr_res.sort_values('correlation')

## Top protein correlations

In [None]:
orgtop10 = uni.sort_values('log_pval')[:10].index
top10 = ['SeqId.'+p for p in orgtop10]

orgtop50 = uni.sort_values('log_pval')[:50].index
top50 = ['SeqId.'+p for p in orgtop50]

In [None]:
corr_res.loc[top10]

In [None]:
(corr_res.correlation > 0.5).sum()

In [None]:
import seaborn as sns
display(corr_pro.loc[keep_samples,top10].corr())
sns.heatmap(corr_pro.loc[keep_samples,top10].corr())

In [None]:
VERY_SMALL_SIZE = 12
SMALL_SIZE = 14
MEDIUM_SIZE = 16
BIGGER_SIZE = 16

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=VERY_SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE) 

In [None]:
plt.figure(figsize = [10,8])
sns.heatmap((corr_pro.loc[keep_samples,top10].corr()),vmax=1, xticklabels = probe_info.loc[orgtop10].GeneName, yticklabels = probe_info.loc[orgtop10].GeneName, cmap='viridis',cbar_kws={'label': 'Correlation'})
plt.savefig(folder+plots+'Top10proteincorrelationheatmap.png',bbox_inches = 'tight')

In [None]:
(corr_pro.loc[keep_samples,top50].corr())

In [None]:
plt.figure(figsize = [15,13])
sns.heatmap((corr_pro.loc[keep_samples,top50].corr()),vmax=1, xticklabels = probe_info.loc[orgtop50].GeneName, yticklabels = probe_info.loc[orgtop50].GeneName, cmap='viridis',cbar_kws={'label': 'Correlation'})
# plt.savefig(folder+plots+'Top10proteincorrelationheatmap.png',bbox_inches = 'tight')

In [None]:
X[all_protein]

In [None]:
correlations = X[all_protein].corr()

In [None]:
((correlations.mean() - 1/4905 )*(4905/4904)).mean()

In [None]:
(np.mean([1,1,1,1,1,2]) - 2/6)*6/5

In [None]:
correlations_corr = corr_pro[all_protein].corr()

In [None]:
((correlations_corr.mean() - 1/4905 )*(4905/4904)).mean()

In [None]:
correlations_corr_ks = corr_pro.loc[keep_samples,all_protein].corr()
((correlations_corr_ks.mean() - 1/4905 )*(4905/4904)).mean()