In [47]:
import numpy as np
import pandas as pd
import os
import csv
import datetime
import statsmodels.api as sm
import statsmodels.formula.api as smf
import matplotlib.pyplot as plt
import seaborn as sns
import statsmodels.formula.api as smf

from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder
from lifelines import CoxPHFitter, WeibullFitter, WeibullAFTFitter
from datetime import datetime, date, timedelta
from sklearn.model_selection import train_test_split
from os.path import isfile, join
from sklearn.metrics import mean_absolute_error, roc_auc_score, precision_score, recall_score, accuracy_score, mean_absolute_percentage_error
from statsmodels.stats.outliers_influence import variance_inflation_factor
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import calibration_curve

import torch
import torchtuples as tt

from pycox.datasets import metabric
from pycox.models import CoxPH
from pycox.models.loss import CoxPHLoss
from pycox.evaluation import EvalSurv

from sklearn.preprocessing import OrdinalEncoder, StandardScaler, MinMaxScaler
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, KFold
from sklearn.model_selection import StratifiedKFold
from time import time
from sksurv.functions import StepFunction
from sksurv.metrics import (
    concordance_index_censored,
    concordance_index_ipcw,
    cumulative_dynamic_auc,
    integrated_brier_score)
from sksurv.metrics import brier_score
from sksurv.nonparametric import kaplan_meier_estimator
from sksurv.preprocessing import OneHotEncoder, encode_categorical
from sksurv.util import Surv
from lifelines.utils import concordance_index
from lifelines import KaplanMeierFitter
from data import load_dataset 
from sklearn.metrics import make_scorer
from sksurv.metrics import concordance_index_censored
from sklearn.pipeline import Pipeline
from sklearn.model_selection import ParameterGrid
from pycox.models import DeepHitSingle
import shap
import warnings 
warnings.filterwarnings('ignore')
from sklearn.isotonic import IsotonicRegression


drive = 'G'
main_path = drive + ':/Shared drives/CKD_Progression/data/CKD_COHORT_Jan2010_Mar2024_v3.csv'
data_path = drive + ':/Shared drives/CKD_Progression/data/'
docs_path = drive + ':/Shared drives/CKD_Progression/docs/'
save_path = drive + ':/Shared drives/CKD_Progression/save/'
resu_path = drive + ':/Shared drives/CKD_Progression/result/'
covariates_path = docs_path + 'covariates.csv'
removecols_path = docs_path + 'remove_columns.csv'

covariates, order_covariates, long_df = load_dataset(get_columns = True)

from sklearn.model_selection import ParameterSampler
from sklearn_pandas import DataFrameMapper
from pycox.models import CoxPH

