# import library

In [None]:
# Importing the necessary libraries and data
import pandas as pd
import os
import numpy as np
from pathlib import Path
import itertools

import plotly.express as px
import plotly.io as pio
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter


# sklearn
from sklearn.metrics import roc_auc_score, confusion_matrix, classification_report, ConfusionMatrixDisplay
from sklearn import metrics
from sklearn.model_selection import RandomizedSearchCV, train_test_split,  GridSearchCV, KFold
from sklearn.ensemble import GradientBoostingClassifier,RandomForestClassifier
from sklearn.pipeline import Pipeline
from sklearn.calibration import CalibratedClassifierCV
from sklearn.utils.class_weight import compute_sample_weight
from yellowbrick.classifier import ROCAUC, ClassificationReport, ClassPredictionError, ConfusionMatrix

# from imblearn.under_sampling import RandomUnderSampler
# from imblearn.over_sampling import RandomOverSampler

# machine learning models
from sklearn.linear_model import LogisticRegression
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from catboost import CatBoostClassifier

#interpretml
from interpret import show
from interpret.glassbox._ebm._research import *
from interpret.data import ClassHistogram
from interpret.perf import ROC
from interpret.glassbox import ExplainableBoostingClassifier
from interpret.blackbox import ShapKernel, LimeTabular, MorrisSensitivity

from interpret.provider import InlineProvider
from interpret import set_visualize_provider

set_visualize_provider(InlineProvider())

import warnings
warnings.filterwarnings("ignore")


In [None]:
# plt.rcParams['font.sans-serif'] = ['SimHei']
# # Replace the Chinese fonts with the ones supported by your system.
# plt.rcParams['axes.unicode_minus'] = False  # Used to display the negative sign normally


# Dataset Configuration

In [None]:
work_dir = 'E:/Disk E/Grand Blue/Research studies/HDL_multiclass'
os.chdir(work_dir)
os.makedirs('./tables', exist_ok=True)
os.makedirs('./images/png', exist_ok=True)
os.makedirs('./images/pdf', exist_ok=True)

In [None]:
df_prim = pd.read_excel('data/TertileClass_Gensini_data.xlsx')
print(df_prim['Gensini_tertile_label'].value_counts())
df_prim.drop(['Gensini_total_Score', 'Gensini_tertile_label'], axis=1, inplace=True)
df_prim.info()


In [None]:
df = df_prim.copy()


df.info()

target_col ='Gensini_tertile'
print('df: ', Counter(df[target_col]))

# Define label 
# The sequential numerical codes are 0, 1, 2
class_names = ['Low', 'Moderate', 'High']


In [None]:
# df.columns.to_list()

In [None]:
# Exploring the dataset
# Target variable prediction (buggy for multicategorisation, only 2 categories shown)
X = df.drop(target_col, axis=1)
y = df[target_col] 

hist = ClassHistogram().explain_data(X, y, name="Full Data")
show(hist)

In [None]:
interval_cols=[ 'age', 'TC', 'TG', 'HDL-C', 'LDL-C', 'HDL-2b', 'HDL-3' ]
df_k = df.copy()
df_k['Gensini_tertile'] = df_k ['Gensini_tertile'].replace( {0: 'Low', 1:'moderate', 2:'High'})

In [None]:
phik_overview = df_k.phik_matrix(interval_cols=interval_cols)

plot_correlation_matrix(phik_overview.values, x_labels=phik_overview.columns, y_labels=phik_overview.index, 
                        vmin=0, vmax=1, color_map='Greens', title=r'correlation $\phi_K$', fontsize_factor=1.5,
                        figsize=(30,30))
plt.tight_layout()

plt.savefig('images/png/correlation matrix.png', dpi=300)
plt.savefig('images/pdf/correlation matrix.pdf', dpi=300)

plt.show()

In [None]:
significance_overview = df_k.significance_matrix(interval_cols=interval_cols)
plot_correlation_matrix(significance_overview.values, 
                        x_labels=significance_overview.columns, 
                        y_labels=significance_overview.index, 
                        vmin=-5, vmax=5, title="Significance of the coefficients", 
                        usetex=False, fontsize_factor=1.5, figsize=(30, 30))
plt.savefig('images/png/Statistical significance.png', dpi=300)
plt.savefig('images/pdf/Statistical significance.pdf', dpi=300)

plt.tight_layout()

In [None]:
global_correlation, global_labels = df_k.global_phik(interval_cols=interval_cols)

plot_correlation_matrix(global_correlation, 
                        x_labels=[''], y_labels=global_labels, 
                        vmin=0, vmax=1, figsize=(15,15),
                        color_map="Greens", title=r"$g_k$",
                        fontsize_factor=1.5)
plt.tight_layout()
plt.savefig('images/png/Global correlation.png', dpi=300)
plt.savefig('images/pdf/Global correlation.pdf', dpi=300)


In [None]:
var_1 = "HDL-2b"
var_2 = "Gensini_tertile"

tmp_interval_cols = ['HDL-2b']

