# Cox model

In [None]:
%load_ext autoreload
%autoreload 2

import os
from tqdm.auto import tqdm

import numpy as np
import pandas as pd
import lifelines
from lifelines import CoxPHFitter
from sklearn.model_selection import StratifiedKFold

from joblib import Parallel, delayed
from tqdm.notebook import tqdm
import neptune
import warnings
warnings.filterwarnings("ignore")
import shutil
import anndata as ad
import pickle
import pathlib

In [None]:
project_name = "210616_centres_dask"
data_path = "/data/analysis/ag-reils/steinfej"
data_pre = f"{data_path}/data/2_datasets_pre/{project_name}"
data_post = f"{data_path}/data/3_datasets_post/{project_name}"

project_label = "21_PGS_Revision"
project_path = f"/data/analysis/ag-reils/ag-reils-shared/cardioRS/results/projects/{project_label}"
figures_path = f"{project_path}/figures"
data_results_path = f"{project_path}/data"
pathlib.Path(figures_path).mkdir(parents=True, exist_ok=True)
pathlib.Path(data_results_path).mkdir(parents=True, exist_ok=True)

## Load data

In [None]:
endpoints = ['MACE']

In [None]:
from dask.distributed import Client, LocalCluster
cluster = LocalCluster(n_workers=20, threads_per_worker=100)
client = Client(cluster)
client

In [None]:
partitions = [str(p) for p in range(22)]
splits = ["train", "valid", "test"]

# Create COX and Predictions

In [None]:
data_temp = pd.read_feather(f"{data_post}/data_merged.feather")
eids_dict = {}
for endpoint in tqdm(endpoints):
    if endpoint == "MACE": eids_incl = data_temp.copy().query(f"myocardial_infarction==False&stroke==False&statins==False").eid.to_list()
    print(endpoint, len(eids_incl))
    eids_dict[endpoint] = eids_incl

In [None]:
data_description = pd.read_feather(f"{data_post}/description.feather")

In [None]:
def load_data(dataset_path, partition, split, eids_incl):
    return pd.read_feather(f"{data_post}/partition_{partition}/{split}/data_imputed_normalized.feather").set_index("eid")

data_all = {partition: {split: client.submit(load_data, data_post, partition, split, eids_incl) for split in splits} for partition in tqdm(partitions)}

In [None]:
data_all = client.gather(data_all)

In [None]:
basics = [
'age_at_recruitment',
'ethnic_background_0.0',
'ethnic_background_1.0',
'ethnic_background_2.0',#na 2 -> 5
'ethnic_background_3.0',
'ethnic_background_4.0',
'townsend_deprivation_index_at_recruitment',
'sex'
]
questionnaire = [
'overall_health_rating_0.0',
'overall_health_rating_1.0',
'overall_health_rating_2.0',
'overall_health_rating_3.0',
'smoking_status_0.0',
'smoking_status_1.0',
'smoking_status_2.0',
]
measurements = [
'body_mass_index_bmi',
'weight',
"standing_height",
'systolic_blood_pressure',
'diastolic_blood_pressure',
]

labs = [
"cholesterol",
"hdl_cholesterol",
"ldl_direct",
"triglycerides"
]

family_history = [
'fh_heart_disease',
]

diagnoses = [
'diabetes1',
'diabetes2',
'chronic_kidney_disease',
'atrial_fibrillation',
'migraine',
'rheumatoid_arthritis',
'systemic_lupus_erythematosus',
'severe_mental_illness',
'erectile_dysfunction',
]

medications = [
"antihypertensives",
"ass",
"atypical_antipsychotics",
"glucocorticoids"
]

pgs_all = [
    'PGS000011',
    'PGS000018',
    'PGS000039',
    'PGS000057',
    'PGS000058',
    'PGS000059'
]

In [None]:
feature_dict = {
"basics": basics,
"questionnaire": questionnaire,
"measurements": measurements,
"labs": labs,
"family_history": family_history,
"medications": medications,
"diagnoses": diagnoses,
}

In [None]:
features = {}
features["clinical"] = feature_dict["basics"]+feature_dict["questionnaire"]+feature_dict["measurements"] + feature_dict["labs"]+feature_dict["family_history"]+feature_dict["medications"]+feature_dict["diagnoses"]
features["clinical_pgs_all"] = features["clinical"] + pgs_all
features["clinical_pgs_all*age"] = features["clinical_pgs_all"] 
features["sun_pgs"] = ["age_at_recruitment", "sex", 'smoking_status_0.0', "diabetes2", "systolic_blood_pressure", "diastolic_blood_pressure", "cholesterol", "hdl_cholesterol", "PGS000018"]

In [None]:
formulas = {}
formulas["clinical"] = "+".join(features["clinical"])
formulas["clinical_pgs_all"] = "+".join(features["clinical_pgs_all"])
formulas["clinical_pgs_all*age"] = "+".join([col for col in features["clinical"] if col!="age_at_recruitment"])+"+"+"+".join([f"age_at_recruitment*{col}" for col in pgs_all])
formulas["sun_pgs"] = ["age_at_recruitment", "sex", 'smoking_status_0.0', "diabetes2", "systolic_blood_pressure", "diastolic_blood_pressure", "cholesterol", "hdl_cholesterol", "PGS000018", "PGS000039"]

## Predictions

In [None]:
#endpoint = "M_MACE"; 
events=[endpoint+'_event' for endpoint in endpoints] 
times=[endpoint+'_event_time' for endpoint in endpoints]
groups = list(features)