In [60]:
class EarlyStopping:
    def __init__(self, patience=5, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_loss = float('inf')
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                print("Early stopping triggered!")
                self.early_stop = True

early_stopping = EarlyStopping(patience=5, delta=0.01)

In [59]:
def load_input(pathway):
    transition_df = long_df[long_df['pathway'] == pathway]
    univariate  = pd.read_excel(resu_path + f'univariate/LR_test/{pathway}.xlsx', sheet_name = pathway)
    multivariate_covariates = univariate[univariate['pvalue'] < 0.20]['variable'].tolist()
    transition_df = transition_df[multivariate_covariates + ['time', 'status']]
    df_train = transition_df.copy()
    df_tests = df_train.sample(frac = 0.2, random_state=42)
    df_train = df_train.drop(df_tests.index)
    df_valid = df_train.sample(frac = 0.2, random_state=42)
    df_train = df_train.drop(df_valid.index)

    standardize = [([col], StandardScaler()) for col in multivariate_covariates]
    x_mapper = DataFrameMapper(standardize)

    X_train = x_mapper.fit_transform(df_train).astype('float32')
    X_valid = x_mapper.transform(df_valid).astype('float32')
    X_tests = x_mapper.transform(df_tests).astype('float32')

    num_durations = 10
    labtrans = DeepHitSingle.label_transform(num_durations)
    
    get_target = lambda df: (df['time'].values, df['status'].values)
    y_train = labtrans.fit_transform(*get_target(df_train))
    y_valid = labtrans.transform(*get_target(df_valid))
    durations_test, events_test = get_target(df_tests)
    val = X_valid, y_valid
    
    train = (X_train, y_train)
    val   = (X_valid, y_valid)
    durations_test, events_test = get_target(df_tests)
    return X_train, y_train, val, X_tests, durations_test, events_test, df_tests

In [68]:
def compute_ace(surv, durations, events, time_grid):
    ace = 0
    n_bins = len(time_grid)
    for t in time_grid:
        nearest_time_index = surv.index.get_loc(t, method='nearest')
        pred_surv = surv.iloc[nearest_time_index].values
    
        observed = (durations >= t).astype(int)
        observed_censored = observed[events == 1]
        pred_surv_censored = pred_surv[events == 1]
        bin_error = np.abs(pred_surv_censored.mean() - observed_censored.mean())
        ace += bin_error
    
    return ace / n_bins

def conformal_coverage(surv, durations, events, alpha):
    final_time_index = surv.index[-1]  
    last_surv = surv.loc[final_time_index].values

    lower_bound = np.quantile(last_surv, alpha / 2)
    upper_bound = np.quantile(last_surv, 1 - alpha / 2)

    coverage = np.mean((durations >= lower_bound) & (durations <= upper_bound))
    return coverage

def evaluate_deephit(main_model, val, X_tests, durations_test, events_test):
    surv = main_model.predict_surv_df(X_tests)
    ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')

    time_grid = np.linspace(durations_test.min(), durations_test.max(), 1000)
    c_index = ev.concordance_td()
    brier_score = ev.integrated_brier_score(time_grid)

    X_val, (durations_val, events_val) = val 
    mean_survival = surv.mean(axis=0).values.ravel()

    iso_reg = IsotonicRegression(out_of_bounds = 'clip')
    iso_reg.fit(mean_survival, durations_test)

    recalibrated_surv = surv.copy()
    for col in recalibrated_surv.columns:
        recalibrated_surv[col] = iso_reg.transform(surv[col].values)
    recalibrated_ev = EvalSurv(recalibrated_surv, durations_test, events_test, censor_surv='km')
    recalibrated_c_index = recalibrated_ev.concordance_td()
    recalibrated_brier_score = recalibrated_ev.integrated_brier_score(time_grid)

    ace = compute_ace(recalibrated_surv, durations_test, events_test, time_grid)
    coverage = conformal_coverage(recalibrated_surv, durations_test, events_test, alpha = 0.05)
    return c_index, brier_score, recalibrated_c_index, recalibrated_brier_score, ace, coverage

def train_deephit(X_train, y_train, val):
    input_features, out_features = X_train.shape[1], labtrans.out_features
    num_nodes = [128, 128, 64]
    batch_norm, output_bias = True, False
    dropout = 0.2

    deephit = tt.practical.MLPVanilla(input_features, 
                                    num_nodes, 
                                    out_features, 
                                    batch_norm,
                                    dropout, 
                                    output_bias = output_bias)

    optimizer = tt.optim.AdamWR(decoupled_weight_decay = 0.01, 
                                cycle_eta_multiplier = 0.8,
                                cycle_multiplier = 2)

    main_model = DeepHitSingle(deephit, optimizer, alpha = 0.2, sigma = 0.1, duration_index = labtrans.cuts)
    batch_size = 128
    epochs = 100
    callbacks = [tt.callbacks.EarlyStopping()]
    verbose = False

    lrfind = main_model.lr_finder(X_train, y_train, batch_size, tolerance = 50)
    main_model.optimizer.set_lr(0.0001) 
    log = main_model.fit(X_train, y_train, 
                        batch_size, 
                        epochs, 
                        callbacks, 
                        verbose,
                        val_data = val, 
                        val_batch_size = batch_size)
    return main_model, log

def custom_predict(data):
    if isinstance(data, np.ndarray):
        data = torch.tensor(data, dtype=torch.float32)
    return main_model.net(data).detach().numpy()

def get_shap(X_train, X_tests):
    X_train_np = X_train.values if hasattr(X_train, 'values') else X_train
    X_tests_np = X_tests.values if hasattr(X_tests, 'values') else X_tests

    explainer = shap.Explainer(custom_predict, X_train_np)
    shap_values = explainer(X_tests_np)
    return shap_values, X_train_np, X_tests_np

def get_calibration_plot(y_test, y_calib, calibrated, prediction, pathway):
    fraction_of_positives_calibrated, mean_predicted_value_calibrated = calibration_curve(
        y_calib['status'], calibrated, n_bins=20)
    fraction_of_positives_uncalibrated, mean_predicted_value_uncalibrated = calibration_curve(
        y_test['status'], prediction, n_bins=20)

    try:
        inner, outer = pathway.split('_')[0], pathway.split('_')[2]
    except IndexError:
        inner, outer = pathway, "Unknown"
    

    plt.figure(figsize=(8, 7))
    plt.plot(mean_predicted_value_calibrated, fraction_of_positives_calibrated, 
             marker = 'o', markersize=8, color='blue', label='Calibrated')
    plt.plot(mean_predicted_value_uncalibrated, fraction_of_positives_uncalibrated, 
             marker = 'o', markersize=8, color='red', label='Uncalibrated')
    plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Perfect Calibration')
    
    plt.xlabel('Mean Predicted Probability', fontsize=15)
    plt.ylabel('Fraction of Positives', fontsize=15)
    plt.title(f'{inner} to {outer} Calibration Plot', fontsize=18)
    plt.legend(fontsize = 12)
    plt.grid(alpha = 0.3)

In [62]:
def get_loss_plot(pathway, log, loss_path):
    formatted_transition = pathway.replace('_to_', r' $\rightarrow$ ')
    loss_df = log.to_pandas()
    plt.figure(figsize = (8, 6))
    plt.plot(loss_df.index, loss_df['train_loss'], label = 'Training',        marker = '.', linewidth = 2, color = 'blue')
    plt.plot(loss_df.index, loss_df['val_loss'],   label = 'Validation Loss', marker = '.', linewidth = 2, color = 'red')
    plt.xlabel('Epochs', fontsize = 15, fontname = 'Arial')
    plt.ylabel('Loss',   fontsize = 15, fontname = 'Arial')
    plt.title(f'DeepHit Loss: {formatted_transition}', fontsize=18, fontname='Arial')
    plt.xticks(fontsize = 13, fontname = 'Arial')
    plt.yticks(fontsize = 13, fontname = 'Arial')
    plt.legend(fontsize = 13)
    plt.savefig(loss_path + f'{pathway}.png', dpi = 300, bbox_inches = 'tight')
    plt.close()  
    
def get_shap_local_plot(pathway, shap_values, X_tests_np, df_tests, shap_path):
    plt.rcParams['font.family'] = 'Arial'
    rename = pd.read_csv(docs_path + 'rename_columns_forest.csv')
    rename_dict = dict(zip(rename['variable'], rename['covariate']))
    renaming = df_tests.copy()
    renaming = renaming.rename(columns=rename_dict)

    plt.figure(figsize=(10, 12))
    shap.summary_plot(shap_values, X_tests_np, feature_names = renaming.columns, show = False)
    plt.savefig(shap_path + f'local_{pathway}.png', dpi=300, bbox_inches='tight')
    plt.close()

def get_shap_globe_plot(pathway, shap_values, X_tests_np, df_tests, shap_path):
    plt.rcParams['font.family'] = 'Arial'
    rename = pd.read_csv(docs_path + 'rename_columns_forest.csv')
    rename_dict = dict(zip(rename['variable'], rename['covariate']))
    renaming = df_tests.copy()
    renaming = renaming.rename(columns = rename_dict)

    formatted_transition = pathway.replace('_to_', r' $\rightarrow$ ')
    features_df = pd.DataFrame(X_tests_np, columns = renaming.columns[:-2])
    feature_means = features_df.mean().sort_values()
    colors = feature_means.apply(lambda x: 'blue' if x > 0 else 'red')
    plt.figure(figsize = (10, 12))
    feature_means.plot(kind = 'barh', color = colors)
    plt.yticks(fontsize = 16)
    plt.xticks(fontsize = 14)
    plt.xlabel('Importance Mean', fontsize = 20)
    plt.ylabel('Feature', fontsize = 20)
    plt.title(f'DeepHit Feature Importance: {formatted_transition}', fontsize = 20)
    plt.tight_layout()
    plt.savefig(shap_path + f'global_{pathway}.png', dpi = 300, bbox_inches = 'tight')
    plt.close()  

In [63]:
date = datetime.now().strftime("%Y-%m-%d")
save_main = resu_path + 'modeling/deephit/'
loss_path = save_main + date + '/loss/'
shap_path = save_main + date + '/shap/'
if not os.path.exists(save_main + date):
    os.mkdir(save_main + date)
    os.mkdir(loss_path)
    os.mkdir(shap_path)

In [65]:
CINDEX, BRIER = [], []
RINDEX, RRIER = [], []
ACE, COVERAGE = [], []
TIME = []
pathways = long_df['pathway'].unique().tolist()
for path in ['CKD3A_to_DEAD', 'CKD3A_to_CKD4', 'CKD3A_to_CKD5A', 'CKD3A_to_CKD5B', 'CKD3B_to_CKD5A', 'CKD3B_to_CKD5B', 'CKD4_to_CKD5B']:
    pathways.remove(path)

for pathway in pathways:
    X_train, y_train, val, X_tests, durations_test, events_test, df_tests = load_input(pathway)

    start = time()
    main_model, log = train_deephit(X_train, y_train, val)
    c_index, brier_score, r_index, recab_brier, ace, coverage = evaluate_deephit(main_model, val, X_tests, durations_test, events_test)
    fit_predict_time = np.round(time() - start, 3)

    CINDEX.append(c_index), BRIER.append(brier_score)
    RINDEX.append(r_index), RRIER.append(recab_brier)

    ACE.append(ace), COVERAGE.append(coverage), TIME.append(fit_predict_time)
    
    shap_values, X_train_np, X_tests_np = get_shap(X_train, X_tests)
    get_loss_plot(pathway, log, loss_path)
    get_shap_local_plot(pathway, shap_values, X_tests_np, df_tests, shap_path)
    get_shap_globe_plot(pathway, shap_values, X_tests_np, df_tests, shap_path)

scores = (pathways, TIME, CINDEX, BRIER, RINDEX, RRIER, ACE, COVERAGE)
scores = pd.DataFrame(scores).T
scores.columns = ['transition', 'time', 'cindex', 'brier', 'r_cindex', 'r_brier', 'ace', 'coverage']
scores.to_csv(save_main + date + 'results.csv', index = False)

Unnamed: 0,transition,time,cindex,brier,r_cindex,r_brier,ace,coverage
0,CKD3A_to_CKD3B,16.092,0.52621,0.251224,0.502848,1928.832388,39.997302,0.0
