In [None]:
import pandas as pd
import matplotlib.pyplot as plt 
import seaborn as sns
import numpy as np
from statannot import add_stat_annotation
from sklearn import metrics

In [None]:
import os
current_dir = os.getcwd()
file_name = "on_ms_white_british_2502.tsv"

if os.path.exists(os.path.join(current_dir, file_name)):
    file_path = current_dir + "/" + file_name
    
else:
    print("File not accessbile, please define file_path manually")

In [None]:
#file_path = "/slade/home/pl450/ON_UKBB/1601_cph_analysis/on_ms_genpop_1801.tsv"

In [None]:
from datetime import date
today = date.today().strftime("%m/%d/%y").replace("/", "_")

population = ""
if "genpop" in file_name :
    population = "genpop"
    print("File containging general popualtion is used")
    
elif "white_british" in file_name:
    population = "white_british"
    print("File containing white british is used")
    
elif population == "":
    print("defining an appropriate string failed")


#This funciton is to create dated labels for saved .png elements  
def create_png_label(png_descr):
    if len(population)>0:
        png_file_name = str(population) +"_"+str(png_descr)+"_"+str(today)+".png"
    else:
        print("Population not defined")
    return(str(png_file_name))


In [None]:
data = pd.read_csv(file_path, sep = "\t", low_memory=False)

In [None]:
#Defining population groups and their attributes, such as labels, plot colours

controls = data[data.ON_group == "Controls"]
control_label = "Controls"
control_colour = 'lightgreen'

ON_group = data[data.ON_group == "ON only"]
ON_label = "ON only"
ON_colour = "indianred"


MS_group = data[data.ON_group == "MS only"]
MS_label = "MS only"
MS_colour = "steelblue"


ON_and_MS_group = data[data.ON_group == "MS-ON"]
ON_and_MS_label = "MS-ON"
ON_and_MS_colour = "darkmagenta"


In [None]:
#Plotting histogram + KDE for all groups. NB: MS-GRS is called 'full_expanded'

font = {'family': 'sans-serif',
        'weight': 'light',
        'size': 11,
        }

plt.rc('font', **font)

kwargs = dict(kde_kws={'linewidth':2}, hist_kws={'alpha':.2}, rug_kws = {'alpha':0.8, 'linewidth':.9, 'height':0.2})
plt.figure(figsize=(10,7), dpi= 300)

#Listing all groups, labels and colors to be plotted

groups = [controls.full_expanded, ON_group.full_expanded, ON_and_MS_group.full_expanded, MS_group.full_expanded]
labels = [control_label, ON_label, ON_and_MS_label, MS_label]
colours = [control_colour, ON_colour, ON_and_MS_colour, MS_colour]


for (group,label,colour) in zip(groups, labels, colours):
    sns.distplot(group, color = colour,
                 label = "{pop_label}, n={len_df}".format(pop_label = label,len_df=len(group)),
                 **kwargs, bins=20)


#plt.title('Distribution of MS-GRS')
plt.xlabel("MS-GRS", fontdict=font)
plt.title('MS-GRS distribution in 4 groups')
plt.legend()


plt.savefig(create_png_label("4_group_hist"))

In [None]:
#Plotting a fugire without positive or negative controls

plt.figure(figsize=(10,7), dpi= 300)

groups_two = [ ON_group.full_expanded, ON_and_MS_group.full_expanded]
labels_two = [ ON_label, ON_and_MS_label]
colours_two = [ON_colour, ON_and_MS_colour]


kwargs = dict(kde_kws={'linewidth':2}, hist_kws={'alpha':.2}, rug_kws = {'alpha':0.7, 'linewidth':.75, 'height':0.05})

for (group,label,colour) in zip(groups_two, labels_two, colours_two):
    sns.distplot(group, color = colour, rug=True,
                 label = "{pop_label}, n={len_df}".format(pop_label = label,len_df=len(group)),
                 **kwargs, bins=22)


