In [None]:
import os
import warnings
import numpy as np
import pandas as pd
from scipy.stats import chi2
from lifelines import CoxPHFitter
from typing import List, Tuple, Union
from pycox.evaluation import EvalSurv
from sksurv.metrics import concordance_index_censored

Numeric = Union[float, int]
NumericArrayLike = Union[List[Numeric], Tuple[Numeric], np.array]

warnings.filterwarnings('ignore')

ResPath = './run-results'
# all_methods = os.listdir(ResPath)
all_methods = ['FGCNSurv', 'M2EFM', 'Multimodal_NSCLC', 'SALMON', 'GDP', 'MiNet', 'SurvivalNet', 'SAE', 'CSAE',
               'CustOmics',
               'I-Boost', 'TCGA-omics-integration', 'MDNNMD', 'TF-LogHazardNet', 'TF-ESN', 'OmiEmbed', 'IPF-LASSO',
               'Priority-Lasso',
               'blockForest', 'MultimodalSurvivalPrediction']

## methods output Hazard Ratio
HR_methods = ['IPF-LASSO', 'M2EFM', 'Multimodal_NSCLC', 'SALMON', 'GDP', 'MiNet', 'SurvivalNet', 'SAE', 'CSAE',
              'CustOmics',
              'MultimodalSurvivalPrediction', 'FGCNSurv', 'I-Boost', 'Priority-Lasso', 'TCGA-omics-integration']
## methods output Survival Probability
SP_methods = ['TF-LogHazardNet', 'TF-ESN', 'OmiEmbed', 'blockForest']
## methods output Vital Status with Probability
VP_methods = ['MDNNMD']


#### DEFINE FUNCTIONS
## function to find indices for specific times
def find_idx(times, all_times):
    idx_vec = []
    for time in times:
        temp_times = np.abs(all_times - time)
        idx = np.where(temp_times == np.min(temp_times))[0]
        idx_vec.append(idx[0])

    return idx_vec


## calculate the survival probability using hazard ratio
def cal_survprob(pred_train, pred_val):
    pred_train['log_HR'] = pred_train['predTrain'].apply(np.log)

    cph = CoxPHFitter()
    cph.fit(pred_train.loc[:, ['log_HR', 'time', 'status']], duration_col='time', event_col='status')
    baseline_cum_hazard = cph.baseline_cumulative_hazard_

    all_times = baseline_cum_hazard.index.values
    max_time = np.max(all_times)
    target_times = np.linspace(0, max_time, 20).tolist()
    baseline_haz_indices = find_idx(target_times, all_times)
    baseline_haz = baseline_cum_hazard.iloc[baseline_haz_indices, 0].values

    HR_val = pred_val['predVal'].to_numpy()
    CHF_val = HR_val[:, np.newaxis] * baseline_haz
    survprob_val = np.exp(-CHF_val)
    survprob_val = survprob_val.T
    survprob_val = pd.DataFrame(survprob_val)
    survprob_val.index = target_times

    return survprob_val


## calculate time-dependent CIndex and IBS
def cal_tdci_ibs(surv_test, duration_test, event_test):
    ev = EvalSurv(surv_test, duration_test, event_test, censor_surv='km')
    testci = ev.concordance_td('antolini')

    max_time = np.max(duration_test)
    time_grid = np.linspace(0, max_time, 20)

    ibs = ev.integrated_brier_score(time_grid)
    return testci, ibs


## calculate risk score using survival probability
def cal_ci(predRes):
    risk_score = -np.log(predRes)
    risk_score = np.sum(risk_score, axis=1)
    risk_score = risk_score / np.max(risk_score)
    return risk_score


## calculate D-Calibration (Paper: Effective Ways to Build and Evaluate Individual Survival Distributions)
## (I think it can be considered as Moderate Calibration)
# Some necessary functions
def check_indicators(indicators: np.array) -> None:
    if not all(np.logical_or(indicators == 0, indicators == 1)):
        raise ValueError(
            "Event indicators must be 0 or 1 where 0 indicates censorship and 1 is an event."
        )


def to_array(array_like: NumericArrayLike, to_boolean: bool = False) -> np.array:
    array = np.asarray(array_like)
    shape = np.shape(array)
    if len(shape) > 1:
        raise ValueError(
            f"Input should be a 1-d array. Got a shape of {shape} instead."
        )
    if np.any(array < 0):
        raise ValueError("All event times must be greater than or equal to zero.")
    if to_boolean:
        check_indicators(array)
        return array.astype(bool)
    return array


