In [None]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns
import patchworklib as pw

from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression

from sklearn.metrics import (
    roc_auc_score, roc_curve,
    average_precision_score, precision_recall_curve,
    confusion_matrix, ConfusionMatrixDisplay,
    classification_report,
    recall_score, precision_score,
    PrecisionRecallDisplay
)
from sklearn.neighbors import NearestNeighbors
from sklearn.inspection import permutation_importance

from plotnine import *

from sklearn.ensemble import RandomForestClassifier
from joblib import Parallel, delayed
import multiprocessing
import math

### evaluation functions 


In [None]:
def preprocess(xtrain, xtest ):
    cols_to_impute = [col for col in xtrain.columns if  col == "age_at_prediction_window"]

    if len(cols_to_impute) == 0:
        print('\tpreprocess column', None)
        return xtrain, xtest, None

    scaler = StandardScaler()
    
    scaler.fit(xtrain[cols_to_impute])
    print('\tpreprocess column', cols_to_impute)
    xtrain[cols_to_impute] = scaler.transform(xtrain[cols_to_impute])
    xtest[cols_to_impute] = scaler.transform(xtest[cols_to_impute])

    return xtrain, xtest, scaler

   
def preprocess_unmatched(xtrain ):
    cols_to_impute = [col for col in xtrain.columns if  col == "age_at_prediction_window"]

    if len(cols_to_impute) == 0:
        print('\tpreprocess column', None)
        return xtrain
    
    scaler = StandardScaler()
    
    scaler.fit(xtrain[cols_to_impute])
    print('\tpreprocess column', cols_to_impute)
    xtrain[cols_to_impute] = scaler.transform(xtrain[cols_to_impute])

    return xtrain


def preprocess_unmatched_scalar(xinput, scalar):
    cols_to_impute = [col for col in xinput.columns if  col == "age_at_prediction_window"]
    if len(cols_to_impute) == 0:
        print('\tpreprocess column', None)
        return  xinput

    xinput[cols_to_impute] = scalar.transform(xinput[cols_to_impute])
    print('Use pretrained Scalar ', scalar)
    return xinput


def calculate_metrics(y_true, y_pred_prob, y_pred):
    auc = roc_auc_score(y_true, y_pred_prob)
    avpre = average_precision_score(y_true, y_pred_prob)
    return auc, avpre, None, None, None, None


def ppv_sensitivity(specificity_levels, _y_true, _y_pred_proba):
    add_sensitivity_results = []
    add_ppv_results =  []
    add_fpr, add_tpr, add_thresholds = roc_curve(_y_true, _y_pred_proba)
    _results = {}
    for specificity in specificity_levels:

        _threshold_index = np.where(add_fpr <= (1 - specificity))[0][-1]
        _threshold = add_thresholds[_threshold_index]

        # Sensitivity (True Positive Rate)
        _sensitivity = add_tpr[_threshold_index]
        add_sensitivity_results.append( _sensitivity)

        # Positive Predictive Value (PPV)
        _y_pred_binary = (_y_pred_proba >= _threshold).astype(int)
        _ppv = precision_score(_y_true, _y_pred_binary)
        add_ppv_results.append(_ppv)

    # Add metrics to results
    _results['Sensitivity'] = add_sensitivity_results
    _results['PPV'] = add_ppv_results
    return _results