outlier_signifs, binning_dict = df_k[[var_1, var_2]].outlier_significance_matrix(interval_cols=tmp_interval_cols, 
                                                                        retbins=True)

zvalues = outlier_signifs.values
xlabels = outlier_signifs.columns
ylabels = outlier_signifs.index

plot_correlation_matrix(zvalues, x_labels=xlabels, y_labels=ylabels, 
                        x_label=var_2,y_label=var_1,
                        vmin=-5, vmax=5, title='outlier significance',
                        identity_layout=False, fontsize_factor=1.2, 
                        figsize=(14, 10))

plt.savefig(f'images/png/Outlier significance_{var_1}.png', dpi=300)
plt.savefig(f'images/pdf/Global correlation_{var_1}.pdf', dpi=300)


In [None]:
var_1 = "HDL-3"
var_2 = "Gensini_tertile"

tmp_interval_cols = ['HDL-3']

outlier_signifs, binning_dict = df_k[[var_1, var_2]].outlier_significance_matrix(interval_cols=tmp_interval_cols, 
                                                                        retbins=True)

zvalues = outlier_signifs.values
xlabels = outlier_signifs.columns
ylabels = outlier_signifs.index

plot_correlation_matrix(zvalues, x_labels=xlabels, y_labels=ylabels, 
                        x_label=var_2,y_label=var_1,
                        vmin=-5, vmax=5, title='outlier significance',
                        identity_layout=False, fontsize_factor=1.2, 
                        figsize=(14, 10))

plt.savefig(f'images/png/Outlier significance_{var_1}.png', dpi=300)
plt.savefig(f'images/pdf/Global correlation_{var_1}.pdf', dpi=300)


# resampling function

In [None]:
# def resample_data (df_res):
#     """
#     Data resampling function
#     """
#     X = df_res.drop(target_col, axis=1)
#     y = df_res[target_col]

#     X_train, X_test, y_train, y_test = train_test_split(
#         X, y, test_size=0.3, random_state=42, shuffle=True, stratify=y
#     )
#     print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)
#     print('y_train: ', Counter(y_train))
#     print('y_test: ', Counter(y_test))

#     # define oversampling strategy
#     over_strategy = {0: 270, 1: 270, 2: 450}
#     ros = RandomOverSampler(sampling_strategy=over_strategy, random_state=42)
#     X_train_ros, y_train_ros = ros.fit_resample(X_train, y_train)
#     print('oversampling: ',X_train_ros.shape, y_train_ros.shape)
#     print(Counter(y_train_ros))
#     print('\n')

#     # define undersampling strategy   
#     under_strategy = {0: 270, 1: 270, 2: 270}
#     rus = RandomUnderSampler(sampling_strategy=under_strategy, random_state=42)
#     X_train_res, y_train_res = rus.fit_resample(X_train_ros, y_train_ros) 
#     print('undersampling: ', X_train_res.shape, y_train_res.shape)
#     print(Counter(y_train_res))

#     # Merge training data and labels into a DataFrame
#     train_df = pd.concat([X_train, y_train], axis=1)
    
#     return X_train_res, y_train_res, X_test, y_test


# Model training (function definition)

In [None]:
# Defining multiple functions

def model_tuning(clf, model_name, search_space):
    """
    Function to perform model tuning
    """
    X = df.drop(target_col, axis=1)
    y = df[target_col]

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.3, random_state=42, shuffle=True, stratify=y
    )
    print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)
    print('y_train: ', Counter(y_train))
    print('y_test: ', Counter(y_test))
    
    # Merge training data and labels into a DataFrame
    train_df = pd.concat([X_train, y_train], axis=1)
    
       
    clf_name = model_name
    print('\n')
    print(clf_name) 
     
    pipe = Pipeline([('clf', clf)])

    kfold = KFold(n_splits=10, shuffle=True, random_state=42)
    scoring = {'F1_macro': metrics.make_scorer(metrics.f1_score, average='macro')}
    
    grid = GridSearchCV(
        pipe,
        param_grid=search_space,
        cv=kfold,
        scoring=scoring,
        refit='F1_macro',
        verbose=1,
        n_jobs=-1
    )

    grid_model = grid.fit(X_train, y_train)

    print('\nThe optimal parameters are：')
    print(grid_model.best_params_)

    print('\nThe optimal parameters are：')
    clf_best_model = grid_model.best_estimator_['clf']
    print(clf_best_model)

    return clf_best_model, X_train, X_test, y_train, y_test


# Function to create directories if they don't exist
def create_directory(path):
    if not os.path.exists(path):
        os.makedirs(path)


# Function to save classification reports to an Excel file
def save_classification_reports_to_excel(results, file_path):
    with pd.ExcelWriter(file_path, engine='openpyxl') as writer:
        for model_name, report in results.items():
            df_report = pd.DataFrame(report).transpose()
            df_report.to_excel(writer, sheet_name=model_name)

