# Heart Disease Prediction - Explainable AI (XAI) Analysis

This notebook implements and demonstrates XAI techniques for heart disease prediction model interpretability.

In [None]:
# Import libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
import shap
from lime import lime_tabular
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Set up plotting
plt.style.use('ggplot')
sns.set(style="whitegrid")
%matplotlib inline

In [None]:
# Load the dataset, model, and scaler
df = pd.read_csv('../data/heart.csv')
model = joblib.load('../backend/model/model.pkl')
scaler = joblib.load('../backend/model/scaler.pkl')
feature_names = joblib.load('../backend/model/feature_names.pkl')

# Split the data
X = df.drop('target', axis=1)
y = df['target']

# Scale the features
X_scaled = scaler.transform(X)

print(f"Model loaded: {type(model).__name__}")
print(f"Dataset shape: {X.shape}")

## 1. Global Explanations with SHAP

SHAP (SHapley Additive exPlanations) values show the impact of each feature on model predictions across the entire dataset.

In [None]:
# Initialize SHAP explainer
# For tree-based models (Random Forest, XGBoost)
if hasattr(model, 'feature_importances_'):
    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(X_scaled)
    
    # For binary classification, some models return a list with one element
    if isinstance(shap_values, list):
        shap_values = shap_values[1]  # Get SHAP values for the positive class (heart disease)
else:
    # For other models (e.g., Logistic Regression)
    explainer = shap.KernelExplainer(model.predict_proba, shap.sample(X_scaled, 100))
    shap_values = explainer.shap_values(X_scaled)[1]  # For the positive class

In [None]:
# Global feature importance based on SHAP values
plt.figure(figsize=(12, 8))
shap.summary_plot(shap_values, X_scaled, feature_names=feature_names, plot_type="bar")
plt.title('Global Feature Importance (SHAP)')
plt.tight_layout()
plt.show()

In [None]:
# SHAP summary plot showing feature value impacts
plt.figure(figsize=(12, 10))
shap.summary_plot(shap_values, X_scaled, feature_names=feature_names)
plt.title('Feature Value Impact on Predictions (SHAP)')
plt.tight_layout()
plt.show()

## 2. Local Explanations with SHAP

Let's examine individual predictions to understand how the model arrives at specific decisions.

In [None]:
# Select a sample patient for explanation
sample_idx = 50  # Can be changed to examine different patients
sample_X = X.iloc[sample_idx:sample_idx+1]
sample_X_scaled = scaler.transform(sample_X)
true_label = y.iloc[sample_idx]
prediction = model.predict(sample_X_scaled)[0]
prediction_proba = model.predict_proba(sample_X_scaled)[0, 1]

print(f"Sample patient data:\n{sample_X}")
print(f"True label: {'Heart Disease' if true_label == 1 else 'No Heart Disease'}")
print(f"Predicted: {'Heart Disease' if prediction == 1 else 'No Heart Disease'} with {prediction_proba:.4f} probability")

In [None]:
# SHAP force plot for the sample patient
plt.figure(figsize=(20, 3))
shap_values_sample = explainer.shap_values(sample_X_scaled)

# For binary classification, some models return a list
if isinstance(shap_values_sample, list):
    shap_values_sample = shap_values_sample[1]  # For the positive class

shap.force_plot(explainer.expected_value if not hasattr(explainer, 'expected_value') else explainer.expected_value[1], 
                shap_values_sample, 
                sample_X, 
                feature_names=feature_names, 
                matplotlib=True,
                show=True)

In [None]:
# Convert force plot to decision plot
plt.figure(figsize=(12, 10))
shap.decision_plot(explainer.expected_value if not hasattr(explainer, 'expected_value') else explainer.expected_value[1], 
                   shap_values_sample, 
                   feature_names=feature_names)
plt.title('SHAP Decision Plot for Sample Patient')
plt.tight_layout()
plt.show()

## 3. LIME Explanations

LIME (Local Interpretable Model-agnostic Explanations) creates a simple, interpretable model that approximates the original model's behavior locally.

In [None]:
# Initialize LIME explainer
lime_explainer = lime_tabular.LimeTabularExplainer(
    X_scaled,
    feature_names=feature_names,
    class_names=['No Heart Disease', 'Heart Disease'],
    mode='classification'
)

In [None]:
# Generate LIME explanation for sample patient
lime_exp = lime_explainer.explain_instance(
    sample_X_scaled[0],
    model.predict_proba,
    num_features=len(feature_names),
    top_labels=1
)

# Plot LIME explanation
plt.figure(figsize=(10, 8))
lime_exp.as_pyplot_figure(label=1)  # For the positive class
plt.title('LIME Explanation for Sample Patient')
plt.tight_layout()
plt.show()

## 4. Comparing XAI Techniques

Now let's compare the explanations from both SHAP and LIME for the same patient.

In [None]:
# Extract feature importance from LIME explanation
lime_importance = dict(lime_exp.as_list(label=1))
lime_df = pd.DataFrame({
    'Feature': [item.split(' ')[0] for item in lime_importance.keys()],
    'Importance': list(lime_importance.values())
}).sort_values(by='Importance', key=abs, ascending=False)

# Extract SHAP values for the sample
shap_df = pd.DataFrame({
    'Feature': feature_names,
    'Importance': shap_values_sample[0]
}).sort_values(by='Importance', key=abs, ascending=False)

# Plot comparison
fig, axes = plt.subplots(1, 2, figsize=(16, 8))