def evaluate_base_model(X_test, y_test, model, CP_num, prediction_window, N, feature_map, model_name, plot=True):
    best_mod = model

    predicted_proba = best_mod.predict_proba(X_test)[:, 1]
    predicted_labels = best_mod.predict(X_test)

    auc, pre,sensitivity , specificity , ppv, npv = calculate_metrics(y_test, predicted_proba, predicted_labels)
    
    add_results = ppv_sensitivity([0.9, 0.95], y_test, predicted_proba)
    sensitivity_90, sensitivity_95 = add_results['Sensitivity']
    ppv_90, ppv_95 = add_results['PPV']
    
    if plot:
        print(f"\nDisplaying performance for CP {CP_num} with a {prediction_window}-year prediction window:\n")

        prec, rec, threshold = precision_recall_curve(y_test, predicted_proba)
        prc_df = pd.DataFrame({"Recall": rec, "Precision": prec})
        ap_score = average_precision_score(y_test, predicted_proba)
        base_ap_score = np.mean(y_test)
        prc_plot = (
            ggplot(prc_df, aes("Recall", "Precision")) + 
            geom_line(color="#3C5488B2") +
            theme_bw() +
            theme() +
            coord_fixed() +
            geom_hline(yintercept=base_ap_score, linetype="dashed") +
            labs(title="Precision-Recall Curve") +
            annotate("text", x=0.15, y=1, label=f"AP={ap_score:.2f}", size=8) +
            annotate("text", x=0.3, y=0.95, label=f"Chance Level AP={base_ap_score:.2f}", size=8)
            )
        ax1 = pw.load_ggplot(prc_plot, figsize=(2.5, 2.5))
    
        ax2 = pw.Brick(figsize=(2.5, 2.5))
        cm = confusion_matrix(y_test, predicted_labels, labels=best_mod.classes_)
        sns.heatmap(cm, annot=True, linewidth=1, cmap="GnBu", fmt="g",
                    yticklabels=["Control", "Case"], xticklabels=["Control", "Case"], ax=ax2)
        ax2.set_title("Confusion Matrix")
        ax2.set_xlabel("Predicted Label")
        ax2.set_ylabel("True Label")
    
        coefs = best_mod.coef_[0]
        top_N_feature_index = np.argsort(abs(coefs))[-N:]
        top_N_feature_names = X_test.columns[top_N_feature_index]

        def get_name(x):
            if isinstance(x, str):
                return feature_map.get(x.strip(), x)
            else:
                return feature_map.get(x, x)
        top_N_feature_names = pd.Series(top_N_feature_names).apply(get_name)


        top_N_coefs_abs = abs(coefs)[top_N_feature_index]
        top_N_coefs = coefs[top_N_feature_index]
        coef_plot_df = pd.DataFrame({"feature_name": top_N_feature_names,
                                "abs_coef": top_N_coefs_abs,
                                "coef": top_N_coefs})
        coef_plot_df['feature_name'] = pd.Categorical(coef_plot_df['feature_name'], categories=coef_plot_df.sort_values('abs_coef')['feature_name'])

        feature_importance_plot = (
            ggplot(coef_plot_df, aes("feature_name", "coef")) +
                geom_bar(stat="identity", fill="#91D1C2B2", color="black") +
                coord_flip() +
                theme_bw() +
                labs(x="", y="Feature Coefficients", title=f"{model_name} Model (CP {CP_num} with a {prediction_window}-yr prediction window)")
            )
        ax3 = pw.load_ggplot(feature_importance_plot, figsize=(5, 4))
        ax_all = (ax1 | ax2)/ax3
    else:
        ax_all = None
    results_list = [auc , pre, sensitivity , specificity , ppv, npv , sensitivity_90, sensitivity_95, ppv_90, ppv_95]

    return ax_all, results_list


def parallel_forest_predict_proba(forest, X, n_jobs=-1):
    cores = multiprocessing.cpu_count()
    n_jobs = max(1, math.floor(cores*0.7))

    def predict_proba_tree(tree, X_subset):
        leaf_ids = tree.apply(X_subset)
        
        leaf_values = tree.tree_.value  # Shape: (n_nodes, n_classes)
        
        probas = leaf_values[leaf_ids][:, 0]   
        
        probas = probas / probas.sum(axis=1, keepdims=True)
        return probas

    tree_probas = Parallel(n_jobs=n_jobs)(
        delayed(predict_proba_tree)(tree, X) for tree in forest.estimators_
    )
    # print(len(tree_probas), tree_probas[0])
    all_probas = np.sum(tree_probas, axis=0)  # Shape: (n_samples, n_classes)
    # print(all_probas, all_probas.shape)
    avg_probas = all_probas / len(forest.estimators_)
    return all_probas


