# Hypertrophic Cardiomyopathy Genes Cross-Validation
##### Selin Kubali
##### 12/13/2023
## Goal
Find out whether we can distinguish the HCM risk of bottom 25% and top 25% of missense and deleterious variant carriers in key hypertrophic cardiomyopathy-related genes.

#### How the code functions
Use cross-validation to fit a Cox-PH model and predict hazard scores. Then isolate the bottom 25% and top 25% of carriers by hazard score and calculate whether there is a statistically significant difference in HCM between them use the Mann-Whitney U test.

Cross-validation is done by splitting on variant data, to ensure there are an equal number of variants in each fold and prevent overfitting on high-frequency variants.

#### Inputs
Lifelines files - from running generate_extracts_gnomAD.ipynb on UKBiobank in Cassa Lab Shared Project/selected_genes/hcm/notebooks. Stored in Cassa Lab Shared Project/selected_genes/hcm/lifelines_data. 
Variant data files - from running vep_processing.ipynb on UKBiobank in Cassa Lab Shared Project/selected_genes/hcm/notebooks. Stored in Cassa Lab Shared Project/selected_genes/hcm/parsed_vep_files

#### Note
Three HCM related genes - DES, PLN, TTR - were eliminated for having too few variants to converge. Each had 2 or fewer cases of HCM with missense or deleterious variants.
PTPN11 and TNNI3 each have few cases of HCM with missense or deleterious variants, which may harm convergence.

In [2]:
import pandas as pd
import numpy as np
from lifelines import CoxPHFitter
from sklearn.model_selection import KFold
from statsmodels.stats.multitest import multipletests
from lifelines import KaplanMeierFitter
import matplotlib.pyplot as plt
from lifelines.statistics import logrank_test


In [44]:
genes = ["ALPK3"]
thresholds_list =  list(range(1, 101))
p_vals = {}

def cross_val(gene):
    cph = CoxPHFitter()

    # load lifelines file
    file_name=gene+'.csv'
    lifelines_data = pd.read_csv("/Users/uriel/Downloads/work_temp/cross_val_lifelines/"+file_name, dtype={
            'is_family_hist':'boolean',
            'is_hcm':'boolean'
            })

    # load variant data file
    file_name=gene+'.csv'
    variant_data = pd.read_csv("/Users/uriel/Downloads/work_temp/variant_files/"+file_name)
    variant_data = variant_data[['Name']]
    variant_data['var_index'] = variant_data.index

    # set lifelines data index to variant data index
    lifelines_data = variant_data.merge(lifelines_data, how="outer")
    lifelines_data.set_index("var_index")


    # clean lifelines file; set pathogenicity for deleterious variants to 1
    lifelines_data = lifelines_data.drop(["Name", "death_age", "Unnamed: 0", "birth_date"], axis = 1)
    lifelines_data.loc[lifelines_data['deleterious'] == 1, 'am_pathogenicity'] = 1


    # filter for only missense and deleterious variants
    lifelines_data = lifelines_data[(lifelines_data['deleterious'] == True) | (lifelines_data['missense_variant'] == True)]
                

    # clean lifelines file
    lifelines_data = lifelines_data.drop(['principal_component_2','principal_component_3', 'principal_component_5', 'principal_component_6', 'principal_component_7', 'principal_component_8', 'principal_component_9', 'principal_component_10', 'deleterious', 'synonymous_variant', 'missense_variant'], axis = 1)
    lifelines_data['am_pathogenicity'] = lifelines_data['am_pathogenicity'].astype(float) 

    # cross validation: split up phenotypic data file based on variant file index
    kf = KFold(n_splits=5, shuffle=True, random_state=1)
    testing_set = []
    for train_idx, test_idx in kf.split(variant_data):
            train = lifelines_data[lifelines_data['var_index'].isin(train_idx)]
            test = lifelines_data[lifelines_data['var_index'].isin(test_idx)]

            train = train.drop(['var_index'], axis=1)
            test = test.drop(['var_index'], axis=1)

            # fit CPH and add hazard scores
            cph.fit(train, duration_col="duration", event_col="is_hcm", fit_options = {"step_size":0.1})
            hazard_scores_fold = cph.predict_partial_hazard(test)
            test['hazard'] = hazard_scores_fold
            testing_set.append(test)

    # create new lifelines_data df by joining all testing sets
    lifelines_data = pd.concat([df for idx, df in enumerate(testing_set)])


    # filter for patients with lowest 25% and highest 25% hazard scores

    for i in thresholds_list:
        percentiles = np.percentile(lifelines_data['hazard'], [i])
        bottom = lifelines_data[lifelines_data['hazard'] < percentiles[0]]
        top = lifelines_data[lifelines_data['hazard'] >= percentiles[0]]
        bottom.loc[:,'is_hcm'] = np.where(bottom['is_hcm'] == True, 1, 0)
        top.loc[:,'is_hcm'] = np.where(top['is_hcm'] == True, 1, 0)




        result = logrank_test(bottom['duration'], top['duration'], event_observed_A=bottom['is_hcm'], event_observed_B=top['is_hcm'])
        p_vals.update({i:result.p_value})


        """kmf_lowest_25_variant = KaplanMeierFitter()
        kmf_lowest_25_variant.fit(durations=bottom['duration'], event_observed=bottom['is_hcm'], label = 'bottom')
        kmf_lowest_25_variant.plot_survival_function()


        kmf_highest_25_variant = KaplanMeierFitter()
        kmf_highest_25_variant.fit(durations=top['duration'], event_observed=top['is_hcm'], label = 'top')
        kmf_highest_25_variant.plot_survival_function()


    plt.title(gene)
    plt.figure()"""




for gene in genes:
    cross_val(gene)



In [45]:
p_adjusted = multipletests(list(p_vals.values()), alpha=0.05, method='bonferroni')
updated_dict = {key: new_p_val for key, new_p_val in zip(p_vals.keys(), p_adjusted[1])}
print("P-values: ", updated_dict)

print(min(updated_dict.values()))



P-values:  {1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0, 5: 1.0, 6: 1.0, 7: 1.0, 8: 1.0, 9: 1.0, 10: 1.0, 11: 1.0, 12: 1.0, 13: 1.0, 14: 1.0, 15: 1.0, 16: 0.8056733605207032, 17: 0.6072163848731624, 18: 0.4687949068364246, 19: 0.357010179620144, 20: 0.2604565865341441, 21: 0.19173935511212734, 22: 0.14284829024434684, 23: 0.10297148369530003, 24: 0.07441778594820467, 25: 0.05262335599724817, 26: 0.03737671053316947, 27: 0.025498974845561965, 28: 0.01811354880635237, 29: 0.012488844315123426, 30: 0.00868841910262129, 31: 0.005695636644209561, 32: 0.0040370536239511135, 33: 0.0026751152586261514, 34: 0.008779955285023712, 35: 0.09792257683683615, 36: 0.07125810562157871, 37: 0.050818916110925594, 38: 0.03408545754125135, 39: 0.022311007712853513, 40: 0.014054419706653606, 41: 0.03579192752748764, 42: 0.023399062278453953, 43: 0.015158562135195305, 44: 0.009473771451809889, 45: 0.005839814876126727, 46: 0.003573774658483644, 47: 0.0022348780999650836, 48: 0.0013081916656436582, 49: 0.00074655693348136