#plt.title('Distribution of MS-GRS')
plt.xlabel("MS-GRS", fontdict=font)
plt.legend()
plt.ylim([0, 0.40])
plt.title('MS-GRS in MS-ON vs ON Only', fontdict=font)

plt.savefig(create_png_label('on_vs_ms_hist'))

In [None]:
# A function to create a pallette for seaborn violin plots

def create_palette(labels, colours):
    palette = {}
    for (label, colour) in zip(labels, colours):
        palette.update({label:colour})
    return(palette)


In [None]:
# 4 group violin plot with in-figure text – text position adjusted for 6x6 plt.figure figeize and 4 groups


from statannot import add_stat_annotation

plt.figure(figsize= [6, 6], dpi=300, facecolor=None)


upd_pal= create_palette(labels, colours)

ax = sns.violinplot(x = 'ON_group', y='full_expanded',data=data, inner = 'box', palette = upd_pal, order =upd_pal.keys())
ax.set_xlabel('Group', fontsize = 12)
ax.set_ylabel('MS-GRS')
for violin, alpha in zip(ax.collections[::2], [0.85,0.85,0.85,0.8]):
    violin.set_alpha(alpha)

add_stat_annotation(ax, data=data, x='ON_group', y='full_expanded', order = upd_pal.keys(),
                    box_pairs=[("MS-ON", "MS only"), ("MS-ON", "ON only"), ("ON only", "Controls")],
                    test='Mann-Whitney', text_format='star', loc='inside', verbose=2)

#ax.text(-0.2,-2.7,'n=481K')
#ax.text(0.8,-1.5,'n=421')
#ax.text(1.8,-1.5,'n=266')
#ax.text(2.8,-1.5,'n=2103')

plt.tight_layout()


plt.savefig(create_png_label('4_group_violin'))

In [None]:
#Creating a dummy Sex_Female
data = pd.get_dummies(data, columns=['Sex'], drop_first=False)
data.drop(labels = 'Sex_Male',axis=1, inplace=True)

In [None]:
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import RepeatedKFold


#This function performs Repeated K Fold cross-validation

def crossval_logreg_rocauc(x,y):
    
    """Takes x and y of data, and returns fpr, tpr, avg_auc and array of all AUCs in a tuple"""
    
    
    fpr_dict = {}
    tpr_dict = {}
    auc_arr = []

    index=0

    rkf = RepeatedKFold(n_splits=3, n_repeats=10)
    
#rkf.split generates indices to split data into training and test sets
    for train_index, test_index in rkf.split(x):

    #The following assigns instances to either Train or Test based on indixes derived before
    
    
        X_train, X_test = x.loc[train_index], x.loc[test_index]
        y_train, y_test = np.ravel(y.loc[train_index]), np.ravel(y.loc[test_index])
    
        #This is just to check that it's working
        #print(("TRAIN:", train_index[0:10], "TEST:", test_index[0:10]))
        

    #Fit logistic regression with train data and predict probabilities using X_test
        logreg = LogisticRegression(max_iter=700).fit(X_train, y_train)
        pred_prob1 = logreg.predict_proba(X_test)

    #ROC-AUC calculation, and saving them to a dictionary
        fpr1, tpr1, thresh1 = roc_curve(y_test, pred_prob1[:,1], pos_label=1)
        fpr_dict.update({str('take_{i}'.format(i=index)):fpr1})
        tpr_dict.update({str('take_{i}'.format(i=index)):tpr1})
        index +=1

    #Calculating ROC-AUC score for each iteraiton/fold
        auc_score1 = roc_auc_score(y_test, pred_prob1[:,1])
        auc_arr.append(auc_score1)
    

#Transforming dictionary with false-positive and true-positive rates (fpr, tpr) into a DataFrame
    fpr_df = pd.DataFrame(dict([(k,pd.Series(v)) for k,v in fpr_dict.items()]))
    fpr_df["avg"] = fpr_df.mean(axis = 1)

    tpr_df = pd.DataFrame(dict([(k,pd.Series(v)) for k,v in tpr_dict.items()]))
    tpr_df["avg"] = tpr_df.mean(axis = 1)