def predict_proba_stack( _loadmodel,  input_np_format):
    
    def predict_proba_tree(tree, X_subset):
        leaf_ids = tree.apply(X_subset)
        
        leaf_values = tree.tree_.value  # Shape: (n_nodes, n_classes)
        
        probas = leaf_values[leaf_ids][:, 0]   
        
        probas = probas / probas.sum(axis=1, keepdims=True)
        return probas
    
    from concurrent.futures import ThreadPoolExecutor

    def parallel_tree_predict_proba(trees, X):
        with ThreadPoolExecutor(max_workers=12) as executor:
            preds = list(executor.map(lambda tree: predict_proba_tree(tree, X), trees))
        return preds

    _loadmodel.n_classes_ = 2
    # _model.n_classes_ = 2
    saved_estimators  =  _loadmodel.estimators_ 
    start = time.time()

    for tree in saved_estimators:
        if not hasattr(tree, "monotonic_cst"):
            tree.monotonic_cst = None  # Set default value to None


    tree_probas = parallel_tree_predict_proba(saved_estimators, input_np_format)
    proba_pre = np.mean(tree_probas, axis=0)  
    return proba_pre

def evaluate_ensemble_model( X_test, y_test, model, CP_num, prediction_window, N, feature_map, model_name='', plot=True):

    best_mod = model

    saved_estimators = best_mod.estimators_

    best_mod.n_classes_ = 2
    best_mod.n_outputs_ = 1
    best_mod.classes_ = np.array([ 0,1]) 

    for tree in saved_estimators:
        if not hasattr(tree, "monotonic_cst"):
            tree.monotonic_cst = None  # Set default value to None

    X_test_input = X_test.values
    predicted_proba = predict_proba_stack(best_mod, X_test_input)[:, 1]

    # predicted_labels = best_mod.predict(X_test_input)
    predicted_labels = None
    auc, pre,sensitivity , specificity , ppv, npv = calculate_metrics(y_test, predicted_proba, predicted_labels)

    add_results = ppv_sensitivity([0.9, 0.95], y_test, predicted_proba)
    sensitivity_90, sensitivity_95 = add_results['Sensitivity']
    ppv_90, ppv_95 = add_results['PPV']
    plot = False
    if plot: # plot prc curve, confusion matrix, feature importance using default method from each type of models itself. Could skip this plot. 
        print(f"\nDisplaying performance for CP {CP_num} with a {prediction_window}-year prediction window:\n")

        prec, rec, threshold = precision_recall_curve(y_test, predicted_proba)
        prc_df = pd.DataFrame({"Recall": rec, "Precision": prec})
        ap_score = average_precision_score(y_test, predicted_proba)
        base_ap_score = np.mean(y_test)

        prc_plot = (
            ggplot(prc_df, aes("Recall", "Precision")) + 
            geom_line(color="#3C5488B2") +
            theme_bw() +
            theme() +
            coord_fixed() +
            geom_hline(yintercept=base_ap_score, linetype="dashed") +
            labs(title="Precision-Recall Curve") +
            annotate("text", x=0.15, y=1, label=f"AP={ap_score:.2f}", size=8) +
            annotate("text", x=0.3, y=0.95, label=f"Chance Level AP={base_ap_score:.2f}", size=8)
            )
        ax1 = pw.load_ggplot(prc_plot, figsize=(2.5, 2.5))


        ax2 = pw.Brick(figsize=(2.5, 2.5))
        cm = confusion_matrix(y_test, predicted_labels, labels=best_mod.classes_)
        sns.heatmap(cm, annot=True, linewidth=1, cmap="GnBu", fmt="g",
                    yticklabels=["Control", "Case"], xticklabels=["Control", "Case"], ax=ax2)
        ax2.set_title("Confusion Matrix")
        ax2.set_xlabel("Predicted Label")
        ax2.set_ylabel("True Label")

        feature_importances = best_mod.feature_importances_
        top_N_feature_index = np.argsort(feature_importances)[-N:]
        top_N_feature_names = X_test.columns[top_N_feature_index]

        def get_name(x):
            if isinstance(x, str):
                return feature_map.get(x.strip(), x)
            else:
                return feature_map.get(x, x)

        top_N_feature_names = pd.Series(top_N_feature_names).apply(get_name)

        top_N_importances = feature_importances[top_N_feature_index]

        coef_plot_df = pd.DataFrame({
            "feature_name": top_N_feature_names,
            "importance": top_N_importances
        })

        coef_plot_df['feature_name'] = pd.Categorical(coef_plot_df['feature_name'], categories=coef_plot_df.sort_values('importance')['feature_name'])

        feature_importance_plot = (
            ggplot(coef_plot_df, aes("feature_name", "importance")) +
                geom_bar(stat="identity", fill="#91D1C2B2", color="black") +
                coord_flip() +
                theme_bw() +
                labs(x="", y="Feature Importance", title=f"{model_name} Model (CP {CP_num} with a {prediction_window}-yr prediction window)")
            )
        ax3 = pw.load_ggplot(feature_importance_plot, figsize=(5, 4))
        ax_all = (ax1 | ax2)/ax3
    else:
        ax_all = None
    results_list = [auc , pre, sensitivity , specificity , ppv, npv , sensitivity_90, sensitivity_95, ppv_90, ppv_95]

    return ax_all , results_list


