# Lab 7.3 — SHAP Values (Explainable AI for Digital Health)

In this lab, you will train a simple **heart disease risk** classifier and use **SHAP** to explain:
- **Global** behavior (which features matter most overall)
- **Local** behavior (why a specific case gets a high/low risk prediction)

Dataset: `heart_disease_uci.csv`


## Learning objectives
By the end of this lab, you should be able to:
1. Load and preprocess a clinical dataset (mixed numeric + categorical)
2. Train and evaluate a baseline classification model
3. Use **SHAP** to explain predictions:
   - Global: *summary plot*, *feature importance bar plot*
   - Local: *waterfall plot* for individual cases
4. Communicate model explanations in plain language (not causality)


In [None]:
# If you run this in Google Colab, install dependencies first
!pip -q install shap scikit-learn pandas matplotlib


In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    roc_auc_score,
    accuracy_score,
)

from sklearn.ensemble import RandomForestClassifier

import shap


## 1) Load data

If you are using **Colab**, upload `heart_disease_uci.csv` to `/content/`.


In [None]:
def resolve_data_path(filename: str) -> str:
    candidates = [
        filename,
        os.path.join(".", filename),
        os.path.join("/content", filename),  # Colab default
    ]
    for p in candidates:
        if os.path.exists(p):
            return p
    raise FileNotFoundError(
        f"Could not find {filename}. If you are on Colab, click the folder icon and upload it to /content."
    )

data_path = resolve_data_path("heart_disease_uci.csv")
df = pd.read_csv(data_path)

print("Loaded:", data_path)
print("Shape:", df.shape)
df.head()


## 2) Define the prediction target

In this dataset, `num` is commonly used as **disease severity**:
- `num = 0` means *no heart disease*
- `num > 0` means *heart disease present*

For this lab, we create a **binary** target:  
`target = 1 if num > 0 else 0`


In [None]:
# Basic checks
print(df["num"].value_counts(dropna=False).sort_index())

# Binary target
df["target"] = (df["num"] > 0).astype(int)

print("\nBinary target distribution:")
print(df["target"].value_counts(normalize=True).rename("proportion"))


## 3) Prepare features (X) and labels (y)

We drop:
- `id` (identifier)
- `num` (original multi-class / severity label)

Everything else becomes an input feature.


In [None]:
drop_cols = ["id", "num"]
X = df.drop(columns=drop_cols + ["target"], errors="ignore")
y = df["target"].astype(int)

# Identify feature types
cat_cols = X.select_dtypes(include=["object", "bool"]).columns.tolist()
num_cols = [c for c in X.columns if c not in cat_cols]

print("Numeric columns:", num_cols)
print("Categorical/boolean columns:", cat_cols)


## 4) Train/test split


In [None]:
X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    test_size=0.25,
    random_state=42,
    stratify=y
)

print("Train:", X_train.shape, " Test:", X_test.shape)


## 5) Build a baseline model (preprocessing + classifier)

We will use:
- **Imputation** (fill missing values)
- **One-hot encoding** for categorical features
- **RandomForestClassifier** as a strong, simple baseline


In [None]:
numeric_transformer = Pipeline(steps=[
    ("imputer", SimpleImputer(strategy="median")),
])

categorical_transformer = Pipeline(steps=[
    ("imputer", SimpleImputer(strategy="most_frequent")),
    ("onehot", OneHotEncoder(handle_unknown="ignore")),
])

preprocess = ColumnTransformer(
    transformers=[
        ("num", numeric_transformer, num_cols),
        ("cat", categorical_transformer, cat_cols),
    ],
    remainder="drop",
)

model = RandomForestClassifier(
    n_estimators=400,
    random_state=42,
    class_weight="balanced",
    n_jobs=-1,
)

clf = Pipeline(steps=[
    ("preprocess", preprocess),
    ("model", model),
])

clf


## 6) Train + evaluate


In [None]:
clf.fit(X_train, y_train)

# Predictions
y_pred = clf.predict(X_test)
y_prob = clf.predict_proba(X_test)[:, 1]

acc = accuracy_score(y_test, y_pred)
auc = roc_auc_score(y_test, y_prob)

print(f"Accuracy: {acc:.3f}")
print(f"ROC-AUC : {auc:.3f}\n")

print("Classification report:")
print(classification_report(y_test, y_pred))

print("Confusion matrix:")
confusion_matrix(y_test, y_pred)


## 7) SHAP explanations

### Important concept
SHAP explains **the model's behavior** (how it uses inputs), not medical causality.

Because our model uses one-hot encoding, we compute SHAP on the **transformed feature matrix**.