# Function to create and save individual plots
def save_individual_plots(clf_best_model, X_train, y_train, X_test, y_test, model_name, save_path):
    # create_directory(save_path)
    
    # ROC AUC
    fig, ax = plt.subplots()
    visualizer_rocauc = ROCAUC(clf_best_model, classes=class_names, ax=ax)
    visualizer_rocauc.fit(X_train, y_train)
    visualizer_rocauc.score(X_test, y_test)
    visualizer_rocauc.finalize()
    plt.title('')
    plt.savefig(f'{save_path}/png/{model_name.lower().replace(" ", "_")}_roc_auc.png', dpi=300)
    plt.savefig(f'{save_path}/pdf/{model_name.lower().replace(" ", "_")}_roc_auc.pdf', dpi=300)
    plt.close(fig)
    
    # Classification Report
    fig, ax = plt.subplots()
    visualizer_class_report = ClassificationReport(clf_best_model, classes=class_names, support=True, ax=ax)
    visualizer_class_report.fit(X_train, y_train)
    visualizer_class_report.score(X_test, y_test)
    visualizer_class_report.finalize()
    plt.title('')
    plt.savefig(f'{save_path}/png/{model_name.lower().replace(" ", "_")}_class_report.png', dpi=300)
    plt.savefig(f'{save_path}/pdf/{model_name.lower().replace(" ", "_")}_class_report.pdf', dpi=300)
    plt.close(fig)
    
    # Prediction Error
    fig, ax = plt.subplots()
    visualizer_pred_error = ClassPredictionError(clf_best_model, classes=class_names, ax=ax)
    visualizer_pred_error.fit(X_train, y_train)
    visualizer_pred_error.score(X_test, y_test)
    visualizer_pred_error.finalize()
    plt.title('')
    plt.savefig(f'{save_path}/png/{model_name.lower().replace(" ", "_")}_pred_error.png', dpi=300)
    plt.savefig(f'{save_path}/pdf/{model_name.lower().replace(" ", "_")}_pred_error.pdf', dpi=300)
    plt.close(fig)
    
    # Confusion Matrix
    fig, ax = plt.subplots()
    visualizer_conf_matrix = ConfusionMatrix(clf_best_model, classes=class_names, ax=ax)
    visualizer_conf_matrix.fit(X_train, y_train)
    visualizer_conf_matrix.score(X_test, y_test)
    visualizer_conf_matrix.finalize()
    plt.title('')
    plt.savefig(f'{save_path}/png/{model_name.lower().replace(" ", "_")}_conf_matrix.png', dpi=300)
    plt.savefig(f'{save_path}/pdf/{model_name.lower().replace(" ", "_")}_conf_matrix.pdf', dpi=300)
    plt.close(fig)

