# ML505: Model Interpretability -- SHAP & Permutation Importance (OPTIONAL)

---

> **This entire notebook is OPTIONAL.** It covers advanced interpretability techniques that go beyond core tree/ensemble modeling. These methods are valuable in practice but are not required for understanding the fundamentals of tree-based models.

---

## Learning Objectives

By the end of this notebook, you will be able to:

1. Compute and interpret permutation importance using scikit-learn
2. Compare permutation importance with built-in (impurity-based) feature importance
3. Create and interpret partial dependence plots
4. Understand the concept of SHAP values for model explanation
5. Choose the right interpretability method for your use case

## Prerequisites

- Random Forest and ensemble methods (Notebooks 01-02)
- Familiarity with feature importance concepts
- scikit-learn basics

## Table of Contents

1. [Permutation Importance](#1-permutation-importance)
2. [Permutation vs Built-in Importance](#2-permutation-vs-built-in-importance)
3. [Partial Dependence Plots](#3-partial-dependence-plots)
4. [SHAP Values](#4-shap-values)
5. [When to Use Which Method](#5-when-to-use-which-method)
6. [Common Mistakes](#6-common-mistakes)
7. [Exercises](#7-exercises)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.inspection import permutation_importance, PartialDependenceDisplay

plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12
sns.set_style('whitegrid')
np.random.seed(42)

In [None]:
# Load data and train a Random Forest for all interpretability demos
data = load_breast_cancer()
X, y = data.data, data.target
feature_names = data.feature_names

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42, stratify=y
)

rf = RandomForestClassifier(
    n_estimators=200, random_state=42, n_jobs=-1
)
rf.fit(X_train, y_train)

print(f"Train accuracy: {accuracy_score(y_train, rf.predict(X_train)):.4f}")
print(f"Test accuracy:  {accuracy_score(y_test, rf.predict(X_test)):.4f}")

## 1. Permutation Importance

> **(OPTIONAL)** This section covers permutation importance as an alternative to built-in feature importance.

### Concept

Permutation importance measures how much the model's performance **drops** when a single feature is randomly shuffled:

1. Compute baseline score on the dataset
2. For each feature:
   - Randomly shuffle that feature's values
   - Re-compute the score
   - Importance = baseline score - shuffled score
3. Repeat multiple times for stability

**Advantages over built-in importance:**
- Model-agnostic (works with any model)
- Not biased toward high-cardinality features
- Can be computed on test data (measures real generalization impact)

In [None]:
# Compute permutation importance on the TEST set
perm_imp = permutation_importance(
    rf, X_test, y_test,
    n_repeats=20,
    random_state=42,
    n_jobs=-1,
    scoring='accuracy'
)

# Sort by mean importance
sorted_idx = perm_imp.importances_mean.argsort()[::-1]

# Display top 15
top_n = 15
print(f"Top {top_n} features by permutation importance (on test set):")
print(f"{'Feature':<30} {'Mean':>8} {'Std':>8}")
print(f"{'-'*46}")
for idx in sorted_idx[:top_n]:
    print(f"{feature_names[idx]:<30} {perm_imp.importances_mean[idx]:>8.4f} "
          f"{perm_imp.importances_std[idx]:>8.4f}")

In [None]:
# Visualize permutation importance (top 15)
fig, ax = plt.subplots(figsize=(10, 7))

top_indices = sorted_idx[:top_n][::-1]  # reverse for horizontal bar plot
ax.barh(
    range(top_n),
    perm_imp.importances_mean[top_indices],
    xerr=perm_imp.importances_std[top_indices],
    color='steelblue',
    alpha=0.8
)
ax.set_yticks(range(top_n))
ax.set_yticklabels(feature_names[top_indices])
ax.set_xlabel('Mean Accuracy Decrease')
ax.set_title('Permutation Importance (computed on test set)')
plt.tight_layout()
plt.show()

## 2. Permutation vs Built-in Importance

> **(OPTIONAL)** Comparing the two importance methods reveals important differences.

In [None]:
# Compare built-in (Gini/impurity) importance with permutation importance
builtin_imp = rf.feature_importances_
perm_imp_mean = perm_imp.importances_mean

# Normalize both to [0, 1] for comparison
builtin_norm = builtin_imp / builtin_imp.max()
perm_norm = perm_imp_mean / perm_imp_mean.max() if perm_imp_mean.max() > 0 else perm_imp_mean

# Select top features by either method
top_by_builtin = np.argsort(builtin_imp)[::-1][:10]
top_by_perm = np.argsort(perm_imp_mean)[::-1][:10]
top_features = list(set(top_by_builtin) | set(top_by_perm))
top_features.sort(key=lambda i: -builtin_imp[i])

# Plot side-by-side
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Built-in importance
idx_sorted = np.argsort(builtin_imp[top_features])
axes[0].barh(
    range(len(top_features)),
    builtin_imp[np.array(top_features)[idx_sorted]],
    color='steelblue', alpha=0.8
)
axes[0].set_yticks(range(len(top_features)))
axes[0].set_yticklabels(feature_names[np.array(top_features)[idx_sorted]])
axes[0].set_xlabel('Importance')
axes[0].set_title('Built-in (Impurity) Importance')

# Permutation importance
idx_sorted_p = np.argsort(perm_imp_mean[top_features])
axes[1].barh(
    range(len(top_features)),
    perm_imp_mean[np.array(top_features)[idx_sorted_p]],
    color='coral', alpha=0.8
)
axes[1].set_yticks(range(len(top_features)))
axes[1].set_yticklabels(feature_names[np.array(top_features)[idx_sorted_p]])
axes[1].set_xlabel('Mean Accuracy Decrease')
axes[1].set_title('Permutation Importance (test set)')

plt.suptitle('Built-in vs Permutation Feature Importance', fontsize=14)
plt.tight_layout()
plt.show()

print("Key observation: The two methods may rank features differently.")
print("Built-in importance can be biased toward high-cardinality numerical features.")
print("Permutation importance measures actual predictive value.")

## 3. Partial Dependence Plots

> **(OPTIONAL)** Partial dependence shows the marginal effect of a feature on the prediction.

A **partial dependence plot (PDP)** shows how the predicted outcome changes as one feature varies, while averaging over all other features. It reveals the shape of the relationship between a feature and the target.

In [None]:
# Identify the top 4 features by permutation importance for PDP
top4_idx = perm_imp.importances_mean.argsort()[::-1][:4]
top4_names = feature_names[top4_idx]
print(f"Top 4 features for partial dependence plots: {list(top4_names)}")

In [None]:
# Partial Dependence Plots for top features
fig, ax = plt.subplots(figsize=(14, 8))
display = PartialDependenceDisplay.from_estimator(
    rf,
    X_train,
    features=list(top4_idx),
    feature_names=feature_names,
    kind='both',  # show individual conditional expectation (ICE) + average
    subsample=100,
    n_jobs=-1,
    ax=ax,
    random_state=42,
)
fig.suptitle('Partial Dependence Plots (Top 4 Features)', fontsize=14)
plt.tight_layout()
plt.show()

print("The yellow line shows the average partial dependence.")
print("The thin blue lines show Individual Conditional Expectation (ICE) curves.")
print("Large variation in ICE curves suggests feature interactions.")

In [None]:
# 2D Partial Dependence Plot (interaction between top 2 features)
fig, ax = plt.subplots(figsize=(8, 6))
display_2d = PartialDependenceDisplay.from_estimator(
    rf,
    X_train,
    features=[(top4_idx[0], top4_idx[1])],
    feature_names=feature_names,
    kind='average',
    n_jobs=-1,
    ax=ax,
)
fig.suptitle(
    f'2D Partial Dependence: {feature_names[top4_idx[0]]} vs {feature_names[top4_idx[1]]}',
    fontsize=13
)
plt.tight_layout()
plt.show()

## 4. SHAP Values

> **(OPTIONAL)** SHAP (SHapley Additive exPlanations) provides theoretically grounded feature attributions.

### Concept

SHAP values are based on **Shapley values** from cooperative game theory. For each prediction, SHAP assigns a value to each feature that represents its contribution to moving the prediction away from the average.

**Properties of SHAP values:**
- **Local accuracy**: SHAP values for a prediction sum to the difference between the prediction and the average prediction
- **Consistency**: If a model changes so that a feature has a larger impact, its SHAP value will not decrease
- **Missingness**: Features that are missing have SHAP value of 0

For tree-based models, `TreeExplainer` computes exact SHAP values efficiently in polynomial time.

In [None]:
try:
    import shap
    print(f"SHAP version: {shap.__version__}")
    HAS_SHAP = True
except ImportError:
    print("SHAP is not installed. Install with: pip install shap")
    print("Proceeding with a manual conceptual example instead.")
    HAS_SHAP = False

In [None]:
if HAS_SHAP:
    # TreeExplainer for efficient SHAP computation on tree-based models
    explainer = shap.TreeExplainer(rf)
    shap_values = explainer.shap_values(X_test)
    
    # Summary plot: global feature importance via SHAP
    print("SHAP Summary Plot (class 1 = malignant):")
    shap.summary_plot(
        shap_values[:, :, 1] if shap_values.ndim == 3 else shap_values[1],
        X_test,
        feature_names=feature_names,
        max_display=15,
        show=True
    )
else:
    print("SHAP library not available. Showing conceptual example.")
    print()
    print("=== Manual SHAP-like Explanation ===")
    print()
    print("For a single prediction, SHAP values explain how each feature")
    print("contributed to the prediction relative to the average.")
    print()
    
    # Manual demonstration using permutation-based approximation
    sample_idx = 0
    sample = X_test[sample_idx:sample_idx + 1]
    base_pred = rf.predict_proba(X_test).mean(axis=0)
    sample_pred = rf.predict_proba(sample)[0]
    
    print(f"Average prediction (baseline): {base_pred}")
    print(f"This sample's prediction:      {sample_pred}")
    print(f"Difference to explain:         {sample_pred - base_pred}")
    print()
    
    # Approximate feature contributions by permuting one feature at a time
    approx_contributions = []
    for i in range(X_test.shape[1]):
        X_permuted = sample.copy()
        X_permuted[0, i] = X_train[:, i].mean()  # replace with mean
        permuted_pred = rf.predict_proba(X_permuted)[0]
        contribution = sample_pred[1] - permuted_pred[1]
        approx_contributions.append(contribution)
    
    approx_contributions = np.array(approx_contributions)
    top_contrib_idx = np.argsort(np.abs(approx_contributions))[::-1][:10]
    
    print("Approximate feature contributions (top 10):")
    for idx in top_contrib_idx:
        direction = "+" if approx_contributions[idx] > 0 else "-"
        print(f"  {feature_names[idx]:<30} {direction}{abs(approx_contributions[idx]):.4f}")
    
    print()
    print("Note: This is a simplified approximation. True SHAP values account")
    print("for all possible feature coalitions, not just single-feature replacement.")

In [None]:
if HAS_SHAP:
    # Force plot for a single prediction
    print("SHAP Force Plot for sample 0:")
    shap.initjs()
    
    sv = shap_values[:, :, 1] if shap_values.ndim == 3 else shap_values[1]
    ev = explainer.expected_value[1] if hasattr(explainer.expected_value, '__len__') else explainer.expected_value
    
    shap.force_plot(
        ev,
        sv[0],
        X_test[0],
        feature_names=feature_names,
        matplotlib=True
    )
    plt.tight_layout()
    plt.show()
else:
    print("SHAP not available -- force plot requires the shap library.")
    print("Install with: pip install shap")

In [None]:
if HAS_SHAP:
    # SHAP bar plot: mean absolute SHAP values (global importance)
    print("SHAP Bar Plot (mean |SHAP value|):")
    shap.summary_plot(
        shap_values[:, :, 1] if shap_values.ndim == 3 else shap_values[1],
        X_test,
        feature_names=feature_names,
        plot_type='bar',
        max_display=15,
        show=True
    )
else:
    print("SHAP not available. See permutation importance above for global feature ranking.")

## 5. When to Use Which Method

> **(OPTIONAL)** Choosing the right interpretability tool for your situation.

| Method | Best For | Pros | Cons |
|--------|----------|------|------|
| **Built-in importance** | Quick feature ranking for tree models | Fast, no extra computation | Biased toward high-cardinality features; training-data only |
| **Permutation importance** | Model-agnostic feature ranking | Unbiased; can use test data; simple concept | Slow for many features; affected by correlated features |
| **Partial dependence** | Understanding feature-response shape | Shows non-linear effects; intuitive plots | Assumes feature independence; can be misleading with correlations |
| **SHAP values** | Individual prediction explanations | Theoretically sound; local + global; handles interactions | Slower; requires library; complex to explain to non-technical audiences |

### Recommendations

- **Start with permutation importance** for a reliable global feature ranking
- **Use partial dependence** to understand the shape of important feature effects
- **Use SHAP** when you need to explain individual predictions (e.g., "why did the model predict X for this patient?")
- **Avoid relying solely on built-in importance** for final conclusions

## 6. Common Mistakes

1. **Confusing importance with causation**: A feature being important means it is useful for prediction, not that it causes the outcome. Correlation and predictive power are not causation.

2. **Using training data for permutation importance**: Permutation importance should be computed on the **test set** (or held-out data). On training data, it may reflect overfitting patterns rather than true predictive value.

3. **Ignoring feature correlations**: When features are correlated, permutation importance can underestimate both features' importance (since permuting one leaves the correlated feature intact). Consider grouping correlated features or using SHAP.

4. **Over-interpreting partial dependence with correlated features**: PDP assumes features are independent. If two features are highly correlated, the PDP may show unrealistic data points that do not occur in practice.

5. **Treating SHAP values as definitive explanations**: SHAP values explain the model, not the underlying reality. A model can learn spurious correlations, and SHAP will faithfully explain those spurious patterns.

## 7. Exercises

### Exercise 1: Permutation Importance on Train vs Test
Compute permutation importance on both the training set and the test set for the Random Forest. Compare the rankings. Are there features that appear important on training data but not on test data? What does this suggest?

### Exercise 2: Partial Dependence with Interactions
Create 2D partial dependence plots for the top 2-3 pairs of features. Do any pairs show interaction effects (non-additive patterns in the 2D PDP)?

### Exercise 3: Compare Interpretability Across Models
Train a `GradientBoostingClassifier` alongside the Random Forest. Compute permutation importance for both models. Do the two models agree on the most important features? What might explain any differences?

In [None]:
# Exercise 1 starter code
# perm_imp_train = permutation_importance(
#     rf, X_train, y_train, n_repeats=20, random_state=42, n_jobs=-1
# )
# perm_imp_test = permutation_importance(
#     rf, X_test, y_test, n_repeats=20, random_state=42, n_jobs=-1
# )
# # Compare: perm_imp_train.importances_mean vs perm_imp_test.importances_mean
# # Look for features with high train importance but low test importance (overfitting signal)