In [1]:
# =====================================
# 🧠 SHAP Explainability for XGBoost
# =====================================

import pandas as pd
import shap
import xgboost as xgb
import joblib
import json
import matplotlib.pyplot as plt
from pathlib import Path

In [3]:
# Load model and test data
# model = joblib.load("../models/xgb_churn_model.pkl")
# ------------------------------------------------------------
# 1️⃣ Load your processed test data
# ------------------------------------------------------------
X_test = pd.read_csv("../data/processed/X_test_processed.csv")
y_test = pd.read_csv("../data/processed/y_test.csv").values.ravel()

# ==============================
# ✅ Quick Sanity Checks
# ==============================
print(f"X_test shape: {X_test.shape}")
print(f"y_test shape: {y_test.shape}")

X_test shape: (1409, 7072)
y_test shape: (1409,)


In [None]:
# ==============================
# 📊 Load Processed Test Data
# ==============================

# Define data paths (adjust if your structure differs)
# X_test_path = Path("../data/processed/X_test_processed.csv")
# y_test_path = Path("../data/processed/y_test.csv")

# Load test sets
# X_test = pd.read_csv(X_test_path)
# y_test = pd.read_csv(y_test_path).values.ravel()  # Flatten to 1D array

# ==============================
# ✅ Quick Sanity Checks
# ==============================
# print(f"X_test shape: {X_test.shape}")
# print(f"y_test shape: {y_test.shape}")

# ------------------------------------------------------------
# 2️⃣ Load model and re-export booster in XGBoost’s native format
#    (this fixes the base_score string bug permanently)
# ------------------------------------------------------------



In [4]:
# Display a preview
display(X_test.head())

Unnamed: 0,SeniorCitizen,tenure,MonthlyCharges,TotalCharges,customerID_0003-MKNFE,customerID_0004-TLHLJ,customerID_0011-IGKFF,customerID_0013-EXCHZ,customerID_0013-MHZWF,customerID_0013-SMEOE,...,StreamingTV_No internet service,StreamingTV_Yes,StreamingMovies_No internet service,StreamingMovies_Yes,Contract_One year,Contract_Two year,PaperlessBilling_Yes,PaymentMethod_Credit card (automatic),PaymentMethod_Electronic check,PaymentMethod_Mailed check
0,-0.441773,1.608483,1.629976,2.707614,False,False,False,False,False,False,...,False,True,False,True,False,True,True,True,False,False
1,2.263606,-0.996684,1.168725,-0.611505,False,False,False,False,False,False,...,False,True,False,True,False,False,True,True,False,False
2,-0.441773,0.346606,0.445324,0.39949,False,False,False,False,False,False,...,False,True,False,False,True,False,True,True,False,False
3,-0.441773,-0.589626,0.440347,-0.365546,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,True,False
4,-0.441773,1.608483,0.588013,1.588523,False,False,False,False,False,False,...,False,True,False,True,False,True,True,True,False,False


In [5]:
# Confirm alignment
assert len(X_test) == len(y_test), "Mismatch between X_test and y_test length!"
print("✅ Data loaded successfully and aligned.")

✅ Data loaded successfully and aligned.


In [7]:
import shap
import joblib
import matplotlib.pyplot as plt
from pathlib import Path

# ------------------------------------------------------------
# 1️⃣ Load model and test data
# ------------------------------------------------------------
model = joblib.load("../models/xgb_churn_model.pkl")

# (assuming X_test already loaded in memory)
print("Model and test data loaded successfully.")

# ------------------------------------------------------------
# 2️⃣ Initialize SHAP explainer (will now work!)
# ------------------------------------------------------------
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)

print("✅ SHAP explainer initialized successfully!")

# ------------------------------------------------------------
# 3️⃣ Save global explainability plots
# ------------------------------------------------------------
Path("images").mkdir(exist_ok=True)

# Summary plot
plt.title("SHAP Summary Plot - Global Feature Importance")
shap.summary_plot(shap_values, X_test, show=False)
plt.tight_layout()
plt.savefig("images/shap_summary_plot.png", bbox_inches="tight")
plt.close()

# Bar plot
plt.title("SHAP Bar Plot - Mean Absolute Feature Impact")
shap.summary_plot(shap_values, X_test, plot_type="bar", show=False)
plt.tight_layout()
plt.savefig("images/shap_bar_plot.png", bbox_inches="tight")
plt.close()

# ------------------------------------------------------------
# 4️⃣ Local explanation for a sample
# ------------------------------------------------------------
sample_index = 5
shap.plots.waterfall(
    shap.Explanation(
        values=shap_values[sample_index],
        base_values=explainer.expected_value,
        data=X_test.iloc[sample_index],
        feature_names=X_test.columns
    ),
    show=False
)
plt.tight_layout()
plt.savefig(f"images/shap_local_explanation_{sample_index}.png", bbox_inches="tight")
plt.close()

print("✅ All SHAP plots saved to /images/")


Model and test data loaded successfully.
✅ SHAP explainer initialized successfully!
✅ All SHAP plots saved to /images/
