## 9. Final model evaluation and interpretation

Now that the model is trained, we return to our constraints—one of which is high interpretability. We want to understand the model’s decisions so we can compare them to our existing biomedical knowledge. This is precisely why we avoided neural networks: interpreting their internal weights and activations is very difficult, and a black‐box approach offers little insight into how specific clinical variables drive each individual prediction. By using a tree‐based method like XGBoost, we can leverage post-hoc explanation techniques (such as SHAP) to quantify each feature’s contribution to the final decision.

In this next step, we will compute feature-level importance scores for every prediction, highlighting which biomarkers, demographic factors, or lab results push the model toward one outcome or another. Once we have those importance scores, we can rank the top contributors for each patient and then aggregate them across a cohort to see which features the model relies on most overall. Finally, we’ll compare this model-derived ranking to the set of risk factors that clinicians typically consider—checking for both alignment (e.g., HbA1c or BMI appearing near the top) and any surprising discrepancies (e.g., a lab value clinicians don’t usually emphasize). By doing so, we ensure that our predictive pipeline not only performs well but also remains transparent, trustworthy, and grounded in real-world medical expertise.

In [2]:
import numpy as np
import pandas as pd
import shap

def interpret_model_prediction(model, input_row, feature_names=None, class_names=None, plot=True):
    """
    Interpret a model prediction for a single input row using SHAP.

    Args:
        model: Trained XGBoost model (e.g., XGBClassifier).
        input_row: pd.Series or 1-row pd.DataFrame with features.
        feature_names: List of feature names (optional). If None, inferred from the model or input_row.
        class_names: List of class names (optional). If None, numeric class indices are used.
        plot: If True, display a SHAP force plot (matplotlib).

    Returns:
        pred_class: Predicted class label or index.
        pred_proba: Predicted probability for each class (numpy array).
        shap_values: SHAP values for the selected class (1D numpy array).
    """

    # 1. Ensure input_row is a single-row DataFrame
    if isinstance(input_row, pd.Series):
        input_row = input_row.to_frame().T
    elif not isinstance(input_row, pd.DataFrame) or input_row.shape[0] != 1:
        raise ValueError("input_row must be a pandas Series or a 1-row DataFrame")

    # 2. Determine feature names
    if feature_names is None:
        try:
            booster = model.get_booster()
            feature_names = booster.feature_names
        except Exception:
            feature_names = list(input_row.columns)
    # Restrict input_row to those features (errors if missing columns)
    input_row = input_row[feature_names]

    # 3. Predict probabilities and class
    pred_proba = model.predict_proba(input_row)[0]  # shape: (n_classes,)
    pred_class_idx = np.argmax(pred_proba)
    pred_class = class_names[pred_class_idx] if class_names else pred_class_idx

    print(f"Predicted class: {pred_class}")
    print(f"Probabilities: {pred_proba}")

    # 4. Compute SHAP values
    explainer = shap.TreeExplainer(model)
    shap_values_all = explainer.shap_values(input_row)

    # 4a. Handle multiclass vs. single-output: 
    #     - For multiclass, shap_values_all is a list of arrays (one per class).
    #     - For binary/single-output, shap_values_all is a single 2D array.
    if isinstance(shap_values_all, list):
        # shap_values_all[class_idx] is an array of shape (1, n_features)
        shap_values = shap_values_all[pred_class_idx][0]  
    else:
        # Single-output case: shap_values_all is (1, n_features)
        shap_values = shap_values_all[0]

    # 5. Identify top contributing features
    abs_shap = np.abs(shap_values)
    top_indices = np.argsort(-abs_shap)[:10]
    print("\nTop contributing features:")
    for idx in top_indices:
        fname = feature_names[idx]
        sval = shap_values[idx]
        print(f"  {fname}: {sval:.4f}")

    # 6. (Optional) SHAP force plot
    if plot:
        shap.initjs()
        # explainer.expected_value may be a list (one per class) or a scalar
        expected_value = (
            explainer.expected_value[pred_class_idx]
            if isinstance(explainer.expected_value, list)
            else explainer.expected_value
        )
        # Display the force plot with matplotlib backend
        shap.force_plot(
            expected_value,
            shap_values,
            input_row,
            feature_names=feature_names,
            matplotlib=True,
            show=True
        )

    return pred_class, pred_proba, shap_values
