Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Prototype new plotting API for bar plot, accept and return axes #3523

Merged
merged 14 commits into from
Feb 28, 2024
157 changes: 49 additions & 108 deletions notebooks/api_examples/plots/bar.ipynb

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion scripts/run_notebooks_timeouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
]

allow_to_timeout = [
Path("api_examples/plots/bar.ipynb"),
Path("api_examples/plots/beeswarm.ipynb"),
Path("api_examples/plots/image.ipynb"),
Path("api_examples/plots/scatter.ipynb"),
Expand Down
116 changes: 71 additions & 45 deletions shap/plots/_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,49 @@
# TODO: improve the bar chart to look better like the waterfall plot with numbers inside the bars when they fit
# TODO: Have the Explanation object track enough data so that we can tell (and so show) how many instances are in each cohort
def bar(shap_values, max_display=10, order=Explanation.abs, clustering=None, clustering_cutoff=0.5,
merge_cohorts=False, show_data="auto", show=True):
connortann marked this conversation as resolved.
Show resolved Hide resolved
show_data="auto", ax=None, show=True):
"""Create a bar plot of a set of SHAP values.

If a single sample is passed, then we plot the SHAP values as a bar chart. If an
:class:`.Explanation` with many samples is passed, then we plot the mean absolute
value for each feature column as a bar chart.


Parameters
----------
shap_values : shap.Explanation or shap.Cohorts or dictionary of shap.Explanation objects
A single row of a SHAP :class:`.Explanation` object (i.e. ``shap_values[0]``) or
a multi-row Explanation object that we want to summarize.
Passing a multi-row :class:`.Explanation` object creates a global
feature importance plot.

Passing a single row of an explanation (i.e. ``shap_values[0]``) creates
a local feature importance plot.

Passing a dictionary of Explanation objects will create a multiple-bar
plot with one bar type for each of the cohorts represented by the
explanation objects.
max_display : int
How many top features to include in the bar plot (default is 10).

order : OpChain or numpy.ndarray
A function that returns a sort ordering given a matrix of SHAP values
and an axis, or a direct sample ordering given as an ``numpy.ndarray``.
connortann marked this conversation as resolved.
Show resolved Hide resolved

By default, take the absolute value.
clustering: Optional np.ndarray
connortann marked this conversation as resolved.
Show resolved Hide resolved
A partition tree, as returned by ``shap.utils.hclust``
clustering_cutoff: float
Controls how much of the clustering structure is displayed.
show_data: bool or str
Controls if data values are shown as part of the y tick labels. If
"auto", we show the data only when there are no transforms.
ax: matplotlib Axes
Axes object to draw the plot onto, otherwise uses the current Axes.
show : bool
Whether ``matplotlib.pyplot.show()`` is called before returning.
Setting this to ``False`` allows the plot
to be customized further after it has been created.

Returns
-------
ax: matplotlib Axes
Returns the Axes object with the plot drawn onto it. Only returned if ``show=False``.

Examples
--------

See `bar plot examples <https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/bar.html>`_.

"""
Expand Down Expand Up @@ -97,16 +115,21 @@ def bar(shap_values, max_display=10, order=Explanation.abs, clustering=None, clu
else:
partition_tree = clustering
if partition_tree is not None:
assert partition_tree.shape[1] == 4, "The clustering provided by the Explanation object does not seem to be a partition tree (which is all shap.plots.bar supports)!"
if len(partition_tree.shape) != 2 or partition_tree.shape[1] != 4:
raise TypeError(
"The clustering provided by the Explanation object does not seem to be a "
"partition tree, which is all shap.plots.bar supports."
)
op_history = cohort_exps[0].op_history
values = np.array([cohort_exps[i].values for i in range(len(cohort_exps))])

if len(values[0]) == 0:
raise Exception("The passed Explanation is empty! (so there is nothing to plot)")
raise ValueError("The passed Explanation is empty, so there is nothing to plot.")

# we show the data on auto only when there are no transforms
# we show the data on auto only when there are no transforms (excluding getitem calls)
if show_data == "auto":
show_data = len(op_history) == 0
transforms = [t for t in op_history if t.get("name") != "__getitem__"]
show_data = len(transforms) == 0
Comment on lines +129 to +132
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is one functional change that I needed to make in order to reproduce more closely the plots in the API example notebook.

In the API example notebook, many plots have show_data=True. This adds the small grey numbers on the left axis. It looks like this behaviour was then changed at some point to only show data values when there are no transforms applied.

However, I think show_data still makes sense as True for a single explanation when the only operation applied has been __getitem__.

Before (show_data = False)

image

After (show_data = True)

image


# TODO: Rather than just show the "1st token", "2nd token", etc. it would be better to show the "Instance 0's 1st but", etc
if issubclass(type(feature_names), str):
Expand Down Expand Up @@ -208,36 +231,38 @@ def bar(shap_values, max_display=10, order=Explanation.abs, clustering=None, clu
if num_features < len(values[0]):
yticklabels[-1] = "Sum of %d other features" % num_cut

# compute our figure size based on how many features we are showing
row_height = 0.5
pl.gcf().set_size_inches(8, num_features * row_height * np.sqrt(len(values)) + 1.5)
if ax is None:
ax = pl.gca()
# Only modify the figure size if ax was not passed in
# compute our figure size based on how many features we are showing
fig = pl.gcf()
row_height = 0.5
fig.set_size_inches(8, num_features * row_height * np.sqrt(len(values)) + 1.5)

# if negative values are present then we draw a vertical line to mark 0, otherwise the axis does this for us...
negative_values_present = np.sum(values[:,feature_order[:num_features]] < 0) > 0
if negative_values_present:
pl.axvline(0, 0, 1, color="#000000", linestyle="-", linewidth=1, zorder=1)
ax.axvline(0, 0, 1, color="#000000", linestyle="-", linewidth=1, zorder=1)

# draw the bars
patterns = (None, '\\\\', '++', 'xx', '////', '*', 'o', 'O', '.', '-')
total_width = 0.7
bar_width = total_width / len(values)
for i in range(len(values)):
ypos_offset = - ((i - len(values) / 2) * bar_width + bar_width / 2)
pl.barh(
ax.barh(
y_pos + ypos_offset, values[i,feature_inds],
bar_width, align='center',
color=[colors.blue_rgb if values[i,feature_inds[j]] <= 0 else colors.red_rgb for j in range(len(y_pos))],
hatch=patterns[i], edgecolor=(1,1,1,0.8), label=f"{cohort_labels[i]} [{cohort_sizes[i] if i < len(cohort_sizes) else None}]"
)

# draw the yticks (the 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks)
pl.yticks(list(y_pos) + list(y_pos + 1e-8), yticklabels + [t.split('=')[-1] for t in yticklabels], fontsize=13)
ax.set_yticks(list(y_pos) + list(y_pos + 1e-8), yticklabels + [t.split('=')[-1] for t in yticklabels], fontsize=13)

xlen = pl.xlim()[1] - pl.xlim()[0]
fig = pl.gcf()
ax = pl.gca()
xlen = ax.get_xlim()[1] - ax.get_xlim()[0]
#xticks = ax.get_xticks()
bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
bbox = ax.get_window_extent().transformed(ax.figure.dpi_scale_trans.inverted())
width = bbox.width
bbox_to_xscale = xlen/width

Expand All @@ -246,21 +271,21 @@ def bar(shap_values, max_display=10, order=Explanation.abs, clustering=None, clu
for j in range(len(y_pos)):
ind = feature_order[j]
if values[i,ind] < 0:
pl.text(
ax.text(
values[i,ind] - (5/72)*bbox_to_xscale, y_pos[j] + ypos_offset, format_value(values[i,ind], '%+0.02f'),
horizontalalignment='right', verticalalignment='center', color=colors.blue_rgb,
fontsize=12
)
else:
pl.text(
ax.text(
values[i,ind] + (5/72)*bbox_to_xscale, y_pos[j] + ypos_offset, format_value(values[i,ind], '%+0.02f'),
horizontalalignment='left', verticalalignment='center', color=colors.red_rgb,
fontsize=12
)

# put horizontal lines for each feature row
for i in range(num_features):
pl.axhline(i+1, color="#888888", lw=0.5, dashes=(1, 5), zorder=-1)
ax.axhline(i+1, color="#888888", lw=0.5, dashes=(1, 5), zorder=-1)

if features is not None:
features = list(features)
Expand All @@ -273,33 +298,34 @@ def bar(shap_values, max_display=10, order=Explanation.abs, clustering=None, clu
except Exception:
pass # features[i] must not be a number

pl.gca().xaxis.set_ticks_position('bottom')
pl.gca().yaxis.set_ticks_position('none')
pl.gca().spines['right'].set_visible(False)
pl.gca().spines['top'].set_visible(False)
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('none')
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
if negative_values_present:
pl.gca().spines['left'].set_visible(False)
pl.gca().tick_params('x', labelsize=11)
ax.spines['left'].set_visible(False)
ax.tick_params('x', labelsize=11)

xmin,xmax = pl.gca().get_xlim()
ymin,ymax = pl.gca().get_ylim()
xmin, xmax = ax.get_xlim()
ymin, ymax = ax.get_ylim()
x_buffer = (xmax-xmin)*0.05

if negative_values_present:
pl.gca().set_xlim(xmin - (xmax-xmin)*0.05, xmax + (xmax-xmin)*0.05)
ax.set_xlim(xmin - x_buffer, xmax + x_buffer)
else:
pl.gca().set_xlim(xmin, xmax + (xmax-xmin)*0.05)
ax.set_xlim(xmin, xmax + x_buffer)

# if features is None:
# pl.xlabel(labels["GLOBAL_VALUE"], fontsize=13)
# else:
pl.xlabel(xlabel, fontsize=13)
ax.set_xlabel(xlabel, fontsize=13)

if len(values) > 1:
pl.legend(fontsize=12)
ax.legend(fontsize=12)

# color the y tick labels that have the feature values as gray
# (these fall behind the black ones with just the feature name)
tick_labels = pl.gca().yaxis.get_majorticklabels()
tick_labels = ax.yaxis.get_majorticklabels()
for i in range(num_features):
tick_labels[i].set_color("#999999")

Expand All @@ -311,15 +337,15 @@ def bar(shap_values, max_display=10, order=Explanation.abs, clustering=None, clu
ylines,xlines = dendrogram_coords(feature_pos, partition_tree)

# plot the distance cut line above which we don't show tree edges
xmin,xmax = pl.xlim()
xmin,xmax = ax.get_xlim()
xlines_min,xlines_max = np.min(xlines),np.max(xlines)
ct_line_pos = (clustering_cutoff / (xlines_max - xlines_min)) * 0.1 * (xmax - xmin) + xmax
pl.text(
ax.text(
ct_line_pos + 0.005 * (xmax - xmin), (ymax - ymin)/2, "Clustering cutoff = " + format_value(clustering_cutoff, '%0.02f'),
horizontalalignment='left', verticalalignment='center', color="#999999",
fontsize=12, rotation=-90
)
line = pl.axvline(ct_line_pos, color="#dddddd", dashes=(1, 1))
line = ax.axvline(ct_line_pos, color="#dddddd", dashes=(1, 1))
line.set_clip_on(False)

for (xline, yline) in zip(xlines, ylines):
Expand All @@ -332,7 +358,7 @@ def bar(shap_values, max_display=10, order=Explanation.abs, clustering=None, clu

# only draw if we are not going past the bottom of the plot
if yline.max() < max_display:
lines = pl.plot(
lines = ax.plot(
xv * 0.1 * (xmax - xmin) + xmax,
max_display - np.array(yline),
color="#999999"
Expand All @@ -343,7 +369,7 @@ def bar(shap_values, max_display=10, order=Explanation.abs, clustering=None, clu
if show:
pl.show()
else:
return pl.gcf()
return ax


def bar_legacy(shap_values, features=None, feature_names=None, max_display=None, show=True):
Expand Down
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
34 changes: 34 additions & 0 deletions tests/plots/test_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,37 @@ def test_simple_bar_with_cohorts_dict():
)
plt.tight_layout()
return fig


@pytest.mark.mpl_image_compare
def test_simple_bar_local_feature_importance(explainer):
"""Bar plot with single row of SHAP values"""
shap_values = explainer(explainer.data)
fig = plt.figure()
shap.plots.bar(shap_values[0], show=False)
plt.tight_layout()
return fig


@pytest.mark.mpl_image_compare
def test_simple_bar_with_clustering(explainer):
"""Bar plot with clustering"""
shap_values = explainer(explainer.data)
clustering = shap.utils.hclust(explainer.data, metric="cosine")
fig = plt.figure()
shap.plots.bar(shap_values, clustering=clustering, show=False)
plt.tight_layout()
return fig


def test_bar_raises_error_for_invalid_clustering(explainer):
shap_values = explainer(explainer.data)
clustering = np.array([1,2,3])
with pytest.raises(TypeError, match="does not seem to be a partition tree"):
shap.plots.bar(shap_values, clustering=clustering, show=False)


def test_bar_raises_error_for_empty_explanation(explainer):
shap_values = explainer(explainer.data)
with pytest.raises(ValueError, match="The passed Explanation is empty"):
shap.plots.bar(shap_values[0:0], show=False)