Skip to content

Commit

Permalink
Merge pull request #767 from floidgilbert/decision-plot-legend
Browse files Browse the repository at this point in the history
decision plot multioutput, legend, new_base_value, etc.
  • Loading branch information
slundberg committed Aug 24, 2019
2 parents 5f8b718 + 1e845d9 commit 0c608c5
Show file tree
Hide file tree
Showing 7 changed files with 1,549 additions and 138 deletions.
3 changes: 2 additions & 1 deletion README.md
Expand Up @@ -299,7 +299,8 @@ An implementation of Kernel SHAP, a model agnostic method to estimate SHAP value

These notebooks comprehensively demonstrate how to use specific functions and objects.

- [`shap.decision_plot`](https://slundberg.github.io/shap/notebooks/plots/decision_plot.html)
- [`shap.decision_plot` and `shap.multioutput_decision_plot`](https://slundberg.github.io/shap/notebooks/plots
/decision_plot.html)

- [`shap.dependence_plot`](https://slundberg.github.io/shap/notebooks/plots/dependence_plot.html)

Expand Down
812 changes: 751 additions & 61 deletions docs/notebooks/plots/decision_plot.html

Large diffs are not rendered by default.

Binary file added notebooks/plots/data/heart.pickle
Binary file not shown.
604 changes: 542 additions & 62 deletions notebooks/plots/decision_plot.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion shap/__init__.py
Expand Up @@ -10,7 +10,7 @@
from .explainers.linear import LinearExplainer
from .explainers.partition import PartitionExplainer
from .plots.summary import summary_plot
from .plots.decision import decision_plot
from .plots.decision import decision_plot, multioutput_decision_plot
from .plots.dependence import dependence_plot
from .plots.force import force_plot, initjs, save_html
from .plots.image import image_plot
Expand Down
140 changes: 127 additions & 13 deletions shap/plots/decision.py
Expand Up @@ -16,7 +16,26 @@
pass

from . import colors, labels
from ..common import convert_to_link, LogitLink, hclust_ordering
from ..common import convert_to_link, hclust_ordering, LogitLink


def __change_shap_base_value(base_value, new_base_value, shap_values) -> np.ndarray:
"""Shift SHAP base value to a new value. This function assumes that `base_value` and `new_base_value` are scalars
and that `shap_values` is a two or three dimensional array.
"""
# matrix of shap_values
if shap_values.ndim == 2:
return shap_values + (base_value - new_base_value) / shap_values.shape[1]

# cube of shap_interaction_values
main_effects = shap_values.shape[1]
all_effects = main_effects * (main_effects + 1) // 2
temp = (base_value - new_base_value) / all_effects / 2 # divided by 2 because interaction effects are halved
shap_values = shap_values + temp
# Add the other half to the main effects on the diagonal
idx = np.diag_indices_from(shap_values[0])
shap_values[:, idx[0], idx[1]] += temp
return shap_values


def __decision_plot_matplotlib(
Expand All @@ -35,7 +54,9 @@ def __decision_plot_matplotlib(
color_bar,
auto_size_plot,
title,
show
show,
legend_labels,
legend_location,
):
"""matplotlib rendering for decision_plot()"""

Expand Down Expand Up @@ -65,14 +86,16 @@ def __decision_plot_matplotlib(
m = cm.ScalarMappable(cmap=plot_color)
m.set_clim(xlim)
y_pos = np.arange(0, feature_display_count + 1)
lines = []
for i in range(cumsum.shape[0]):
pl.plot(
o = pl.plot(
cumsum[i, :],
y_pos,
color=m.to_rgba(cumsum[i, -1], alpha),
linewidth=linewidth[i],
linestyle=linestyle[i]
)
lines.append(o[0])

# determine font size. if ' *\n' character sequence is found (as in interaction labels), use a smaller
# font. we don't shrink the font for all interaction plots because if an interaction term is not
Expand Down Expand Up @@ -139,6 +162,9 @@ def __decision_plot_matplotlib(
if ascending:
pl.gca().invert_yaxis()

if legend_labels is not None:
ax.legend(handles=lines, labels=legend_labels, loc=legend_location)

if show:
pl.show()

Expand All @@ -148,7 +174,7 @@ class DecisionPlotResult:
plots with the same scale and feature ordering.
"""

def __init__(self, shap_values, feature_names, feature_idx, xlim):
def __init__(self, base_value, shap_values, feature_names, feature_idx, xlim):
"""
Example
-------
Expand All @@ -159,10 +185,15 @@ def __init__(self, shap_values, feature_names, feature_idx, xlim):
Parameters
----------
base_value : float
The base value used in the plot. For multioutput models,
this will be the mean of the base values. This will inherit `new_base_value` if specified.
shap_values : numpy.ndarray
The `shap_values` passed to decision_plot re-ordered based on `feature_order`. If SHAP interaction values
are passed to decision_plot, `shap_values` is a 2D (matrix) representation of the interactions. See
`feature_names` to locate the feature positions.
`feature_names` to locate the feature positions. If `new_base_value` is specified, the SHAP values are
relative to the new base value.
feature_names : list of str
The feature names used in the plot in the order specified in the decision_plot parameter `feature_order`.
Expand All @@ -175,6 +206,7 @@ def __init__(self, shap_values, feature_names, feature_idx, xlim):
The x-axis limits. This attributed can be used to specify the same x-axis in multiple decision plots.
"""
self.base_value = base_value
self.shap_values = shap_values
self.feature_names = feature_names
self.feature_idx = feature_idx
Expand All @@ -200,20 +232,23 @@ def decision_plot(
xlim=None,
show=True,
return_objects=False,
ignore_warnings=False
ignore_warnings=False,
new_base_value=None,
legend_labels=None,
legend_location="best",
) -> Union[DecisionPlotResult, None]:
"""Visualize model decisions using cumulative SHAP values. Each colored line in the plot represents the model
prediction for a single observation. Note that plotting too many samples at once can make the plot unintelligible.
Parameters
----------
base_value : float or numpy.ndarray
This is the reference value that the feature contributions start from. For SHAP values it should
be the value of explainer.expected_value.
This is the reference value that the feature contributions start from. Usually, this is
explainer.expected_value.
shap_values : numpy.ndarray
Matrix of SHAP values (# features) or (# samples x # features) from explainer.shap_values(). Or cube of SHAP
interaction values (# samples x # features x # features). from explainer.shap_interaction_values().
interaction values (# samples x # features x # features) from explainer.shap_interaction_values().
features : numpy.array or pandas.Series or pandas.DataFrame or numpy.ndarray or list
Matrix of feature values (# features) or (# samples x # features). This provides the values of all the
Expand Down Expand Up @@ -280,6 +315,18 @@ def decision_plot(
Plotting many data points or too many features at a time may be slow, or may create very large plots. Set
this argument to `True` to override hard-coded limits that prevent plotting large amounts of data.
new_base_value : float
SHAP values are relative to a base value; by default, the expected value of the model's raw predictions. Use
`new_base_value` to shift the base value to an arbitrary value (e.g. the cutoff point for a binary
classification task).
legend_labels : list of str
List of legend labels. If `None`, legend will not be shown.
legend_location : str
Legend location. Any of "best", "upper right", "upper left", "lower left", "lower right", "right",
"center left", "center right", "lower center", "upper center", "center".
Returns
-------
Returns a DecisionPlotResult object if `return_objects=True`. Returns `None` otherwise (the default).
Expand All @@ -297,9 +344,11 @@ def decision_plot(
if type(base_value) == np.ndarray and len(base_value) == 1:
base_value = base_value[0]

if isinstance(shap_values, list):
raise TypeError("The shap_values arg looks like multi output. Try shap_values[i].")
if isinstance(base_value, list) or isinstance(shap_values, list):
raise TypeError("Looks like multi output. Try base_value[i] and shap_values[i], "
"or use shap.multioutput_decision_plot().")

# validate shap_values
if not isinstance(shap_values, np.ndarray):
raise TypeError("The shap_values arg is the wrong type. Try explainer.shap_values().")

Expand Down Expand Up @@ -397,6 +446,11 @@ def decision_plot(
feature_display_range.step
)

# apply new_base_value
if new_base_value is not None:
shap_values = __change_shap_base_value(base_value, new_base_value, shap_values)
base_value = new_base_value

# use feature_display_range to determine which features will be plotted. convert feature_display_range to
# ascending indices and expand by one in the negative direction. why? we are plotting the change in prediction
# for every feature. this requires that we include the value previous to the first displayed feature
Expand Down Expand Up @@ -438,6 +492,7 @@ def decision_plot(
# convert values based on link and update x-axis extents
create_xlim = xlim is None
link = convert_to_link(link)
base_value_saved = base_value
if isinstance(link, LogitLink):
base_value = link.finv(base_value)
cumsum = link.finv(cumsum)
Expand Down Expand Up @@ -480,7 +535,66 @@ def decision_plot(
color_bar,
auto_size_plot,
title,
show
show,
legend_labels,
legend_location,
)

return DecisionPlotResult(shap_values, feature_names, feature_idx, xlim) if return_objects else None
if not return_objects:
return None

return DecisionPlotResult(base_value_saved, shap_values, feature_names, feature_idx, xlim)


def multioutput_decision_plot(base_values, shap_values, row_index, **kwargs) -> Union[DecisionPlotResult, None]:
"""Decision plot for multioutput models. Plots all outputs for a single observation. By default, the plotted base
value will be the mean of base_values unless new_base_value is specified. Supports both SHAP values and SHAP
interaction values.
Parameters
----------
base_values : list of float
This is the reference value that the feature contributions start from. Use explainer.expected_value.
shap_values : list of numpy.ndarray
A multioutput list of SHAP matrices or SHAP cubes from explainer.shap_values() or
explainer.shap_interaction_values(), respectively.
row_index : int
The integer index of the row to plot.
**kwargs : Any
Arguments to be passed on to decision_plot().
Returns
-------
Returns a DecisionPlotResult object if `return_objects=True`. Returns `None` otherwise (the default).
"""

if not (isinstance(base_values, list) and isinstance(shap_values, list)):
raise ValueError("The base_values and shap_values args expect lists.")

# convert arguments to arrays for simpler handling
base_values = np.array(base_values)
if not ((base_values.ndim == 1) or (np.issubdtype(base_values.dtype, np.number))):
raise ValueError("The base_values arg should be a list of scalars.")
shap_values = np.array(shap_values)
if shap_values.ndim not in [3, 4]:
raise ValueError("The shap_values arg should be a list of two or three dimensional SHAP arrays.")
if shap_values.shape[0] != base_values.shape[0]:
raise ValueError("The base_values output length is different than shap_values.")

# shift shap base values to mean of base values
base_values_mean = base_values.mean()
for i in range(shap_values.shape[0]):
shap_values[i] = __change_shap_base_value(base_values[i], base_values_mean, shap_values[i])

# select the feature row corresponding to row_index
if (kwargs is not None) and ("features" in kwargs):
features = kwargs["features"]
if isinstance(features, np.ndarray) and (features.ndim == 2):
kwargs["features"] = features[[row_index]]
elif str(type(features)) == "<class 'pandas.core.frame.DataFrame'>":
kwargs["features"] = features.iloc[row_index]

return decision_plot(base_values_mean, shap_values[:, row_index, :], **kwargs)

0 comments on commit 0c608c5

Please sign in to comment.