def run_direct_evaluate_pipeline(X, y, all_map, model_type,  years=[0,1,2,5,10], pre_trained_model=None, score=None, f_reference=None, model_name=None, scalars=None, plot=False):
    cp_year_results = {}
    show_results_dict = {}
    # shap_years = {}

    for cp in [1]:
        for prediction_window in reversed(years):
            if pre_trained_model: # should always pass pretrained model inside
                cv_results = []
                for cv in range(1):  # test each cv from the pre-trained model or only test a part of cvs
                    print('CV: ', cv, '| Prediction window: ', prediction_window, '| Model type: ', model_type)
                    # model = pre_trained_model[prediction_window] 
                    # scalar = scalars[prediction_window] 
                    model = pre_trained_model
                    scalar = scalars
                    try:
                        saved_model_features = f_reference[f'CP_{cp}_{prediction_window}_yr'].drop('person_id', axis=1).columns
                    except:
                        refer_cols = f_reference[f'CP_{cp}_{prediction_window}_yr']

                        if not isinstance(refer_cols, list):
                            refer_cols= f_reference[f'CP_{cp}_{prediction_window}_yr'].to_list()
                        saved_model_features = [i for i in  refer_cols if i != 'person_id']
                    # else:
                    #     saved_model_features = model.feature_names_in_

                    current_features = X[prediction_window].columns
                    overlapping_cols = set(saved_model_features).intersection(set(current_features))
                    missing_cols = set(saved_model_features) - set(current_features)
                    print('Current cols\t', len(current_features))
                    print('Reference cols\t', len(saved_model_features))
                    print('Overalaping cols\t', len(overlapping_cols))
                    print('Missing_cols cols\t', len(missing_cols), missing_cols)
                    
                    # f_input = X[prediction_window][list(overlapping_cols)]

                    f_input = X[prediction_window].reindex(columns=saved_model_features, fill_value=0)

                    # print(  f_input.dtypes)
                    
                    if  model_type == 'xgb':
                        X_newnames = dict(zip(f_input.columns, range(len(f_input.columns))))
                        X_newnames['age_at_prediction_window'] = 'age_at_prediction_window'
                        f_input.columns = [X_newnames[i] for i in f_input.columns]
                        # if xgboost, the mapping dictionarity will be converted, the key will be changed to the index of column
                        featuremap2 = {}
                        for k, v in X_newnames.items():
                            j = all_map.get(k, k)
                            featuremap2[v] = j
                        all_map = featuremap2

                    # f_input = preprocess_unmatched(f_input)
                    f_input = preprocess_unmatched_scalar(f_input, scalar)

                    # print("New data range (min, max):")
                    # print(f_input.min().min(), f_input.max().max())
                    y_input = y[prediction_window]
                    if model_type == 'rf' or model_type == 'xgb':
                        axall, results_list = evaluate_ensemble_model(f_input, y_input, model, 1, prediction_window, 30, all_map, model_name=model_name, plot=plot)
                    else:
                        axall, results_list = evaluate_base_model(f_input, y_input, model, 1, prediction_window, 30, all_map, model_name=model_name, plot=plot)
                    if plot: 
                        display(axall)
#                     shapfi = shap_values(model, f_input, f_input, model_type, cols_names)
#                     shap_years[f"{str(cp)}_{str(prediction_window)}_{model_type}_{str(cv)}"] = shapfi
                    cp_year_results[f"{str(prediction_window)}_{model_type}_{str(cv)}"] = results_list # put all cv and prediction_window evaluation results to a dictionary
                    cv_results.append(results_list) # meanwhile, show the results

                showresults = pd.DataFrame(cv_results)
                showresults.columns = ['auc', 'pre', 'sensitivity', 'specificity', 'ppv', 'npv', 'sensitivity_90', 'sensitivity_95', 'ppv_90', 'ppv_95']
                showresults.loc['mean'] = showresults.mean()
                showresults.loc['std'] = showresults.std()
                display('Show results of cvs', showresults)
                show_results_dict[f"{str(cp)}_{str(prediction_window)}_{model_type}"] = showresults

    return show_results_dict, cp_year_results 