# Function to create combined subplots
def create_combined_subplot(fig_title, subplot_func):
    num_models = len(models)
    rows = (num_models + 1) // 2
    fig, axes = plt.subplots(rows, 2, figsize=(12, 6 * rows), sharex=False, sharey=False)

    for i, (clf, model_name, search_space) in enumerate(models):
        ax = axes[i // 2, i % 2]
        clf_best_model, X_train, X_test, y_train, y_test = model_tuning(clf, model_name, search_space)
        subplot_func(clf_best_model, X_train, y_train, X_test, y_test, ax, model_name)
    
    # Hide empty subplots
    if num_models % 2 != 0:
        fig.delaxes(axes[rows - 1, 1])
    
    fig.tight_layout()
    fig.suptitle(fig_title, y=1.02)
    return fig

# Function for ROC AUC
def plot_rocauc(clf_best_model, X_train, y_train, X_test, y_test, ax, model_name):
    visualizer_rocauc = ROCAUC(clf_best_model, classes=class_names, ax=ax)
    visualizer_rocauc.fit(X_train, y_train)
    visualizer_rocauc.score(X_test, y_test)
    visualizer_rocauc.finalize()
    ax.set_title(f'ROC AUC - {model_name}')

# Function for Classification Report
def plot_class_report(clf_best_model, X_train, y_train, X_test, y_test, ax, model_name):
    visualizer_class_report = ClassificationReport(clf_best_model, classes=class_names, support=True, ax=ax)
    visualizer_class_report.fit(X_train, y_train)
    visualizer_class_report.score(X_test, y_test)
    visualizer_class_report.finalize()
    ax.set_title(f'Classification Report - {model_name}')

# Function for Prediction Error
def plot_pred_error(clf_best_model, X_train, y_train, X_test, y_test, ax, model_name):
    visualizer_pred_error = ClassPredictionError(clf_best_model, classes=class_names, ax=ax)
    visualizer_pred_error.fit(X_train, y_train)
    visualizer_pred_error.score(X_test, y_test)
    visualizer_pred_error.finalize()
    ax.set_title(f'Prediction Error - {model_name}')

# Function for Confusion Matrix
def plot_conf_matrix(clf_best_model, X_train, y_train, X_test, y_test, ax, model_name):
    visualizer_conf_matrix = ConfusionMatrix(clf_best_model, classes=class_names, ax=ax)
    visualizer_conf_matrix.fit(X_train, y_train)
    visualizer_conf_matrix.score(X_test, y_test)
    visualizer_conf_matrix.finalize()
    ax.set_title(f'Confusion Matrix - {model_name}')


# New function to extract weighted avg metrics and plot heatmap
def plot_weighted_avg_heatmap(results, save_path):
    data = {
        'Model': [],
        'Precision': [],
        'Recall': [],
        'F1-score': []
        # 'Support': []
    }

    for model_name, report in results.items():
        weighted_avg = report['weighted avg']
        data['Model'].append(model_name)
        data['Precision'].append(weighted_avg['precision'])
        data['Recall'].append(weighted_avg['recall'])
        data['F1-score'].append(weighted_avg['f1-score'])
        # data['Support'].append(weighted_avg['support'])

    df_ave = pd.DataFrame(data)
    df_ave.set_index('Model', inplace=True)

    plt.figure(figsize=(10, 6))
    sns.heatmap(df_ave, annot=True, cmap='YlOrRd', fmt='.3f')
    # Note font colour set to black
    # sns.heatmap(df_ave, annot=True, cmap='YlOrRd', fmt='.3f',annot_kws={"color": "black"})
    # plt.title('Weighted Average Metrics for Different Models')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(f'{save_path}/png/weighted_avg_heatmap.png', dpi=300)
    plt.savefig(f'{save_path}/pdf/weighted_avg_heatmap.pdf', dpi=300)
    plt.show()


# Initial model name reference

In [None]:
# Building the basic model
# log_clf = LogisticRegression(multi_class='multinomial', solver='lbfgs', random_state=42)


# log_clf = LogisticRegression(random_state=42)
# lasso_reg = Lasso(random_state=42)
# elastic_net = ElasticNet(random_state=42)
# random_forest_clf = RandomForestClassifier(random_state=42)
# extra_trees_clf = ExtraTreesClassifier(random_state=42)

# Create Decision Tree classifer object
# dt_clf = DecisionTreeClassifier(random_state=42)

# svm_clf = SVC(probability=True, random_state=42)
# mlp_clf = MLPClassifier(max_iter=10000, random_state=42)


# xgb_clf = XGBClassifier(nthread=-1, random_state=42)

# ebm_clf = ExplainableBoostingClassifier(greedy_ratio=0.5, random_state=42)

# # lightgbm for classification
# lgb_clf = LGBMClassifier(random_state=42)

# # catboost for classification
# catb_clf = CatBoostClassifier(verbose=0, n_estimators=100)
# catb_clf = CatBoostClassifier(random_state=42)

In [None]:
# # XGBoost： compute sample weights for handling imbalanced data
# Merge training data and labels into a DataFrame
# train_df = pd.concat([X_train, y_train], axis=1)
# sample_weights = compute_sample_weight(
#     class_weight='balanced',
#     y=train_df[target_col] #provide your own target name
# )

# xgb_clf = XGBClassifier(nthread=-1, random_state=42,  sample_weight=sample_weights)


# Model hyperparameter configuration

In [None]:
# logistic Model Tuning parameters
search_space_log = [
    {
              # 'clf__solver' : ['newton-cg', 'lbfgs', 'liblinear'],
        #       'clf__penalty' : ['l2'],
        #       'clf__C' : [0.001,0.01,0.1,1,10,100,1000]
              # 'clf__C' : [0.01,0.1,1,10]
    }
]

# xgboost Model Tuning parameters
# Define our search space for grid search
search_space_xgb = [
    {
        # 'clf__n_estimators': [5000],
        # 'clf__learning_rate': [0.01, 0.1, 0.2, 0.3],
        # 'clf__max_depth': [3, 4,5],
        # 'clf__colsample_bytree': [i / 10.0 for i in range(1, 3)],
        # 'clf__gamma': [i / 10.0 for i in range(3)],
        # 'fs__score_func': [chi2],
        # 'fs__k': [10]
    }
]

# RandomForestClassifier Model tuning
search_space_rf = [
    {
    # 'clf__bootstrap': [True],
    # 'clf__max_depth': [80, 90, 100, 110],
    # 'clf__max_features': [2, 3],
    # 'clf__min_samples_leaf': [3, 4, 5],
    # 'clf__min_samples_split': [8, 10, 12],
    # 'clf__n_estimators': [100, 200, 300, 1000]
    }
]



# EBM Model tuning
# Define our search space for grid search
search_space_ebm = [
    {
        # "clf__validation_size": [0.1, 0.15, 0.2],
        # "clf__greedy_ratio": [0.0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 4.0]
       # "clf__greedy_ratio": [ 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 4.0]

    }
]

# search_space_ebm = [
#     {   "clf__max_bins": [1024, 4096, 16384, 65536],
#      # "clf__max_interaction_bins": [8, 16, 32, 64, 128, 256],
#      # "clf__outer_bags": [50],
#      # "clf__learning_rate": [0.02, 0.01, 0.005, 0.0025],
#      "clf__greedy_ratio": [0.0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 4.0],
#      # "clf__cyclic_progress": [0.0, 0.5, 1.0],
#      # "clf__smoothing_rounds": [0, 50, 100, 200, 500, 1000, 2000, 4000],
#      "clf__max_leaves": [3, 4]  
#     }
# ]


# LightGBM Model tuning
# Define our search space for grid search
# search_space_lgb = [
#     {
#         # 'clf__n_estimators': [5000],
#         # 'clf__learning_rate': [0.1],
#         # 'clf__max_depth': range(1,11),
#         # 'clf__boosting_type': ['gbdt', 'dart', 'goss']


#     }
# ]


# CatBoost Model tuning
# Define our search space for grid search
# search_space_catb = [
#     {
#         # 'clf__n_estimators': [10, 50, 100, 500, 1000, 5000],
#         # 'clf__learning_rate': [0.1],
#         # 'clf__max_depth': range(1,11)


#     }
# ]


# Model Comparison Type Configuration

In [None]:
# # List of models and their corresponding search spaces
# models = [
#     (LogisticRegression(multi_class='multinomial', solver='lbfgs',penalty= 'l2', random_state=42), 'Logistic Regression', search_space_log),
#     (XGBClassifier(nthread=-1, max_depth=4,  random_state=42), 'XGBoost', search_space_xgb),
#     (ExplainableBoostingClassifier(greedy_ratio=0.5, random_state=42), 'Explainable Boosting Machine', search_space_ebm),
#     (RandomForestClassifier(max_features = 3, random_state=42), 'Random Forest', search_space_rf)
    
# ]

In [None]:
# List of models and their corresponding search spaces
models = [
    (LogisticRegression(C=1, multi_class='multinomial', solver='lbfgs',penalty= 'l2', random_state=42), 'Logistic Regression', search_space_log),
    (XGBClassifier(nthread=-1, max_depth=3,  random_state=42), 'XGBoost', search_space_xgb),
    (ExplainableBoostingClassifier(greedy_ratio=4.0, random_state=42), 'Explainable Boosting Machine', search_space_ebm),
    (RandomForestClassifier(max_depth=80, random_state=42), 'Random Forest', search_space_rf)
    
]

# Machine learning model training (calling functions)

In [None]:
#--------------------call function --------------------

# Set seaborn theme to "darkgrid"
sns.set_theme(style="darkgrid")


classification_reports = {}
for clf, model_name, search_space in models:
    # Save individual plots for each model
    clf_best_model, X_train, X_test, y_train, y_test = model_tuning(clf, model_name, search_space)
    save_individual_plots(clf_best_model, X_train, y_train, X_test, y_test, model_name, './images')
    
    # Generate classification report
    y_pred = clf_best_model.predict(X_test)
    report = classification_report(y_test, y_pred, output_dict=True, target_names=class_names)
    classification_reports[model_name] = report

# Save all classification reports to a single Excel file
save_classification_reports_to_excel(classification_reports, './tables/classification_reports.xlsx')


# Plot and save the heatmap of weighted average metrics
plot_weighted_avg_heatmap(classification_reports, './images')


# Create and save combined ROC AUC subplot
fig_rocauc = create_combined_subplot('Combined ROC AUC', plot_rocauc)
fig_rocauc.savefig('./images/png/combined_rocauc.png', dpi=300)
fig_rocauc.savefig('./images/pdf/combined_rocauc.pdf', dpi=300)

# Create and save combined Classification Report subplot
fig_class_report = create_combined_subplot('Combined Classification Report', plot_class_report)
fig_class_report.savefig('./images/png/combined_class_report.png', dpi=300)
fig_class_report.savefig('./images/pdf/combined_class_report.pdf', dpi=300)

# Create and save combined Prediction Error subplot
fig_pred_error = create_combined_subplot('Combined Prediction Error', plot_pred_error)
fig_pred_error.savefig('./images/png/combined_pred_error.png', dpi=300)
fig_pred_error.savefig('./images/pdf/combined_pred_error.pdf', dpi=300)

# Create and save combined Confusion Matrix subplot
fig_conf_matrix = create_combined_subplot('Combined Confusion Matrix', plot_conf_matrix)
fig_conf_matrix.savefig('./images/png/combined_conf_matrix.png', dpi=300)
fig_conf_matrix.savefig('./images/pdf/combined_conf_matrix.pdf', dpi=300)

plt.show()


# EBM model interpretability

In [None]:
df2 = df_prim.copy()

# Converting Numeric Values to Labeled for EBM Modelling
columns_to_replace = [ 'hypertension', 'diabetes', 'stroke', 'kidney disease','Thyroid Dysfunction', 'COPD']  
# Replace with the column name list you want to replace

# Use a dictionary to specify the value to replace
replacement_dict = {1: 'Yes', 0: 'No'}

# Application Replacement
df2 [columns_to_replace] = df2 [columns_to_replace].replace(replacement_dict)

df2 ['sex'] = df2 ['sex'].replace( {1: 'Male', 0: 'Female'})
df2 ['Gensini_tertile'] = df2 ['Gensini_tertile'].replace({0: 'Low', 1:'moderate', 2:'High'})
df2 .head()
# print(set(df2 [target_col]))

In [None]:
X = df2.drop(target_col, axis=1)
y = df2[target_col]

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42, shuffle=True, stratify=y
)
print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)
print('y_train: ', Counter(y_train))
print('y_test: ', Counter(y_test))

