In [1]:
import pandas as pd
import shap
import lime
import lime.lime_tabular
import joblib
import matplotlib.pyplot as plt
import numpy as np
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

# Load preprocessed data
X_train = pd.read_csv('C:/Users/chigu/Desktop/stroke_prediction_project/Data/X_train_preprocessed.csv')
X_test = pd.read_csv('C:/Users/chigu/Desktop/stroke_prediction_project/Data/X_test_preprocessed.csv')
y_train = pd.read_csv('C:/Users/chigu/Desktop/stroke_prediction_project/Data/y_train_preprocessed.csv').values.ravel()
y_test = pd.read_csv('C:/Users/chigu/Desktop/stroke_prediction_project/Data/y_test_preprocessed.csv').values.ravel()

# Ensure X_train and X_test are preprocessed identically
if not X_train.columns.equals(X_test.columns):
    print("Warning: X_train and X_test have different columns!")
    X_test = X_test.reindex(columns=X_train.columns, fill_value=0)  # Align columns safely

print("Data shapes:", X_train.shape, X_test.shape)

# Load trained model
model = joblib.load('C:/Users/chigu/Desktop/stroke_prediction_project/Models/stroke_model.pkl')

# SHAP Explanation
explainer = shap.Explainer(model, X_train)
shap_values = explainer(X_test, check_additivity=False)  # Disable additivity check

# Check SHAP values shape
print("SHAP Values Shape:", shap_values.values.shape)

# Handle Multi-Class SHAP Values (Select One Class)
if len(shap_values.values.shape) > 2:
    class_index = 1  # Select class 1 (Stroke), adjust if needed
    shap_values_fixed = shap.Explanation(
        values=shap_values.values[:, :, class_index],  # Extract single class
        base_values=shap_values.base_values[:, class_index],  # Extract single class base values
        data=X_test
    )
else:
    shap_values_fixed = shap_values  # Already correct format

# SHAP Summary Plot
shap.summary_plot(shap_values_fixed.values, X_test, show=False)
plt.savefig("C:/Users/chigu/Desktop/stroke_prediction_project/Models/shap_summary_plot.png")
plt.close()

# SHAP Beeswarm Plot
shap.plots.beeswarm(shap_values_fixed, show=False)
plt.savefig("C:/Users/chigu/Desktop/stroke_prediction_project/Models/shap_beeswarm_plot.png")
plt.close()

# LIME Explanation
explainer = lime.lime_tabular.LimeTabularExplainer(
    X_train.values, feature_names=X_train.columns, class_names=['No Stroke', 'Stroke'], mode='classification'
)

# Explain a single instance with LIME
idx = np.random.randint(0, X_test.shape[0])
exp = explainer.explain_instance(X_test.iloc[idx].to_numpy(), model.predict_proba)
lime_output_path = "C:/Users/chigu/Desktop/stroke_prediction_project/Models/lime_explanation.html"
exp.save_to_file(lime_output_path)

print(f"✅ Explainability analysis with SHAP & LIME completed successfully!\nLIME results saved at: {lime_output_path}")

Data shapes: (7778, 10) (1022, 10)




SHAP Values Shape: (1022, 10, 2)
✅ Explainability analysis with SHAP & LIME completed successfully!
LIME results saved at: C:/Users/chigu/Desktop/stroke_prediction_project/Models/lime_explanation.html