#Calculating mean ROC-AUC from cross-validation
    avg_auc = np.mean(auc_arr)
    avg_auc
    
    
    
    return(fpr_df, tpr_df, avg_auc, auc_arr)

In [None]:
#Calculating ROC-AUC without the cross-validation

def empirical_fpr_tpr(x,y):
    
    logreg = LogisticRegression(max_iter=700).fit(x,np.ravel(y))
    pred_prob1 = logreg.predict_proba(x)
    
    fpr_emp, tpr_emp, _ = roc_curve(y, pred_prob1[:,1], pos_label=1)
    
    auc_score_emp = roc_auc_score(y, pred_prob1[:,1])
    
    return(fpr_emp, tpr_emp, auc_score_emp)

In [None]:

data.TDI.fillna(data.TDI.mean(), inplace=True)
#data = data[data.TDI.notna()]
#data = data[data.BMI.notna()]
#data = data[data.Sex.notna()]
#data = data[data.MS_any.notna()]

In [None]:
#Defining what variabvalue_countsbe used in each model: Null, Full, MS-GRS only, HLA-GRS only, non-HLA-GRS only

y = data.loc[:,["MS_any"]]

x_null = data.loc[:,["Sex_Female", "TDI", "PC1", "PC2", "PC3", "PC4", "enrol_age"]]

x_full = data.loc[:,['full_expanded',"Sex_Female", "TDI", "PC1", "PC2", "PC3", "PC4", "enrol_age"]]

x_grs = data.loc[:,['full_expanded']]

x_hlagrs = data.loc[:,['ten_full_hla']]

x_nonhlagrs = data.loc[:,['expanded_nonhla_grs']]


In [None]:
null_crossval = crossval_logreg_rocauc(x_null,y)
full_crossval = crossval_logreg_rocauc(x_full, y)
grs_crossval = crossval_logreg_rocauc(x_grs, y)

null_empirical = empirical_fpr_tpr(x_null, y)
full_empirical = empirical_fpr_tpr(x_full, y)
grs_empirical = empirical_fpr_tpr(x_grs, y)


In [None]:
hlagrs_crossval = crossval_logreg_rocauc(x_hlagrs, y)
nonhlagrs_crossval = crossval_logreg_rocauc(x_nonhlagrs, y)

hlagrs_empirical = empirical_fpr_tpr(x_hlagrs, y)
nonhlagrs_empirical = empirical_fpr_tpr(x_nonhlagrs, y)

In [None]:
grs_crossval[2]

print(full_crossval[2], full_empirical[2], null_crossval[2], null_empirical[2], grs_crossval[2], grs_empirical[2])

print(hlagrs_crossval[2],hlagrs_empirical[2],nonhlagrs_crossval[2], nonhlagrs_empirical[2])

In [None]:
font = {'family': 'sans-serif',
        'weight': 'light',
        'size': 11,
        }

plt.rc('font', **font)


plt.figure(figsize= [6, 6], dpi=180, facecolor=None)


#Plotting ROC-AUC curves

plt.plot(grs_empirical[0], grs_empirical[1],
         label = 'MS-GRS (AUC={i})'.format(i=round(grs_crossval[2],3)), linewidth = 1.2, color='seagreen')

plt.plot(hlagrs_empirical[0], hlagrs_empirical[1],
         label = 'HLA-GRS (AUC={i})'.format(i=round(hlagrs_crossval[2],3)), linewidth = 1.2, color ='darkorange'  )

plt.plot(nonhlagrs_empirical[0], nonhlagrs_empirical[1],
         label = 'non-HLA-GRS (AUC={i})'.format(i=round(nonhlagrs_crossval[2],3)), linewidth = 1.2, color = 'darkturquoise' )


plt.plot([0, 1], [0, 1], color='dimgrey', linewidth=1.2, linestyle='--')

plt.grid(which='minor', alpha=0.05)
plt.grid(which='major', alpha=0.1)


