## Initialize

In [None]:
%load_ext autoreload
%autoreload 2

import os
from tqdm.auto import tqdm
import pathlib

import numpy as np
import pandas as pd
import lifelines

In [None]:
from dask.distributed import Client, LocalCluster
cluster = LocalCluster(n_workers=10, threads_per_worker=5)
client = Client(cluster)
cluster.scheduler

In [None]:
project_name = "210616_centres_dask"
data_path = "/data/analysis/ag-reils/steinfej"
data_pre = f"{data_path}/data/2_datasets_pre/{project_name}"
data_post = f"{data_path}/data/3_datasets_post/{project_name}"

project_label = "21_PGS_Revision"
project_path = f"/data/analysis/ag-reils/ag-reils-shared/cardioRS/results/projects/{project_label}"
figures_path = f"{project_path}/figures"
data_results_path = f"{project_path}/data"
pathlib.Path(figures_path).mkdir(parents=True, exist_ok=True)
pathlib.Path(data_results_path).mkdir(parents=True, exist_ok=True)

In [None]:
data =  pd.read_feather(f"{data_post}/data_merged.feather")

In [None]:
endpoints = ['MACE']
endpoint_labels = sorted([f"{e}_event" for e in endpoints]+[f"{e}_event_time" for e in endpoints])
endpoint_data =  pd.read_feather(f"{data_post}/data_merged.feather", columns=["eid"] + endpoint_labels)

In [None]:
preds = pd.read_feather(f"{data_results_path}/predictions_210703_centres_FINAL.feather")

In [None]:
# Bootstrapping or even recruitment centers?

In [None]:
data_test = preds[['eid','endpoint', 'module','features','split','partition','Ft_10']].query("split=='test'")
data_test

modules = data_test.module.unique().tolist()
features = data_test.features.unique().tolist()
partitions = data_test.partition.unique().tolist()

In [None]:
iterations=[i for i in range(1000)]

In [None]:
endpoint="MACE"

In [None]:
from IPython.display import clear_output
from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc, integrated_brier_score
from lifelines.utils import concordance_index
from dask.diagnostics import ProgressBar

def calculate_per_endpoint(df, partition, iteration, endpoint, module, feature, time):  
    event = [0 if (endpoint_event == 0) | (endpoint_event_time > time) else 1 for endpoint_event, endpoint_event_time in zip(df[endpoint+"_event"], df[endpoint+"_event_time"])]
    event_time = [time if (endpoint_event == 0) | (endpoint_event_time > time) else endpoint_event_time for endpoint_event, endpoint_event_time in zip(df[endpoint+"_event"], df[endpoint+"_event_time"])]
    df = df.assign(event = event, event_time = event_time)
    df = df.dropna(subset=["event_time", f"Ft_{time}", "event"], axis=0)
    
    cindex = 1-concordance_index(df["event_time"], df[f"Ft_{time}"], df["event"])
    return {"endpoint":endpoint, "module": module, "features": feature, "partition":partition, "iteration":iteration, "n": len(df), "time":time, "cindex":cindex}

def calculate_per_iteration(data_bm, endpoint, iteration, eids_bs, time):  
    results = []
    for module in modules: 
        temp_module = data_bm.query("module==@module").set_index("eid").loc[eids_bs].reset_index()
        for feature in features:
            temp_features = temp_module.query("features==@feature")
            if len(temp_features)>0:
                for partition in partitions:
                    temp_partition = temp_features.query("partition==@partition")[["eid", "Ft_10", "MACE_event", "MACE_event_time"]]
                    results.append(calculate_per_endpoint(temp_partition, partition, iteration, endpoint, module, feature, time=10))
    return results

data_bm = data_test.query("endpoint==@endpoint").merge(endpoint_data[["eid", f"{endpoint}_event", f"{endpoint}_event_time"]], on="eid", how="left")
eids = data_bm.eid.unique()
with ProgressBar():
    rows = []
    for iteration in tqdm(iterations):
        eids_bs = np.random.choice(eids, size=len(eids))
        data_future = client.scatter(data_bm)
        rows.append(client.submit(calculate_per_iteration, data_future, endpoint, iteration, eids_bs, time=10))

In [None]:
from dask.distributed import progress
progress(rows)

In [None]:
rows = client.gather(rows)

In [None]:
rows = [item for sublist in rows for item in sublist]

In [None]:
benchmark_endpoints_pp = pd.DataFrame({"endpoint":[], "module": [], "features": [], "partition": [], "iteration": [], "n":[],"time": [], "cindex": []}).append(rows, ignore_index=True)
clear_output()

In [None]:
name = "benchmark_cindex_MACE_210705_centres_FINAL"
benchmark_endpoints_pp.to_feather(f"{data_results_path}/{name}.feather")