In [None]:
# Training the Explainable Boosting Machine (EBM)
ebm = ExplainableBoostingClassifier(greedy_ratio=4.0, random_state=42 )
ebm.fit(X_train, y_train)

In [None]:
ebm.get_params()

In [None]:
ebm_global = ebm.explain_global(name='EBM')
show(ebm_global)

In [None]:
# Custom colours
# custom_colors = {'Low': 'green', 'Moderate': 'blue', 'High': 'orange'}

In [None]:
# Global Explanations - explaining the entire model behavior.
ebm_global = ebm.explain_global(name='EBM')
show(ebm_global)


# Save each chart in the global explanation to the ‘images’ directory on the disk.
for index, value in enumerate(ebm.term_names_):
    plotly_fig = ebm_global.visualize(index)
    # plotly_fig.write_image(f"images/fig_{value}.png", engine="kaleido")
    
    # # Custom colours
    # colors = {
    #     'mild': '#9fc377',    
    #     'moderate': '#0272a2', 
    #     'severe': '#ca0b03',   
    #     'Distribution': '#ff7f0e' 
    # }

    # line_width = 3  # Set the line width
    
    # # Modify legend colour (alternate code)
    # for trace in plotly_fig['data']:
    #     if 'name' in trace:
    #         class_name = trace['name']
    #         if 'line' in trace:
    #             trace['line']['color'] = colors.get(class_name, '#000000')  # Set the colour, default is black
    #             trace['line']['width'] = line_width  # Set the line width
    #         elif 'marker' in trace:
    #             trace['marker']['color'] = colors.get(class_name, '#000000')  # Set the colour, default is black
    #             trace['marker']['line']['width'] = line_width  # Set the line width

    
    # Save as PNG，300 dpi
    plotly_fig.write_image(f"images/png/fig_{value}.png", format="png", scale=3, engine="kaleido")
    # Save as PDF，300 dpi
    plotly_fig.write_image(f"images/pdf/fig_{value}.pdf", format="pdf", scale=3, engine="kaleido")