def reconstruct_rf(name, dirc='rf_chunks', modelkey=None, modelcv=None):
    model_name = name
    output_dir_model = os.path.join(dirc, model_name)
    
    if not os.path.exists(output_dir_model):
        raise FileNotFoundError(f"Error: The directory '{output_dir_model}' does not exist. Make sure the chunks are saved.")
    
    chunk_files = sorted([os.path.join(output_dir_model, f) for f in os.listdir(output_dir_model) if f.endswith(".pkl") and ('chunk_' in f)])
    key_chunk_files = []
    for ifile in chunk_files:
        if str(modelkey) + '_' + str(modelcv) in ifile: 
                key_chunk_files.append(ifile )
    chunk_files = key_chunk_files

    if len(chunk_files) == 0:
        raise FileNotFoundError(f"No chunk files found in '{output_dir_model}'. Ensure the chunks were saved correctly.")
    
    reconstructed_rf = RandomForestClassifier()
    reconstructed_rf.estimators_ = []
    print(f"Ori", len(reconstructed_rf.estimators_))

    for chunk_file in chunk_files:
        chunk = joblib.load(chunk_file)
        reconstructed_rf.estimators_.extend(chunk)
        print(f"Loaded {chunk_file}", len(chunk), len(reconstructed_rf.estimators_))
    
    reconstructed_rf.n_estimators = len(reconstructed_rf.estimators_)

    reconstructed_rf.modelname = name
    reconstructed_rf.modelkey = modelkey
    reconstructed_rf.modelcv = modelcv
    
    print(f"Reconstructed RF model with {reconstructed_rf.n_estimators} estimators.")
    return reconstructed_rf

<Figure size 100x100 with 0 Axes>

### directly evaluate existing models on unmatched testing data

In [None]:
all_map = pickle.load(open('all_map.pkl', 'rb'))
for i, v in all_map.items():
    all_map[i] = i + ' ' + v

test previous downloaded model 

In [5]:

import joblib
hold_out_portion = 0.5
ratio = 10
test_f = pickle.load( open(f'./MiddleFeatures/demo_unmatched_fs.pkl', 'rb'))
test_t = pickle.load( open(f'./MiddleFeatures/test_t_portion_{str(hold_out_portion).split('.')[-1]}.pkl', 'rb'))
test_e = pickle.load( open(f'./MiddleFeatures/test_e_portion_{str(hold_out_portion).split('.')[-1]}.pkl', 'rb'))


In [None]:


best_cvs = [4, 0, 1, 1, 0]
windows = [10, 5, 2, 1, 0]

reference_cols = joblib.load('./rf_chunks/model_test_rf_all_feature/reference_cols_xgb_all_feature.pkl')
scalarmodels = joblib.load('./rf_chunks/model_test_rf_all_feature/test_rf_scalar_all_feature.pkl')


show_results_dict, results = {}, {}

for inde in range(5):

    print('Finetune for prediction window', windows[inde])

    pretrainedRF = reconstruct_rf('model_test_rf_all_feature', modelkey=f'1_{windows[inde]}_rf', modelcv=best_cvs[inde])

    pretrained_scalar = scalarmodels[f'1_{windows[inde]}_rf'][best_cvs[inde]] 

    show_results_dict[windows[inde]], results[windows[inde]] = run_direct_evaluate_pipeline(test_f, test_t, all_map, 'rf',  years=[windows[inde]],  pre_trained_model=pretrainedRF,\
                             score=None , model_name='all feature', scalars=pretrained_scalar, f_reference=reference_cols, plot=False)



In [11]:

pickle.dump(results, open(f'rf_chunks/model_test_rf_all_feature_direct/demoonlyage_rs_years_unmatched.pkl', 'wb'))

pickle.dump(show_results_dict, open(f'rf_chunks/model_test_rf_all_feature_direct/demoonlyage_showrs_years_unmatched.pkl', 'wb'))
