Skip to content

Commit

Permalink
[MRG] BUG Checks to number of axes in passed in ax more generically (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasjpfan authored and ogrisel committed Dec 19, 2019
1 parent a6c07f2 commit 556eb92
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 21 deletions.
7 changes: 7 additions & 0 deletions doc/whats_new/v0.22.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ Changelog
- |Fix| :func:`utils.check_array` now correctly converts pandas DataFrame with
boolean columns to floats. :pr:`15797` by `Thomas Fan`_.

:mod:`sklearn.inspection`
.........................

- |Fix| :func:`inspection.plot_partial_dependence` and
:meth:`inspection.PartialDependenceDisplay.plot` now consistently checks
the number of axes passed in. :pr:`15760` by `Thomas Fan`_.

.. _changes_0_22:

Version 0.22.0
Expand Down
18 changes: 10 additions & 8 deletions sklearn/inspection/_partial_dependence.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,10 +655,12 @@ def convert_feature(fx):

features = tmp_features

if isinstance(ax, list):
if len(ax) != len(features):
raise ValueError("Expected len(ax) == len(features), "
"got len(ax) = {}".format(len(ax)))
# Early exit if the axes does not have the correct number of axes
if ax is not None and not isinstance(ax, plt.Axes):
axes = np.asarray(ax, dtype=object)
if axes.size != len(features):
raise ValueError("Expected ax to have {} axes, got {}".format(
len(features), axes.size))

for i in chain.from_iterable(features):
if i >= len(feature_names):
Expand Down Expand Up @@ -886,16 +888,16 @@ def plot(self, ax=None, n_cols=3, line_kw=None, contour_kw=None):
axes_ravel[i] = self.figure_.add_subplot(spec)

else: # array-like
ax = check_array(ax, dtype=object, ensure_2d=False)
ax = np.asarray(ax, dtype=object)
if ax.size != n_features:
raise ValueError("Expected ax to have {} axes, got {}"
.format(n_features, ax.size))

if ax.ndim == 2:
n_cols = ax.shape[1]
else:
n_cols = None

if ax.ndim == 1 and ax.shape[0] != n_features:
raise ValueError("Expected len(ax) == len(features), "
"got len(ax) = {}".format(len(ax)))
self.bounding_ax_ = None
self.figure_ = ax.ravel()[0].figure
self.axes_ = ax
Expand Down
31 changes: 18 additions & 13 deletions sklearn/inspection/tests/test_plot_partial_dependence.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,26 +222,31 @@ def test_plot_partial_dependence_passing_numpy_axes(pyplot, clf_boston,
assert len(disp2.axes_[0, 1].get_lines()) == 2


@pytest.mark.parametrize("nrows, ncols", [(2, 2), (3, 1)])
def test_plot_partial_dependence_incorrent_num_axes(pyplot, clf_boston,
boston):
grid_resolution = 25
fig, (ax1, ax2, ax3) = pyplot.subplots(1, 3)
boston, nrows, ncols):
grid_resolution = 5
fig, axes = pyplot.subplots(nrows, ncols)
axes_formats = [list(axes.ravel()), tuple(axes.ravel()), axes]

msg = r"Expected len\(ax\) == len\(features\), got len\(ax\) = 3"
with pytest.raises(ValueError, match=msg):
plot_partial_dependence(clf_boston, boston.data,
['CRIM', ('CRIM', 'ZN')],
grid_resolution=grid_resolution,
feature_names=boston.feature_names,
ax=[ax1, ax2, ax3])
msg = "Expected ax to have 2 axes, got {}".format(nrows * ncols)

disp = plot_partial_dependence(clf_boston, boston.data,
['CRIM', ('CRIM', 'ZN')],
['CRIM', 'ZN'],
grid_resolution=grid_resolution,
feature_names=boston.feature_names)

with pytest.raises(ValueError, match=msg):
disp.plot(ax=[ax1, ax2, ax3])
for ax_format in axes_formats:
with pytest.raises(ValueError, match=msg):
plot_partial_dependence(clf_boston, boston.data,
['CRIM', 'ZN'],
grid_resolution=grid_resolution,
feature_names=boston.feature_names,
ax=ax_format)

# with axes object
with pytest.raises(ValueError, match=msg):
disp.plot(ax=ax_format)


def test_plot_partial_dependence_with_same_axes(pyplot, clf_boston, boston):
Expand Down

0 comments on commit 556eb92

Please sign in to comment.