print('All images saved.')

# Combining global graph functions

In [None]:
def save_combined_plots(ebm_global, variables, list_var, scale=3):
    """
    Generate graphs of multiple variables and combine them into subgraphs of n rows and 1 column

    parameter:
        ebm_global: The object that generates the diagram should contain the visualize method
        variables: List with variable names
        list_var: List of all variable names
        scale: Scale of saved images
        output_path: Path to save the image
    """
    image_files = []
    
    for var in variables:
        index = list_var.index(var)
        
        # Generate graph
        plotly_fig = ebm_global.visualize(index)
        
        # Saving a single image
        image_path = f"images/png/fig_{var}.png"
        plotly_fig.write_image(image_path, format="png", scale=scale, engine="kaleido")
        image_files.append(image_path)
    
    num_vars = len(variables)
    fig, axs = plt.subplots(num_vars, 1, figsize=(8, 6*num_vars))  # Adjust figsize as needed

    if num_vars == 1:
        axs = [axs]  # Ensure axs is iterable

    for i, image_file in enumerate(image_files):
        img = mpimg.imread(image_file)
        axs[i].imshow(img)
        axs[i].axis('off')  # Hide the axes
        axs[i].set_title(variables[i])
    
    # Adjust the layout
    plt.tight_layout()
    
    # Generate file name
    combined_name = "_".join(variables)
    
    # Save as PNG，300 dpi
    plt.savefig(f"images/png/combined_fig_{combined_name}.png", format="png", dpi=300)
    # Save as PDF，300 dpi
    plt.savefig(f"images/pdf/combined_fig_{combined_name}.pdf", format="pdf", dpi=300)
    
    plt.show()




In [None]:
# Draw a combination chart
ebm_global = ebm.explain_global(name='EBM')
list_var = ebm.term_names_

# Calling a function
save_combined_plots(ebm_global, ['HDL-2b', 'HDL-3'], list_var)
save_combined_plots(ebm_global, ['LDL-C', 'TC'], list_var)
save_combined_plots(ebm_global, ['HDL-C', 'TG'], list_var)


# Other results output of EBM model

In [None]:
# Save summary graph
ebm_explanation = ebm.explain_global()
plotly_fig = ebm_explanation.visualize()

# Save as PNG，300 dpi
plotly_fig.write_image("images/png/fig_summary.png", format="png", scale=3, engine="kaleido")
# Save as PDF，300 dpi
plotly_fig.write_image("images/pdf/fig_summary.pdf", format="pdf", scale=3, engine="kaleido")
print('summary images saved.')

