# 1. Import

In [1]:
import joblib
import os

import numpy as np
import pandas as pd

import shap
import matplotlib.pyplot as plt

import lightgbm as lgbm
from catboost import CatBoostClassifier
from sklearn.linear_model import LogisticRegression

  from .autonotebook import tqdm as notebook_tqdm



# 2. Load models

In [2]:
model_dir = "../model"
# Load CatBoost model
catboost_model = CatBoostClassifier()
catboost_model.load_model(os.path.join(model_dir, "catboost_model.cbm"))

# Load LightGBM model
lightgbm_model = joblib.load(os.path.join(model_dir, "LighGBM_model.pkl"))

# Load Logistic Regression model
lr_model = joblib.load(os.path.join(model_dir, "LR_model.pkl"))

# Load Random Forest model
rf_model = joblib.load(os.path.join(model_dir, "random_forest_model.pkl"))

# 3. Load test X

In [3]:
test_X = pd.read_csv('../data/test/X_test.csv')

# 4. Shap calculation

In [4]:
# SHAP values for CatBoost
explainer_cb = shap.TreeExplainer(catboost_model)
shap_values_cb = explainer_cb.shap_values(test_X)

# SHAP values for LightGBM
explainer_lgb = shap.TreeExplainer(lightgbm_model)
shap_values_lgb = explainer_lgb.shap_values(test_X)

# SHAP values for Logistic Regression
explainer_lr = shap.LinearExplainer(lr_model, test_X)
shap_values_lr = explainer_lr.shap_values(test_X)
shap_values_lr = np.array(shap_values_lr, dtype=float)

# SHAP values for Random Forest
explainer_rf = shap.TreeExplainer(rf_model)
shap_values_rf = explainer_rf.shap_values(test_X)



# 5. Plot SHAP

In [9]:
# CatBoost SHAP Values
plt.figure(figsize=(10, 10))
shap.summary_plot(shap_values_cb, test_X, show=False)
plt.title("CatBoost SHAP Values")
plt.tight_layout()
plt.savefig('../plots/shap_values_catboost.png')
plt.close()

# LightGBM SHAP Values
plt.figure(figsize=(10, 10))
shap.summary_plot(shap_values_lgb, test_X, show=False)
plt.title("LightGBM SHAP Values")
plt.tight_layout()
plt.savefig('../plots/shap_values_lightgbm.png')
plt.close()

# Logistic Regression SHAP Values
plt.figure(figsize=(10, 10))
shap.summary_plot(shap_values_lr, test_X, show=False)
plt.title("Logistic Regression SHAP Values")
plt.tight_layout()
plt.savefig('../plots/shap_values_logistic.png')
plt.close()

# Random Forest SHAP Values
plt.figure(figsize=(10, 10))
shap.summary_plot(shap_values_rf[:,:,1], test_X, show=False)
plt.title("Random Forest SHAP Values")
plt.tight_layout()
plt.savefig('../plots/shap_values_random_forest.png')
plt.close()

<Figure size 640x480 with 0 Axes>