In [None]:
data = {}
for group in tqdm(groups): 
    data[group] = {"features":features[group]+events+times}
    for partition in partitions: 
        data[group][partition] = {}
        for split in splits: data[group][partition][split] = data_all[partition][split].loc[:, data[group]["features"]].copy()

In [None]:
from lifelines.utils import concordance_index
import pathlib

def fit_predict_coxph(data_h5ad, endpoint, group, partition, time, event, eids_incl, dump_path):
    pathlib.Path(dump_path).mkdir(parents=True, exist_ok=True)      

    cph = CoxPHFitter()
    train_data = data_h5ad["train"].reset_index().query("eid==@eids_incl").set_index("eid")
    val_data = data_h5ad["valid"].reset_index().query("eid==@eids_incl").set_index("eid")
    test_data = data_h5ad["test"].reset_index().query("eid==@eids_incl").set_index("eid")


    covariates_with_tte = [col for col in data[group]["features"] if "MACE" not in col]+[time, event]
    for col in covariates_with_tte:
        if train_data[col].nunique()==1: covariates_with_tte.remove(col)

    cph.fit(train_data[covariates_with_tte], duration_col=time, event_col=event, show_progress=True, step_size=0.5, formula=formulas[group])
    pickle.dump(cph, open(f"{dump_path}/{endpoint}_{group}_{partition}.p", "wb" ) )
    print(concordance_index(val_data[time], -cph.predict_partial_hazard(val_data[covariates_with_tte]), val_data[event]))

    surv_train = 1-cph.predict_survival_function(train_data[covariates_with_tte], times=[t for t in range(1,27)])
    surv_val = 1-cph.predict_survival_function(val_data[covariates_with_tte], times=[t for t in range(1,27)]) # as years + 1  
    surv_test = 1-cph.predict_survival_function(test_data[covariates_with_tte], times=[t for t in range(1,27)]) # as years + 1 

    pred = {"train":train_data.reset_index()[["eid"]],
            "val":val_data.reset_index()[["eid"]],
           "test":test_data.reset_index()[["eid"]],}

    pred["train"][f"score_COX_{group}"] = surv_train.iloc[0].to_list()
    pred["val"][f"score_COX_{group}"] = surv_val.iloc[0].to_list()
    pred["test"][f"score_COX_{group}"] = surv_test.iloc[0].to_list()


    time_cols = {t: f"0_{t}_Ft" for t in range(1, 27)}
    for t, col in time_cols.items():
        pred["train"][col] = surv_train.T[t].to_list()
        pred["val"][col] = surv_val.T[t].to_list()
        pred["test"][col] = surv_test.T[t].to_list()

    preds = pd.concat([pred["train"].assign(split="train"), pred["val"].assign(split="valid"), pred["test"].assign(split="test")], axis=0)\
        .assign(endpoint=endpoint, features=group, partition=partition, module="COXPH", datamodule="UKBBSurvivalDatamodule", net="", calibrated="False")
    preds = preds[["eid", 'endpoint', 'features', 'split', 'partition', 'module', 'datamodule', 'net', 'calibrated'] + list(time_cols.values())].reset_index(drop=True)
    preds.to_feather(f"{dump_path}/{endpoint}_{group}_{partition}.feather")

In [None]:
dump_path = f"{data_post}/COXPH/210631_PGS_REVISION"

In [None]:
for endpoint in tqdm(endpoints):
    time = f"{endpoint}_event_time"
    event = f"{endpoint}_event"
    eids_incl = eids_dict[endpoint]
    for group in tqdm(groups):
        print(group)
        for partition in partitions:
            fit_predict_coxph(data[group][partition], endpoint, group, partition, time, event, eids_incl, dump_path)

# Read and Process Predictions

In [None]:
import glob
files = sorted(glob.glob(f"{dump_path}/*.feather"))

In [None]:
import joblib
import pandas as pd
from joblib import Parallel, delayed
from tqdm.auto import tqdm
def get_df(path): return pd.read_feather(path)#return pd.read_csv(f"{path[:-8]}.csv", index_col=0)
with joblib.parallel_backend('dask'):
    dfs = Parallel(n_jobs=80)(delayed(get_df)(path) for path in tqdm(files) if path is not None if not pd.isna(path))

In [None]:
predictions = pd.concat(dfs, axis=0).reset_index(drop=True)

In [None]:
def convert_to_float32(df):
    for col in tqdm(df.columns.to_list()):
        if df[col].dtype == "float64": 
            print(col, "convert")
            df[col]= df[col].astype("float32")
    return df

for col in tqdm(predictions.columns.to_list()):
    if predictions[col].dtype == "object": predictions[col]= predictions[col].astype("category")
        
predictions["partition"] = predictions["partition"].astype(int)
predictions = convert_to_float32(predictions)

In [None]:
def fix_column_names(df):
    # rename and fix time bugs!!! -> 0_11_Ft -> Ft at t=10 -> fix earlier
    time_fix_map = dict(zip([col for col in df.columns if "Ft" in col], [f"Ft_{col}" for col in range(len([col for col in df.columns if "Ft" in col]))]))
    df = df.rename(time_fix_map, axis="columns")
    return df
predictions = fix_column_names(predictions)

In [None]:
predictions.to_feather(f"{data_results_path}/predictions_cox_210631_REVISION.feather")