In [None]:
def create_and_export_dataframe(data, filename):
    # Create a DataFrame containing the data
    df_in = pd.DataFrame(data)
    
    # Sort the 'Importance' column in reverse order and reset the index
    df_in = df_in.sort_values(by='Importance', ascending=False).reset_index(drop=True)
    
    # Keep 3 decimal places in the 'Importance' column and generate a new column
    df_in['Importance value'] = df_in['Importance'].round(3)
    
    # Count the number of commas in each cell in the first column and add 1
    df_in["Number of elements"] = df_in.iloc[:, 0].apply(lambda x: str(x).count(',') + 1)

    # Print data frame
    print(df_in.head())
        
    #Writing DataFrame to Excel File
    df_in.to_excel(filename, index=False)
    print(f"\nResults have been written to {filename}")
    return df_in


In [None]:
# # Debugging Code
# import plotly.graph_objects as go

# fig = go.Figure(data=go.Bar(y=[2, 3, 1]))
# fig.write_image("images/fig1.png", engine="kaleido")


In [None]:
# Calling a function
importances = ebm.term_importances()
names = ebm.term_names_

# Calling common functions to create and export DataFrame
data = {'Term Name': names, 'Importance': importances}
df_m1 = create_and_export_dataframe(data, 'tables/term_importances.xlsx')


In [None]:
# Set drawing style
sns.set()

# Create a drawing
plt.figure(figsize=(18, 16))
barplot = sns.barplot(
    x='Importance value', 
    y='Term Name', 
    data=df_m1, 
    orient='h',
    order=df_m1.sort_values('Importance value', ascending=False)['Term Name']
)


 # Add value labels
for p in barplot.patches:
    barplot.annotate(
        format(p.get_width(), '.3f'),  # Take three decimal places
        (p.get_width(), p.get_y() + p.get_height() / 2.), 
        ha = 'left', 
        va = 'center',
        xytext = (5, 0), 
        textcoords = 'offset points'
    )


plt.title('Global Term/Feature Importances')
plt.xlabel('Mean Absolute Score (Weighted)')
plt.ylabel('')


plt.tight_layout()
plt.savefig('images/png/barplot_summary.png', format='png', dpi=300)
plt.savefig('images/pdf/barplot_summary.pdf', format='pdf', dpi=300)


plt.show()

In [None]:
# Define two variable sets
group1 = ['TC', 'TG', 'LDL-C', 'HDL-C']
group2 = ['HDL-2b', 'HDL-3']

# Generate all possible permutations and combinations, each of which 
# contains at least one variable from group1 and one variable from group2
combinations = []

for i in range(1, len(group1) + 1):
    for j in range(1, len(group2) + 1):
        for combo1 in itertools.combinations(group1, i):
            for combo2 in itertools.combinations(group2, j):
                combinations.append(list(combo1) + list(combo2))

# Initialize my_global_exp
my_global_exp = None

five_feature_group = ['TC', 'TG', 'LDL-C','HDL-2b', 'HDL-3']
all_other_terms = [term for term in ebm.term_names_ if term not in five_feature_group]
all_terms_group = [term for term in ebm.term_names_]

# Calculate the importances of each combination and update my_global_exp
for combo in combinations:
    if my_global_exp is None:
        my_global_exp = append_group_importance(combo, ebm, X_train)
        
    else:
        my_global_exp = append_group_importance(combo, ebm, X_train, global_exp=my_global_exp)


my_global_exp = append_group_importance(group1, ebm, X_train, global_exp=my_global_exp)
my_global_exp = append_group_importance(group2, ebm, X_train, global_exp=my_global_exp)
my_global_exp = append_group_importance(all_other_terms, ebm, X_train, global_exp=my_global_exp,group_name="all_other_terms_excluding_TC_TG_LDL-C_HDL-2b_HDL-3")
my_global_exp = append_group_importance(all_terms_group, ebm, X_train, global_exp=my_global_exp,group_name="all_terms")


# Show the final result
show(my_global_exp)


In [None]:
list_imp = combinations+[group1]+[group2]+ [all_other_terms] +[all_terms_group]

my_dict = get_group_and_individual_importances(list_imp, ebm, X_train)
# for key in my_dict:
#     print(f"Term: {key} - Importance: {my_dict[key]}")

In [None]:
# Get the importance of the combinations
my_dict = get_group_and_individual_importances(list_imp,
                                               ebm, X_train)

# Call the general function to create and export a DataFrame
data2 = {'Term': list(my_dict.keys()), 'Importance': list(my_dict.values())}
df_m2 = create_and_export_dataframe(data2, 'tables/group_importances.xlsx')



In [None]:
# Identify the rows with the maximum and second maximum "Number of elements" values
max_idx = df_m2['Number of elements'].idxmax()
second_max_idx = df_m2['Number of elements'].drop(max_idx).idxmax()

# Modify the "Term" column
df_m2.at[max_idx, 'Term'] = 'all_terms'
df_m2.at[second_max_idx, 'Term'] = 'all_other_terms_excluding_TC_TG_LDL-C_HDL-2b_HDL-3'


In [None]:
df_m2.head()

In [None]:
# Group by Number of elements
grouped_data = df_m2.groupby("Number of elements")