# SHAP plot
sns.barplot(x='Importance', y='Feature', data=shap_df.head(10), ax=axes[0])
axes[0].set_title('Top 10 Features (SHAP)')

# LIME plot
sns.barplot(x='Importance', y='Feature', data=lime_df.head(10), ax=axes[1])
axes[1].set_title('Top 10 Features (LIME)')

plt.tight_layout()
plt.show()

## 5. Implementing a Function for Explanation

Finally, let's develop a function that can be used in the backend API to generate explanations for any patient.

In [None]:
def generate_explanation(patient_data, model, scaler, feature_names):
    """
    Generate SHAP and LIME explanations for a patient's prediction.
    
    Args:
        patient_data (pd.DataFrame): Patient features as a DataFrame
        model: Trained ML model
        scaler: Feature scaler used during training
        feature_names (list): List of feature names
        
    Returns:
        dict: Dictionary containing prediction and explanations
    """
    # Scale the input data
    patient_data_scaled = scaler.transform(patient_data)
    
    # Make prediction
    prediction = model.predict(patient_data_scaled)[0]
    probability = model.predict_proba(patient_data_scaled)[0, 1]
    
    # Initialize SHAP explainer
    if hasattr(model, 'feature_importances_'):
        explainer = shap.TreeExplainer(model)
    else:
        # Use a subset of the training data as background
        background_data = shap.sample(X_scaled, 100)
        explainer = shap.KernelExplainer(model.predict_proba, background_data)
    
    # Calculate SHAP values
    shap_values = explainer.shap_values(patient_data_scaled)
    
    # For binary classification, some models return a list
    if isinstance(shap_values, list):
        shap_values = shap_values[1]  # For the positive class
    
    # Get SHAP values as dictionary
    shap_dict = {feature: float(value) for feature, value in zip(feature_names, shap_values[0])}
    
    # Initialize LIME explainer
    lime_explainer = lime_tabular.LimeTabularExplainer(
        X_scaled,
        feature_names=feature_names,
        class_names=['No Heart Disease', 'Heart Disease'],
        mode='classification'
    )
    
    # Generate LIME explanation
    lime_exp = lime_explainer.explain_instance(
        patient_data_scaled[0],
        model.predict_proba,
        num_features=len(feature_names),
        top_labels=1
    )
    
    # Extract LIME explanation as list of (feature, weight) tuples
    lime_list = lime_exp.as_list(label=1)
    lime_explanation = [{'feature': item[0], 'weight': float(item[1])} for item in lime_list]
    
    # Get global feature importance if available
    if hasattr(model, 'feature_importances_'):
        feature_importance = {feature: float(importance) 
                              for feature, importance in zip(feature_names, model.feature_importances_)}
    else:
        # For models without built-in feature importance (e.g., Logistic Regression)
        feature_importance = {feature: float(abs(coef)) 
                              for feature, coef in zip(feature_names, model.coef_[0])}
    
    # Create explanation dictionary
    explanation = {
        'prediction': int(prediction),
        'probability': float(probability),
        'shap_values': shap_dict,
        'feature_importance': feature_importance,
        'lime_explanation': lime_explanation,
        'expected_value': float(explainer.expected_value if not hasattr(explainer, 'expected_value') 
                              else explainer.expected_value[1])
    }
    
    return explanation

In [None]:
# Test the explanation function with our sample patient
explanation = generate_explanation(sample_X, model, scaler, feature_names)

# Print the explanation in a readable format
print(f"Prediction: {'Heart Disease' if explanation['prediction'] == 1 else 'No Heart Disease'}")
print(f"Probability: {explanation['probability']:.4f}")
print("\nTop 5 SHAP values:")
for feature, value in sorted(explanation['shap_values'].items(), key=lambda x: abs(x[1]), reverse=True)[:5]:
    print(f"  {feature}: {value:.4f}")
    
print("\nTop 5 LIME features:")
for item in sorted(explanation['lime_explanation'], key=lambda x: abs(x['weight']), reverse=True)[:5]:
    print(f"  {item['feature']}: {item['weight']:.4f}")
    
print("\nTop 5 global feature importance:")
for feature, value in sorted(explanation['feature_importance'].items(), key=lambda x: x[1], reverse=True)[:5]:
    print(f"  {feature}: {value:.4f}")

In [None]:
# Save the explanation function for use in the backend
import inspect
function_code = inspect.getsource(generate_explanation)

with open('../backend/model/explainability.py', 'w') as f:
    f.write("import pandas as pd\n")
    f.write("import numpy as np\n")
    f.write("import shap\n")
    f.write("from lime import lime_tabular\n\n")
    f.write(function_code)

print("Explanation function saved to '../backend/model/explainability.py'")

## XAI Analysis Summary

In this notebook, we've implemented and analyzed several Explainable AI techniques for our heart disease prediction model:

1. **SHAP (Global Explanations)**: We identified which features are most important overall for predicting heart disease.

2. **SHAP (Local Explanations)**: We examined how specific feature values contribute to individual predictions through force plots and decision plots.

3. **LIME Explanations**: We generated local explanations using LIME, which creates simple approximations of the model for individual predictions.

4. **Comparison**: We compared SHAP and LIME explanations, noting similarities and differences in their interpretations.

5. **API Integration**: We created a reusable function that can be integrated into our backend API to generate explanations for any new patient.

These XAI techniques provide transparency into our model's decision-making process, helping both healthcare providers and patients understand the factors driving heart disease risk predictions.