plt.xlabel('False Positive Rate' )
plt.ylabel('True Positive Rate')
plt.legend(prop={'size': 7})
plt.savefig(create_png_label('roc_hlavsnonhla_models'))

In [None]:
#Plotting ROC-AUC curves for MS-GRS only, MS=GRS + covariates and NULL model

plt.rc('font', **font)


plt.figure(figsize= [6, 6], dpi=180, facecolor=None)


plt.plot(grs_empirical[0], grs_empirical[1],
         label = 'MS-GRS Only (AUC={i})'.format(i=round(grs_crossval[2],3)), linewidth = 1.2, color='seagreen')

plt.plot(full_empirical[0], full_empirical[1],
         label = 'MS-GRS + Covariates (AUC={i})'.format(i=round(full_crossval[2],3)), linewidth = 1.2, color ='indianred'  )

plt.plot(null_empirical[0], null_empirical[1],
         label = 'Null model (AUC={i})'.format(i=round(null_crossval[2],3)), linewidth = 1.2, color = 'darkturquoise' )


plt.plot([0, 1], [0, 1], color='dimgrey', linewidth=1.2, linestyle='--')

plt.grid(which='minor', alpha=0.05)
plt.grid(which='major', alpha=0.1)


plt.xlabel('False Positive Rate' )
plt.ylabel('True Positive Rate')
plt.legend(prop={'size': 7})
plt.savefig(create_png_label('roc_auc_3_models'))

In [None]:
def stratify_kmf(data, cut_off_dict):
    
    """Requires data_frame and cut_off_dict, returns a tuple of groups"""
    vars_list = []
    
    for cut_off in cut_off_dict:
        if cut_off == 0:
            var = data[((data.cph_prediction >= data.cph_prediction.quantile(q=cut_off)) &
                   (data.cph_prediction <= data.cph_prediction.quantile(q=cut_off_dict.get(cut_off))))]
        elif cut_off >0:
            var = data[((data.cph_prediction > data.cph_prediction.quantile(q=cut_off)) &
                       (data.cph_prediction <= data.cph_prediction.quantile(q=cut_off_dict.get(cut_off))))]
        vars_list.append(var)
    return vars_list

In [None]:
cut_off_tertiles = {0:1/3, 1/3:2/3, 2/3:1}
cut_off_threequarts = {0:.25, .25:.75, .75:1}
cut_off_quarts = {0:.25, .25:.5, .5:.75, .75:1}

In [None]:
#Defininf. funciton used for Kaplan-Meier (KM) plotting 

def plot_km(data, cut_off_dict, plot_label, save_label="Unnamed", figsize = [6,6]):
    
    """Params: data_frame, cut_off_dict, plot_label, save_label and figsize. Returns a KM plot"""
    
    stratified_data = stratify_kmf(data, cut_off_dict)
    
    if cut_off_dict == cut_off_quarts:
        
        kmf_labels = ["1st qrt", "2nd qrt", "3rd qrt", '4th qrt']
        
        kmf_firstq, kmf_secondq, kmf_thirdq, kmf_fourthq = KaplanMeierFitter(), KaplanMeierFitter(), KaplanMeierFitter(), KaplanMeierFitter()
        models = [kmf_firstq, kmf_secondq, kmf_thirdq, kmf_fourthq]
        
        
    elif cut_off_dict == cut_off_threequarts:
        
        kmf_labels = ["1st quart", "2nd-3rd quart", "4th quart"]
        kmf_firstq, kmf_secondq, kmf_thirdq = KaplanMeierFitter(), KaplanMeierFitter(), KaplanMeierFitter()
        
        models = [kmf_firstq, kmf_secondq, kmf_thirdq]
        

    elif cut_off_dict == cut_off_tertiles:
        
        kmf_labels = ["1st tertile", "2nd-3rd tertile", "4th tertile"]
        kmf_firstq, kmf_secondq, kmf_thirdq = KaplanMeierFitter(), KaplanMeierFitter(), KaplanMeierFitter()
        
        models = [kmf_firstq, kmf_secondq, kmf_thirdq]
        
    else:
        print("Cut off dictioniary invalid")
        end()
        
    plt.figure(figsize= figsize, dpi=180, facecolor=None)
    
    for (group, group_label, kmf_model) in zip(stratified_data, kmf_labels, models):
        kmf_model.fit(group.ON_to_MS_years, group.first_ON, label = '{group}, n={i}'.format(
            group = group_label, i = len(group)))
        ax = kmf_model.plot(show_censors = True, censor_styles = {'ms': 3, 'marker': 'o'})
    
    ax.set_xlim([0.0, 40.0])
    ax.set_ylim([0.0, 1.0])
    
    plt.title(plot_label) 
    plt.xlabel("MS-free survival in years")
    plt.ylabel("MS-free survival probability")
    
    if len(cut_off_dict)==3:
        add_at_risk_counts(models[0], models[1], models[2], fontsize= 'small')
    
    if len(cut_off_dict)==4:
        add_at_risk_counts(models[0], models[1], models[2],models[3], fontsize= 'small')
    
    plt.tight_layout()
    

    
    plt.savefig(create_png_label(save_label))
    plt.show()