# # Plot bar charts for each group
# for group, group_data in grouped_data:
#     # Sort by Importance value in descending order and take the top 3
#     top3_data = group_data.nlargest(3, 'Importance value')

#     # Create the plot
#     plt.figure(figsize=(10, 6))
#     barplot = sns.barplot(
#         x='Importance value', 
#         y=df_m2.columns[0], 
#         data=top3_data, 
#         orient='h',
#         order=top3_data.sort_values('Importance value', ascending=False)[df_m2.columns[0]]
#     )

#     # Add value labels
#     for p in barplot.patches:
#         barplot.annotate(
#             format(p.get_width(), '.3f'),  # Keep three decimal places
#             (p.get_width(), p.get_y() + p.get_height() / 2.), 
#             ha = 'left', 
#             va = 'center',
#             xytext = (5, 0), 
#             textcoords = 'offset points'
#         )

#     # Set the title
#     plt.title(f'Number of elements in the group: {group}')
#     plt.xlabel('Mean Absolute Score (Weighted)')
#     plt.ylabel(df_m2.columns[0])

#     # Save the plot as PNG and PDF formats
#     plt.tight_layout()
#     # plt.subplots_adjust(top=0.9, bottom=0.1, left=0.1, right=0.9)
#     plt.savefig(f'images/png/barplot_group_{group}.png', format='png', dpi=300)
#     plt.savefig(f'images/pdf/barplot_group_{group}.pdf', format='pdf', dpi=300)

#     # Show the plot
#     plt.show()


In [None]:
# Create a 4-row, 2-column subplot
fig, axes = plt.subplots(4, 2, figsize=(16, 18))

# Iterate through the groups and plot bar charts
for (group, group_data), ax in zip(grouped_data, axes.flatten()):
    # Sort by Importance value in descending order and take the top 3
    top3_data = group_data.nlargest(3, 'Importance value')

    # Create the bar chart
    barplot = sns.barplot(
        x='Importance value', 
        y=df_m2.columns[0], 
        data=top3_data, 
        orient='h',
        order=top3_data.sort_values('Importance value', ascending=False)[df_m2.columns[0]],
        ax=ax,
        width=0.5
    )

    # Add value labels
    for p in barplot.patches:
        barplot.annotate(
            format(p.get_width(), '.3f'),  # Keep three decimal places
            (p.get_width(), p.get_y() + p.get_height() / 2.), 
            ha='left', 
            va='center',
            xytext=(5, 0), 
            textcoords='offset points'
        )

    # Set title and labels
    ax.set_title(f'Number of elements in the group: {group}', fontsize=14)
    ax.set_xlabel('Mean Absolute Score (Weighted)', fontsize=12)
    ax.set_ylabel('')
    # ax.set_yticklabels(ax.get_yticklabels(), fontsize=10, rotation=45, ha="right")  # Rotate y-axis labels

# Adjust layout to avoid overlap
plt.tight_layout()
# plt.subplots_adjust(top=0.9, bottom=0.1, left=0.05, right=0.95, hspace=0.4, wspace=0.3)

# Save the plot as PNG and PDF formats
plt.savefig('images/png/barplot_groups.png', format='png', dpi=300)
plt.savefig('images/pdf/barplot_groups.pdf', format='pdf', dpi=300)

# Show the plot
plt.show()


In [None]:
# Local Explanations: explaining individual predictions
ebm_local = ebm.explain_local(X_test[:80], y_test[:80], name='EBM')
show(ebm_local)

# Save the explanation plot for the 0th data point in Local Explanations
plotly_fig_local = ebm_local.visualize(1)

# # Custom colors
# colors = {
#     'mild': '#9fc377',    # Green
#     'moderate': '#0272a2', # Blue
#     'severe': '#ca0b03',   # Red    
# }

# # Modify legend colors (backup code)
# for trace in plotly_fig_local['data']:
#     if 'name' in trace:
#         class_name = trace['name']
#         trace['marker']['color'] = colors.get(class_name, '#000000')  # Set color, default to black

# Save as PNG, 300 dpi
plotly_fig_local.write_image("images/png/fig_local.png", format="png", scale=3, engine="kaleido")

# Note encoding bug: The 0, 1, 2 in the image need to be checked against the dropdown menu for specific values, 
# and text should be modified in PDF before taking a screenshot to Save as PNG

# Save as PDF, 300 dpi
plotly_fig_local.write_image("images/pdf/fig_local.pdf", format="pdf", scale=3, engine="kaleido")
print('Local Explanations images saved.')


In [None]:
data = ebm_local.data(1)
print(data)

In [None]:
# Write the underlying data for the selected cases used in local visualization explanations

# Extract feature names, scores, and values
feature_names = data['names'] + data['extra']['names']
scores = np.array(data['scores'] + data['extra']['scores'])
values = data['values'] + data['extra']['values']

# Create a DataFrame
df_local = pd.DataFrame(scores, index=feature_names, columns=data['meta']['label_names'])
df_local['values'] = values

# Write to an Excel file
df_local.to_excel('tables/feature_scores_local_explanations.xlsx')

print("Data has been written.")


In [None]:
df_local