# get the survival probability for each patient at the time they died or were censored
def get_1surv_prob(pred_df, survival_times):
    time_points = pred_df.index.values
    probabilities = np.zeros(len(survival_times))

    for i, time in enumerate(survival_times):
        distances = np.abs(time_points - time)
        min_distance = np.min(distances)
        candidate_indices = np.where(distances == min_distance)[0]

        if len(candidate_indices) == 1:
            selected_idx = candidate_indices[0]
        else:
            candidate_probs = pred_df.iloc[candidate_indices, i]
            min_prob = candidate_probs.min()
            min_prob_indices = candidate_indices[candidate_probs == min_prob]
            selected_idx = min_prob_indices[0]
        probabilities[i] = pred_df.iloc[selected_idx, i]
    return probabilities



# function to calculate D-Calibration
def d_calibration(
        event_indicators: NumericArrayLike,
        predictions: NumericArrayLike,
        bins: int = 10,
) -> dict:
    event_indicators = to_array(event_indicators, to_boolean=True)
    predictions = to_array(predictions)

    # include minimum to catch if probability = 1.
    bin_index = np.minimum(np.floor(predictions * bins), bins - 1).astype(int)
    censored_bin_indexes = bin_index[~event_indicators]
    uncensored_bin_indexes = bin_index[event_indicators]

    censored_predictions = predictions[~event_indicators]
    censored_contribution = 1 - (censored_bin_indexes / bins) * (
            1 / censored_predictions
    )
    censored_following_contribution = 1 / (bins * censored_predictions)

    contribution_pattern = np.tril(np.ones([bins, bins]), k=-1).astype(bool)

    following_contributions = np.matmul(
        censored_following_contribution, contribution_pattern[censored_bin_indexes]
    )
    single_contributions = np.matmul(
        censored_contribution, np.eye(bins)[censored_bin_indexes]
    )
    uncensored_contributions = np.sum(np.eye(bins)[uncensored_bin_indexes], axis=0)
    bin_count = (
            single_contributions + following_contributions + uncensored_contributions
    )
    chi2_statistic = np.sum(
        np.square(bin_count - len(predictions) / bins) / (len(predictions) / bins)
    )
    return dict(
        p_value=1 - chi2.cdf(chi2_statistic, bins - 1),
        bin_proportions=bin_count / len(predictions),
        censored_contributions=(single_contributions + following_contributions)
                               / len(predictions),
        uncensored_contributions=uncensored_contributions / len(predictions),
    )


## function to create table for each metric
def create_metric_table(data, metric, datasets):
    metric_data = {method: values[metric] for method, values in data.items()}
    df = pd.DataFrame(metric_data, index=[f'Run {i + 1}' for i in range(len(datasets))])
    df.index = datasets
    return df


## create summary table for all three metrics
def create_summary_table(cindex_table, tdcindex_table, ibs_table, cal_table):
    summary_data = {
        'Cindex': cindex_table.mean(),
        'tdCindex': tdcindex_table.mean(),
        'IBS': ibs_table.mean(),
        'Calibration': cal_table.mean()
    }

    summary_df = pd.DataFrame(summary_data)
    return summary_df