In [None]:
data.first_ON.fillna(0, inplace=True)

on_to_ms_data = data[(((data.ON_any ==1)) & (data.first_MS !=1) & (data.simult_MS_ON !=1)
                      & (data.died!=1) & (np.isnan(data.ON_to_MS_years)==False))]

kmf_data = on_to_ms_data.loc[:,['full_expanded','first_ON','ON_to_MS_years',
                               "Sex_Female", 'age18to50', 'age_ON']]

kmf_data.rename(mapper= {"full_expanded":"cph_prediction"}, axis = 1, inplace = True)

In [None]:
from lifelines import KaplanMeierFitter
from lifelines.plotting import add_at_risk_counts

plot_km(kmf_data, cut_off_threequarts, 'Unadjusted KM, by MS-GRS quartile', 'unadjusted_KM_mergedquartiles')

In [None]:
kmf_unadj_data = kmf_data


kmf_unadj_data['Sex'] = kmf_unadj_data.Sex_Female
kmf_unadj_data.loc[kmf_unadj_data.Sex_Female ==0 ,'Sex' ] = 'male'
kmf_unadj_data.loc[kmf_unadj_data.Sex_Female ==1 ,'Sex' ] = 'female'

kmf_unadj_data.loc[kmf_unadj_data.first_ON ==0 ,"first_ON"] = 'ON only'
kmf_unadj_data.loc[kmf_unadj_data.first_ON ==1 ,"first_ON"] = 'MS-ON'



kmf_strat = stratify_kmf(kmf_unadj_data, cut_off_threequarts)


for i in kmf_strat:
    print(i.groupby('Sex')['age18to50'].value_counts())
    #print(i.Sex.value_counts())
    #print(i[i.first_ON =='MS-ON'].Sex.value_counts())
    
    
kmf_unadj_data.rename(mapper= {"cph_prediction":"full_expanded"}, axis = 1, inplace = True)

In [None]:
# This plots a graph between age at ON diag and MS-GRS
sns.lmplot(data=kmf_unadj_data, x="full_expanded", y="age_ON", hue="first_ON")

In [None]:
plot_km(kmf_data, cut_off_quarts, 'Unadjusted KM, by MS-GRS quartile', 'unadjusted_KM_quartiles', figsize = [5,6])

In [None]:
from lifelines import CoxPHFitter


cph_data =on_to_ms_data.loc[:,['full_expanded','first_ON','ON_to_MS_years',
                               'Sex_Female', 'age18to50']]


#cph_data = pd.get_dummies(cph_data, columns=['smoking_status'], drop_first=True)

cph = CoxPHFitter()

cph.fit(cph_data, duration_col='ON_to_MS_years', event_col='first_ON')
cph.print_summary()
cph.plot()
cph_data['cph_prediction'] = cph.predict_partial_hazard(cph_data)


In [None]:
cph_data[data.first_ON==0].enrol_age_years.describe()