In [None]:
# Transform data into the model's feature space
X_train_t = clf.named_steps["preprocess"].transform(X_train)
X_test_t  = clf.named_steps["preprocess"].transform(X_test)

# Convert sparse → dense (SHAP plots are easiest with dense arrays)
def to_dense(X):
    return X.toarray() if hasattr(X, "toarray") else np.asarray(X)

X_train_dense = to_dense(X_train_t)
X_test_dense  = to_dense(X_test_t)

# Build feature names (numeric + one-hot categorical)
ohe = clf.named_steps["preprocess"].named_transformers_["cat"].named_steps["onehot"]
cat_feature_names = ohe.get_feature_names_out(cat_cols).tolist()
feature_names = num_cols + cat_feature_names

len(feature_names), X_test_dense.shape


### 7.1 Global explanation — SHAP summary plot
- Shows which features matter most overall
- Color indicates feature value (low → high)


In [None]:
# TreeExplainer works well for tree-based models (like RandomForest)
explainer = shap.TreeExplainer(clf.named_steps["model"])

# For binary classification, SHAP can return:
# - a list of arrays (one per class), or
# - a single array, depending on SHAP/model version.
shap_values = explainer.shap_values(X_test_dense)

# Pick SHAP values for the positive class (heart disease present)
if isinstance(shap_values, list) and len(shap_values) == 2:
    shap_pos = shap_values[1]
else:
    shap_pos = shap_values

# Summary plot (beeswarm)
shap.summary_plot(shap_pos, X_test_dense, feature_names=feature_names, show=True)


### 7.2 Global explanation — feature importance (bar)


In [None]:
shap.summary_plot(shap_pos, X_test_dense, feature_names=feature_names, plot_type="bar", show=True)


### 7.3 Dependence plot — how one feature affects prediction

Pick one feature (often a top feature from the bar plot) and visualize its effect.


In [None]:
# Choose a feature to inspect (edit this if you want)
feature_to_plot = feature_names[0]  # change to a more interesting feature after you see the bar plot
print("Feature:", feature_to_plot)

shap.dependence_plot(
    feature_to_plot,
    shap_pos,
    X_test_dense,
    feature_names=feature_names,
    show=True
)


## 8) Local explanation — why this case got this prediction

We will explain two cases:
- A **high-risk** case (highest predicted probability)
- A **low-risk** case (lowest predicted probability)


In [None]:
# Get predicted probabilities for each test case
test_probs = clf.predict_proba(X_test)[:, 1]

high_idx = int(np.argmax(test_probs))
low_idx  = int(np.argmin(test_probs))

print("High-risk predicted probability:", test_probs[high_idx])
print("Low-risk predicted probability :", test_probs[low_idx])

# Show original (human-readable) input rows
display(pd.DataFrame({
    "case": ["HIGH risk", "LOW risk"],
    "predicted_prob": [test_probs[high_idx], test_probs[low_idx]],
}))
print("\nHIGH-risk case inputs:")
display(X_test.iloc[high_idx:high_idx+1])
print("\nLOW-risk case inputs:")
display(X_test.iloc[low_idx:low_idx+1])


In [None]:
# Waterfall plots (local explanations)
# SHAP "expected_value" is the baseline model output; feature contributions push it up/down.

# Create Explanation objects
base = explainer.expected_value
# expected_value can also be a list for binary classifiers
if isinstance(base, (list, np.ndarray)) and len(np.atleast_1d(base)) == 2:
    base_pos = base[1]
else:
    base_pos = float(np.atleast_1d(base)[0])

# High-risk case
exp_high = shap.Explanation(
    values=shap_pos[high_idx],
    base_values=base_pos,
    data=X_test_dense[high_idx],
    feature_names=feature_names
)

# Low-risk case
exp_low = shap.Explanation(
    values=shap_pos[low_idx],
    base_values=base_pos,
    data=X_test_dense[low_idx],
    feature_names=feature_names
)

print("Waterfall — HIGH risk")
shap.plots.waterfall(exp_high, max_display=12)

print("\nWaterfall — LOW risk")
shap.plots.waterfall(exp_low, max_display=12)


## 9) Student tasks (submit screenshots + short answers)

1. **Global:** Take a screenshot of the **bar plot** and write 2–3 sentences:
   - Top 3 features and what direction they influence risk (based on the summary plot).

2. **Local:** Take screenshots of both **waterfall plots** and write:
   - Which 2 features push risk up the most for the high-risk case?
   - Which 2 features push risk down the most for the low-risk case?

3. **Dependence:** Create one dependence plot for a top feature and explain what you observe.

### Reminder
SHAP explains what the **model** learned from this dataset. It does **not** prove medical causality.