#### CALCULATE METRICS FOR ALL METHODS
allRes = {}
for method in all_methods:
    all_datasets = os.listdir(os.path.join(ResPath, method))
    all_datasets = [x for x in all_datasets if "TCGA-" in x]
    all_datasets = sorted(all_datasets)  # Returns new sorted list
    allCIndex = []
    alltdCIndex = []
    allCal = []
    allIBS = []

    for dataset in all_datasets:
        NewPath = os.path.join(ResPath, method, dataset)
        CIndex = []
        tdCIndex = []
        IBS = []
        Cal = []

        for time in range(1, 6):
            for fold in range(1, 11):
                TrainRes = pd.read_csv(os.path.join(ResPath, method, dataset, "Time" + str(time)) + "/Train_Res_" + str(fold) + ".csv",
                                       header=0, index_col=0)
                ValRes = pd.read_csv(os.path.join(ResPath, method, dataset, "Time" + str(time)) + "/Val_Res_" + str(fold) + ".csv", header=0,
                                     index_col=0)
                trueDat = np.array(ValRes[['time', 'status']])

                if method in HR_methods:
                    try:
                        valSurvProb = cal_survprob(TrainRes, ValRes)
                        tmp_tdCI, tmp_ibs = cal_tdci_ibs(valSurvProb, trueDat[:, 0], trueDat[:, 1])
                        tmp_survprob = get_1surv_prob(valSurvProb, trueDat[:, 0])
                        tmp_Cal = d_calibration(trueDat[:, 1], tmp_survprob)["p_value"]
                        # tmp_Cal = -np.log10(tmp_Cal)     ### transform p-values into -log10

                    except:
                        tmp_tdCI = np.NaN
                        tmp_ibs = np.NaN
                        tmp_Cal = np.NaN

                    try:
                        tmp_CI = concordance_index_censored(event_indicator=ValRes['status'].values.astype(bool),
                                                            event_time=ValRes['time'].values,
                                                            estimate=ValRes['predVal'])[0]
                    except:
                        tmp_CI = np.NaN

                if method in SP_methods:
                    predVal = ValRes.drop(['time', 'status'], axis=1)
                    try:
                        riskScore = cal_ci(predVal.to_numpy())
                        tmp_CI = concordance_index_censored(event_indicator=ValRes['status'].values.astype(bool),
                                                            event_time=ValRes['time'].values,
                                                            estimate=riskScore)[0]
                    except:
                        tmp_CI = np.NaN
                    time_points = [int(col.split('_')[-1]) for col in predVal.columns if '_' in col]
                    valSurvProb = predVal.T
                    valSurvProb = valSurvProb.reset_index(drop=True)
                    valSurvProb.index = time_points

                    try:
                        tmp_tdCI, tmp_ibs = cal_tdci_ibs(valSurvProb, trueDat[:, 0], trueDat[:, 1])
                        tmp_survprob = get_1surv_prob(valSurvProb, trueDat[:, 0])
                        tmp_Cal = d_calibration(trueDat[:, 1], tmp_survprob)["p_value"]
                        # tmp_Cal = -np.log10(tmp_Cal)  ### transform p-values into -log10

                    except:
                        tmp_tdCI = np.NaN
                        tmp_ibs = np.NaN
                        tmp_Cal = np.NaN

                if method in VP_methods:
                    try:
                        tmp_CI = concordance_index_censored(event_indicator=ValRes['status'].values.astype(bool),
                                                            event_time=ValRes['time'].values,
                                                            estimate=1 - ValRes['predVal'].to_numpy())[0]
                    except:
                        tmp_CI = np.NaN

                    tmp_tdCI = np.NaN
                    tmp_ibs = np.NaN
                    tmp_Cal = np.NaN

                if tmp_Cal > 0.05:
                    tmp_Cal = 1
                else:
                    tmp_Cal = 0

                CIndex.append(tmp_CI)
                tdCIndex.append(tmp_tdCI)
                IBS.append(tmp_ibs)
                Cal.append(tmp_Cal)

        CIndex = np.nanmean(CIndex)
        tdCIndex = np.nanmean(tdCIndex)
        IBS = np.nanmean(IBS)
        # Cal = np.nanmean(Cal)
        Cal = np.nansum(Cal)


        allCIndex.append(CIndex)
        alltdCIndex.append(tdCIndex)
        allIBS.append(IBS)
        allCal.append(Cal)

    # allCIndex = np.nanmean(allCIndex)
    # alltdCIndex = np.nanmean(alltdCIndex)
    # allIBS = np.nanmean(allIBS)
    allRes[method] = {'CIndex': allCIndex, 'tdCIndex': alltdCIndex, 'IBS': allIBS, 'Cal': allCal}


CIndex_Table = create_metric_table(allRes, 'CIndex', all_datasets)
tdCIndex_Table = create_metric_table(allRes, 'tdCIndex', all_datasets)
IBS_Table = create_metric_table(allRes, 'IBS', all_datasets)
Cal_Table = create_metric_table(allRes, 'Cal', all_datasets)
Res_Table = create_summary_table(CIndex_Table, tdCIndex_Table, IBS_Table, Cal_Table)

CIndex_Table.to_csv(ResPath + "/CIndex_Table.csv", sep=",", header=True)
tdCIndex_Table.to_csv(ResPath + "/tdCIndex_Table.csv", sep=",", header=True)
IBS_Table.to_csv(ResPath + "/IBS_Table.csv", sep=",", header=True)
Cal_Table.to_csv(ResPath + "/Calibration_Table.csv", sep=",", header=True)
Res_Table.to_csv(ResPath + "/Res_Summary_Table.csv", sep=",", header=True)

print(Res_Table)