In [None]:
cph_data_test = cph_data.loc[:, cph_data.columns != 'cph_prediction']
cph.check_assumptions(cph_data_test, p_value_threshold=0.05, show_plots=True)

In [None]:
plot_km(cph_data, cut_off_threequarts, 'GRS Cox model, by quartile', 'GRS_cox_mergequart')

In [None]:
plot_km(cph_data, cut_off_quarts, 'GRS Cox model, by quartile', 'GRS_cox_4quart', figsize = [5,6])

In [None]:
plot_km(cph_data, cut_off_tertiles, "GRS Cox model, by tertile", 'coxgrs_tertile' )

In [None]:
cph_grs_tert = stratify_kmf(cph_data, cut_off_tertiles)

In [None]:
cut_off_threequint = {0:.2, .2:.8, .8:1 }
three_quint_groups = stratify_kmf(cph_data, cut_off_threequint)


In [None]:


label_strat= ['Low Risk', 'Interm. Risk', 'High Risk']

cph_data['Sex'] = cph_data.Sex_Female
cph_data.loc[cph_data.Sex_Female ==0 ,'Sex' ] = 'male'
cph_data.loc[cph_data.Sex_Female ==1 ,'Sex' ] = 'female'

cph_data.loc[cph_data.first_ON ==0 ,"first_ON"] = 'ON only'
cph_data.loc[cph_data.first_ON ==1 ,"first_ON"] = 'MS-ON'


    

In [None]:
#describing results

stratified_cph_quarts = stratify_kmf(cph_data, cut_off_quarts)
for i in stratified_cph_quarts:
    
    #print(i.Sex.value_counts())
    #print(i[i.first_ON =='MS-ON'].Sex.value_counts())
    #print(i.age18to50.value_counts())
    print(i.groupby('age18to50')['Sex'].value_counts())

In [None]:
plt.figure(figsize= [6,3], dpi=300)
sns.histplot(x='cph_prediction',y = "Sex", hue='first_ON', hue_order=['ON only', 'MS-ON'],
             data = cph_data,palette='Set2',binwidth=0.009)

#plt.legend()
plt.xlabel('Cox predicted risk, GRS model')
plt.ylabel('Sex')
plt.savefig(create_png_label("coxgrs_histbandplot"))

In [None]:
plt.figure(figsize= [6, 4], dpi=300, facecolor=None)

sns.boxenplot(x="Sex", y="cph_prediction",hue="first_ON",hue_order=['ON only', 'MS-ON'],
              palette=create_palette(labels_two, colours_two), data=cph_data)
#sns.swarmplot(x="Sex", y="cph_prediction", hue="first_ON", palette="Set1", size=4, data=cph_data)

sns.rugplot(data=cph_data, y="cph_prediction", hue="first_ON", lw=0.8, alpha=.5,
            palette=create_palette(labels_two, colours_two))
plt.ylabel('COX predicted risk of MS, GRS model', fontdict=font)
plt.legend()
plt.tight_layout()
plt.ylim([-.5, 5.1])

plt.savefig(create_png_label('cphgrs_boxen'))

In [None]:
sns.histplot(x="enrol_age_years", hue='first_ON', data=cph_data, multiple='stack' )

In [None]:
null_cph_data =data.loc[(((data.ON_any ==1)) & (data.first_MS !=1) &
                (data.simult_MS_ON !=1) &(data.died!=1) &
                (data.ON_to_MS_years >=0)), ['first_ON','ON_to_MS_years',"Sex_Female",'age18to50', "PC1"]]
                                                


null_cph = CoxPHFitter()

null_cph.fit(null_cph_data, duration_col='ON_to_MS_years', event_col='first_ON')
null_cph.print_summary()
null_cph.plot()
null_cph_data['cph_prediction'] = null_cph.predict_partial_hazard(null_cph_data)

In [None]:
plot_km(null_cph_data, cut_off_dict=cut_off_threequarts, plot_label='Null CPH, by quartile', save_label= "coxnull_mergedquartile")

In [None]:
plot_km(null_cph_data, cut_off_dict=cut_off_quarts, plot_label='Null CPH, by quartile', save_label= "coxnull_quartile", figsize=[5, 6] )

In [None]:
plot_km(null_cph_data, cut_off_dict=cut_off_tertiles, plot_label='Null CPH, by tertile', save_label= "coxnull_tertile" )

In [None]:

label_strat= ['Low Risk', 'Interm. Risk', 'High Risk']

null_cph_data['Sex'] = null_cph_data.Sex_Female
null_cph_data.loc[cph_data.Sex_Female ==0 ,'Sex' ] = 'male'
null_cph_data.loc[cph_data.Sex_Female ==1 ,'Sex' ] = 'female'

null_cph_data.loc[cph_data.first_ON ==0 ,"first_ON"] = 'ON only'
null_cph_data.loc[cph_data.first_ON ==1 ,"first_ON"] = 'MS-ON'


In [None]:
stratified_cph_quarts = stratify_kmf(null_cph_data, cut_off_quarts)
for i in stratified_cph_quarts:
    
    #print(i.Sex.value_counts())
    #print(i[i.first_ON =='MS-ON'].Sex.value_counts())
    #print(i.age18to50.value_counts())
    print(i.groupby('Sex')['first_ON'].value_counts())

In [None]:
null_cph_data['Sex'] = null_cph_data.Sex_Female

null_cph_data.loc[null_cph_data.Sex_Female ==0,'Sex' ] = 'male'
null_cph_data.loc[null_cph_data.Sex_Female ==1 ,'Sex' ] = 'female'

null_cph_data.loc[null_cph_data.first_ON ==1 ,"first_ON"] = 'MS-ON'
null_cph_data.loc[null_cph_data.first_ON ==0 ,"first_ON"] = 'ON only'


plt.figure(figsize= [6, 4], dpi=300, facecolor=None)

sns.boxenplot(x="Sex", y="cph_prediction",hue='first_ON',palette=create_palette(labels_two, colours_two),
              data=null_cph_data, hue_order=['ON only', 'MS-ON'])
#sns.swarmplot(x="Sex", y="cph_prediction", hue="first_ON", palette="Set1", size=4, data=cph_data)

sns.rugplot(data=null_cph_data, y="cph_prediction", hue="first_ON", lw=0.6, alpha=.5,
            palette=create_palette(labels_two, colours_two))
plt.ylabel('Predicted risk, Null Cox model')
plt.legend()
plt.tight_layout()
plt.ylim([-.5, 6.2])

plt.savefig(create_png_label('null_boxen'))

In [None]:
from lifelines.statistics import logrank_test

In [None]:
#Performing log-likelihood ratio test

import scipy
#LR_statstic is supposed to be the chi-squared
LR_statistic = -2*(null_cph.log_likelihood_-cph.log_likelihood_)

#p_val is calculated using scipy.stats chi-squared to p-val. How many DFs should be included?

p_val = scipy.stats.chi2.sf(LR_statistic, 1)

print("LLR-test MS-GRS vs NULL: ", p_val)

In [None]:
#Further I was just exploring other models using the same code + same Log-likelihood test

In [None]:
best_cph_data =data.loc[(((data.ON_any ==1)) & (data.first_MS !=1) &
                (data.simult_MS_ON !=1) &(data.died!=1) &
                (data.ON_to_MS_years >=0)), ['full_expanded','first_ON','ON_to_MS_years',"Sex_Female",'age18to50', 'enrol_age_years']]
                                                


best_cph = CoxPHFitter()

best_cph.fit(best_cph_data, duration_col='ON_to_MS_years', event_col='first_ON')
best_cph.print_summary()
best_cph.plot()
best_cph_data['cph_prediction'] = best_cph.predict_partial_hazard(best_cph_data)


In [None]:
plot_km(best_cph_data, cut_off_threequarts, 'Age at Enrol + proposed CPH model', 'bestCPH_threequarts')

In [None]:
plot_km(best_cph_data, cut_off_quarts, 'Age at Enrol + proposed CPH model', 'bestCPH_quarts', figsize = [5,6])

In [None]:
plot_km(best_cph_data, cut_off_tertiles, 'Age at Enrol + proposed CPH model by tertiles', 'bestCPH_terts')

In [None]:
label_strat= ['Low Risk', 'Interm. Risk', 'High Risk']

best_cph_data['Sex'] = best_cph_data.Sex_Female
best_cph_data.loc[best_cph_data.Sex_Female ==0 ,'Sex' ] = 'male'
best_cph_data.loc[best_cph_data.Sex_Female ==1 ,'Sex' ] = 'female'

best_cph_data.loc[best_cph_data.first_ON ==0 ,"first_ON"] = 'ON only'
best_cph_data.loc[best_cph_data.first_ON ==1 ,"first_ON"] = 'MS-ON'


stratified_cph_quarts = stratify_kmf(best_cph_data, cut_off_quarts)
for i in stratified_cph_quarts:
    
    #print(i.Sex.value_counts())
    #print(i[i.first_ON =='MS-ON'].Sex.value_counts())
    #print(i.age18to50.value_counts())
    print(i.groupby('Sex')['age18to50'].value_counts())

In [None]:
sns.lmplot(x='enrol_age_years', y='cph_prediction', hue='first_ON' ,data = best_cph_data)

In [None]:
sns.displot(x='enrol_age_years', hue='Sex', data = best_cph_data)

In [None]:
sns.histplot(x='enrol_age_years', y='full_expanded', data = best_cph_data, hue='first_ON', alpha = 0.7)

In [None]:
bestnull_cph_data =data.loc[(((data.ON_any ==1)) & (data.first_MS !=1) &
                (data.simult_MS_ON !=1) &(data.died!=1) &
                (data.ON_to_MS_years >=0)), ['first_ON','ON_to_MS_years',"Sex_Female",'age18to50', 'enrol_age_years', 'PC1']]
                                                


bestnull_cph = CoxPHFitter()

bestnull_cph.fit(bestnull_cph_data, duration_col='ON_to_MS_years', event_col='first_ON')
bestnull_cph.print_summary()
bestnull_cph.plot()
bestnull_cph_data['cph_prediction'] = bestnull_cph.predict_partial_hazard(bestnull_cph_data)



In [None]:
plot_km(bestnull_cph_data, cut_off_quarts, 'best NULL', 'bestnull_mergedquarts', figsize=[5,6])

In [None]:
label_strat= ['Low Risk', 'Interm. Risk', 'High Risk']

bestnull_cph_data['Sex'] = bestnull_cph_data.Sex_Female
bestnull_cph_data.loc[best_cph_data.Sex_Female ==0 ,'Sex' ] = 'male'
bestnull_cph_data.loc[bestnull_cph_data.Sex_Female ==1 ,'Sex' ] = 'female'

bestnull_cph_data.loc[bestnull_cph_data.first_ON ==0 ,"first_ON"] = 'ON only'
bestnull_cph_data.loc[bestnull_cph_data.first_ON ==1 ,"first_ON"] = 'MS-ON'



for i in stratify_kmf(bestnull_cph_data, cut_off_quarts):
    
    #print(i.Sex.value_counts())
    #print(i[i.first_ON =='MS-ON'].Sex.value_counts())
    #print(i.age18to50.value_counts())
    print(i.groupby('Sex')['age18to50'].value_counts())

In [None]:
# Very not sure about this log_likelihood test – is it mathematically correct?

import scipy
#LR_statstic is supposed to be the chi-squared
LR_statistic = -2*(bestnull_cph.log_likelihood_-best_cph.log_likelihood_)

#p_val is calculated using scipy.stats chi-squared to p-val. How many DFs should be included?
#Also, does this method work for partial LL?

p_val = scipy.stats.chi2.sf(LR_statistic, 1)

print(p_val)

In [None]:
bestnull_cph.log_likelihood_LR_statistic

In [None]:
null_cph.log_likelihood_

In [None]:
best_cph.log_likelihood_

In [None]:
cph.log_likelihood_