Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Adds Plotting API to Partial Dependence #14646

Merged
merged 36 commits into from Sep 20, 2019
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
fa93505
WIP
thomasjpfan Aug 8, 2019
d1dc257
Merge remote-tracking branch 'upstream/master' into plotting_api_part…
thomasjpfan Aug 12, 2019
1dac62c
WIP Adding tests
thomasjpfan Aug 12, 2019
bc45123
WIP Doc
thomasjpfan Aug 13, 2019
9d67f6a
CLN Creates _plot folder
thomasjpfan Aug 13, 2019
5129289
CLN Adds _plot into setup
thomasjpfan Aug 13, 2019
6f5407e
CLN Fixes bugs
thomasjpfan Aug 13, 2019
8bc2433
DOC Adds dev docs
thomasjpfan Aug 14, 2019
4aa273c
ENH Updates API
thomasjpfan Aug 14, 2019
5263235
Merge remote-tracking branch 'upstream/master' into plotting_api_part…
thomasjpfan Aug 15, 2019
045aa8f
CLN Address comments
thomasjpfan Aug 16, 2019
75f8ed8
TST Removes check
thomasjpfan Aug 16, 2019
5fddc5f
DOC Address comments
thomasjpfan Aug 16, 2019
396f1cb
DOC Uses normal decision tree in example
thomasjpfan Aug 16, 2019
bd121c3
CLN Refactor array-like
thomasjpfan Aug 16, 2019
422992e
STY Minor
thomasjpfan Aug 16, 2019
ca532d7
DOC Update
thomasjpfan Aug 16, 2019
7316ef5
TST trigger
thomasjpfan Aug 16, 2019
9b916cf
DOC Update docstring
thomasjpfan Aug 26, 2019
ec47bd8
Merge remote-tracking branch 'upstream/master' into plotting_api_part…
thomasjpfan Aug 27, 2019
170a1b5
DOC Adds new plotting api section
thomasjpfan Aug 27, 2019
9ad9477
DOC Uses :: syntax
thomasjpfan Aug 27, 2019
cab552a
CLN Moves files around
thomasjpfan Sep 4, 2019
0f5dee1
DOC Updates plotting.rst
thomasjpfan Sep 5, 2019
e7965d9
DOC Address comments
thomasjpfan Sep 5, 2019
24ee8b0
CLN Address comments
thomasjpfan Sep 5, 2019
d339328
CLN Address comments
thomasjpfan Sep 5, 2019
f7e946d
Merge remote-tracking branch 'upstream/master' into plotting_api_part…
thomasjpfan Sep 5, 2019
480bcef
DOC More links
thomasjpfan Sep 5, 2019
e9e5d3a
DOC More words
thomasjpfan Sep 5, 2019
4b50a8b
BUG Fix setup file
thomasjpfan Sep 7, 2019
f77f5d1
DOC Fix
thomasjpfan Sep 9, 2019
612a80e
REV Remove unrelated commit
thomasjpfan Sep 12, 2019
a8182d8
Merge remote-tracking branch 'upstream/master' into plotting_api_part…
thomasjpfan Sep 12, 2019
6c3a6c9
Merge remote-tracking branch 'upstream/master' into plotting_api_part…
thomasjpfan Sep 19, 2019
42e0c9b
CLN Address @glemaitre comments
thomasjpfan Sep 19, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 0 additions & 50 deletions doc/developers/contributing.rst
Expand Up @@ -1070,53 +1070,3 @@ make this task easier and faster (in no particular order).
<https://git-scm.com/docs/git-grep#_examples>`_) is also extremely
useful to see every occurrence of a pattern (e.g. a function call or a
variable) in the code base.


.. _plotting_api:

Plotting API
============

Scikit-learn defines a simple API for creating visualizations for machine
learning. The key features of this API is to run calculations once and to have
the flexibility to adjust the visualizations after the fact. This logic is
encapsulated into a display object where the computed data is stored and
the plotting is done in a `plot` method. The display object's `__init__`
method contains only the data needed to create the visualization. The `plot`
method takes in parameters that only have to do with visualization, such as a
matplotlib axes. The `plot` method will store the matplotlib artists as
attributes allowing for style adjustments through the display object. A
`plot_*` helper function accepts parameters to do the computation and the
parameters used for plotting. After the helper function creates the display
object with the computed values, it calls the display's plot method. Note
that the `plot` method defines attributes related to matplotlib, such as the
line artist. This allows for customizations after calling the `plot` method.

For example, the `RocCurveDisplay` defines the following methods and
attributes:

.. code-block:: python

class RocCurveDisplay:
def __init__(self, fpr, tpr, roc_auc, estimator_name):
...
self.fpr = fpr
self.tpr = tpr
self.roc_auc = roc_auc
self.estimator_name = estimator_name

def plot(self, ax=None, name=None, **kwargs):
...
self.line_ = ...
self.ax_ = ax
self.figure_ = ax.figure_

def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None,
drop_intermediate=True, response_method="auto",
name=None, ax=None, **kwargs):
# do computation
viz = RocCurveDisplay(fpr, tpr, roc_auc,
estimator.__class__.__name__)
return viz.plot(ax=ax, name=name, **kwargs)

Read more in the :ref:`User Guide <visualizations>`.
1 change: 1 addition & 0 deletions doc/developers/index.rst
Expand Up @@ -16,3 +16,4 @@ Developer's Guide
performance
advanced_installation
maintainer
plotting
90 changes: 90 additions & 0 deletions doc/developers/plotting.rst
@@ -0,0 +1,90 @@
.. _plotting_api:

================================
Developing with the Plotting API
================================

Scikit-learn defines a simple API for creating visualizations for machine
learning. The key features of this API is to run calculations once and to have
the flexibility to adjust the visualizations after the fact. This section is
intended for developers who wish to develop or maintain plotting tools. For
usage, users should refer to the :ref`User Guide <visualizations>`.

Plotting API Overview
---------------------

This logic is encapsulated into a display object where the computed data is
stored and the plotting is done in a `plot` method. The display object's
`__init__` method contains only the data needed to create the visualization.
The `plot` method takes in parameters that only have to do with visualization,
such as a matplotlib axes. The `plot` method will store the matplotlib artists
as attributes allowing for style adjustments through the display object. A
`plot_*` helper function accepts parameters to do the computation and the
parameters used for plotting. After the helper function creates the display
object with the computed values, it calls the display's plot method. Note that
the `plot` method defines attributes related to matplotlib, such as the line
artist. This allows for customizations after calling the `plot` method.

For example, the `RocCurveDisplay` defines the following methods and
attributes::

class RocCurveDisplay:
def __init__(self, fpr, tpr, roc_auc, estimator_name):
...
self.fpr = fpr
self.tpr = tpr
self.roc_auc = roc_auc
self.estimator_name = estimator_name

def plot(self, ax=None, name=None, **kwargs):
...
self.line_ = ...
self.ax_ = ax
self.figure_ = ax.figure_

def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None,
drop_intermediate=True, response_method="auto",
name=None, ax=None, **kwargs):
# do computation
viz = RocCurveDisplay(fpr, tpr, roc_auc,
estimator.__class__.__name__)
return viz.plot(ax=ax, name=name, **kwargs)

Read more in :ref:`sphx_glr_auto_examples_plot_roc_curve_visualization_api.py`
and the :ref:`User Guide <visualizations>`.

Plotting with Multiple Axes
---------------------------

Some of the plotting tools like
:func:`~sklearn.inspection.plot_partial_dependence` and
:class:`~sklearn.inspection.PartialDependenceDisplay` support plottong on
multiple axes. Two different scenarios are supported:

1. If a list of axes is passed in, `plot` will check if the number of axes is
consistent with the number of axes it expects and then draws on those axes. 2.
If a single axes is passed in, that axes defines a space for multiple axes to
be placed. In this case, we suggest using matplotlib's
`~matplotlib.gridspec.GridSpecFromSubplotSpec` to split up the space::

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpecFromSubplotSpec

fig, ax = plt.subplots()
gs = GridSpecFromSubplotSpec(2, 2, subplot_spec=ax.get_subplotspec())

ax_top_left = fig.add_subplot(gs[0, 0])
ax_top_right = fig.add_subplot(gs[0, 1])
ax_bottom = fig.add_subplot(gs[1, :])

By default, the `ax` keyword in `plot` is `None`. In this case, the single
axes is created and the gridspec api is used to create the regions to plot in.

See for example, :func:`~sklearn.inspection.plot_partial_dependence` which
plots multiple lines and contours using this API. The axes defining the
bounding box is saved in a `bounding_ax_` attribute. The individual axes
created are stored in an `axes_` ndarray, corresponding to the axes position on
the grid. Positions that are not used are set to `None`. Furthermore, the
matplotlib Artists are stored in `lines_` and `contours_` where the key is the
position on the grid. When a list of axes is passed in, the `axes_`, `lines_`,
and `contours_` is a 1d ndarray corresponding to the list of axes passed in.
6 changes: 6 additions & 0 deletions doc/modules/classes.rst
Expand Up @@ -665,6 +665,12 @@ Plotting

.. currentmodule:: sklearn

.. autosummary::
:toctree: generated/
:template: class.rst

inspection.PartialDependenceDisplay
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

.. autosummary::
:toctree: generated/
:template: function.rst
Expand Down
3 changes: 3 additions & 0 deletions doc/visualizations.rst
Expand Up @@ -59,6 +59,7 @@ values of the curves.
.. topic:: Examples:

* :ref:`sphx_glr_auto_examples_plot_roc_curve_visualization_api.py`
* :ref:`sphx_glr_auto_examples_plot_partial_dependence_visualization_api.py`

Available Plotting Utilities
============================
Expand All @@ -70,6 +71,7 @@ Functions

.. autosummary::

inspection.plot_partial_dependence
metrics.plot_roc_curve


Expand All @@ -80,4 +82,5 @@ Display Objects

.. autosummary::

inspection.PartialDependenceDisplay
metrics.RocCurveDisplay
5 changes: 3 additions & 2 deletions examples/inspection/plot_partial_dependence.py
Expand Up @@ -111,7 +111,7 @@
fig = plt.gcf()
fig.suptitle('Partial dependence of house value on non-location features\n'
'for the California housing dataset, with MLPRegressor')
fig.tight_layout(rect=[0, 0.03, 1, 0.95])
fig.subplots_adjust(wspace=0.8, hspace=0.3)

##############################################################################
# Partial Dependence computation for Gradient Boosting
Expand Down Expand Up @@ -150,7 +150,8 @@
fig = plt.gcf()
fig.suptitle('Partial dependence of house value on non-location features\n'
'for the California housing dataset, with Gradient Boosting')
fig.tight_layout(rect=[0, 0.03, 1, 0.95])
fig.subplots_adjust(wspace=0.8, hspace=0.3)


##############################################################################
# Analysis of the plots
Expand Down
138 changes: 138 additions & 0 deletions examples/plot_partial_dependence_visualization_api.py
@@ -0,0 +1,138 @@
"""
=========================================
Advanced Plotting With Partial Dependence
=========================================
The :func:`~sklearn.inspection.plot_partial_dependence` function returns a
:class:`~sklearn.inspection.PartialDependenceDisplay` object that can be used
for plotting without needing to recalculate the partial dependence. In this
example, we show how to plot partial dependence plots and how to quickly
customize the plot with the Visualization API.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we use Display API now? (also example file name?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not changed the name of how to reference the API. I would still say "visualization API". "Display API" sounds overly generic.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, then lowercase (I thought this was a class ^^)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
customize the plot with the Visualization API.
customize the plot with the visualization API.


.. note::

See also :ref:`sphx_glr_auto_examples_plot_roc_curve_visualization_api.py`

"""
print(__doc__)

import matplotlib.pyplot as plt
from sklearn.datasets import load_boston
from sklearn.neural_network import MLPRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.tree import DecisionTreeRegressor
from sklearn.inspection import plot_partial_dependence


##############################################################################
# Train models on the boston housing price dataset
# ================================================
#
# First, we train a decision tree and a multi-layer perceptron on the boston
# housing price dataset.

boston = load_boston()
X, y = boston.data, boston.target
feature_names = boston.feature_names

tree = DecisionTreeRegressor()
mlp = make_pipeline(StandardScaler(),
MLPRegressor(hidden_layer_sizes=(100, 100),
tol=1e-2, max_iter=500, random_state=0))
tree.fit(X, y)
mlp.fit(X, y)


##############################################################################
# Plotting partial dependence for two features
# ============================================
#
# We plot partial dependence curves for features "LSTAT" and "RM" for
# the decision tree. With two features,
# :func:`~sklearn.inspection.plot_partial_dependence` expects to plot two
# curves. Here the plot function place a grid of two plots using the space
# defined by `ax` .
fig, ax = plt.subplots(figsize=(12, 6))
ax.set_title("Decision Tree")
tree_disp = plot_partial_dependence(tree, X, ["LSTAT", "RM"],
feature_names=feature_names, ax=ax)

##############################################################################
# The partial depdendence curves can be plotted for the multi-layer perceptron.
# In this case, `line_kw` is passed to
# :func:`~sklearn.inspection.plot_partial_dependence` to change the color of
# the curve.
fig, ax = plt.subplots(figsize=(12, 6))
ax.set_title("Multi-layer Perceptron")
mlp_disp = plot_partial_dependence(mlp, X, ["LSTAT", "RM"],
feature_names=feature_names, ax=ax,
line_kw={"c": "red"})

##############################################################################
# Plotting partial dependence of the two models together
# ======================================================
#
# The `tree_disp` and `mlp_disp`
# :class:`~sklearn.inspection.PartialDependenceDisplay` objects contain all the
# computed information needed to recreate the partial dependence curves. This
# means we can easily create additional plots without needing to recompute the
# curves.
#
# One way to plot the curves is to place them in the same figure, with the
# curves of each model on each row. First, we create a figure with two axes
# within two rows and one column. The two axes are passed to the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the example but it's pretty condensed and it combines multiple functionalities at once so it's hard to understand how things interact independently.

For example it's a bit confusing at first how you can end up with a 2 by 2 grid while you only asked for 1 column.

I think it's missing a very simple plot with just 1 feature where you combine the curves of the 2 models. Then

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a "Plotting partial dependence for one feature" section at the end. With the current API, there are two ways to plot a single feature:

Option 1

tree_disp = plot_partial_dependence(tree, X, ["LSTAT"],
                                    feature_names=feature_names)
mlp_disp = plot_partial_dependence(mlp, X, ["LSTAT"],
                                   feature_names=feature_names,
                                   ax=tree_disp.axes_, line_kw={"c": "red"})

Option 2

_, ax = plt.subplots()
tree_disp = plot_partial_dependence(tree, X, ["LSTAT"],
                                    feature_names=feature_names, ax=[ax])
mlp_disp = plot_partial_dependence(mlp, X, ["LSTAT"],
                                   feature_names=feature_names,
                                   ax=[ax], line_kw={"c": "red"})

For the example, I went with option 1. The "nicer" way to do this is:

Possible option that is being disallowed.

_, ax = plt.subplots()
tree_disp = plot_partial_dependence(tree, X, ["LSTAT"],
                                    feature_names=feature_names, ax=ax)
mlp_disp = plot_partial_dependence(mlp, X, ["LSTAT"],
                                   feature_names=feature_names,
                                   ax=ax, line_kw={"c": "red"})

The first call will call ax.set_visible(False) to denote that the space has been used. The second call will see that the axes is not visible and raise an error. We can technically support this "single feature" and "single axes" case, which I think will add another layer of complexity to the API i.e. "If ax is a single axes and len(features) == 1, then we behave differently"

# :func:`~sklearn.inspection.PartialDependenceDisplay.plot` functions of
# `tree_disp` and `mlp_disp`. The given axes will be used by the plotting
# function to draw the partial dependence. The resulting plot places the
# decision tree partial dependence curves in the first row of the
# multi-layer perceptron in the second row.

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10))
tree_disp.plot(ax=ax1)
ax1.set_title("Decision Tree")
mlp_disp.plot(ax=ax2, line_kw={"c": "red"})
ax2.set_title("Multi-layer Perceptron")

##############################################################################
# Another way to compare the curves is to plot them on top of each other. Here,
# we create a figure with one row and two columns. The axes are passed into the
# :func:`~sklearn.inspection.PartialDependenceDisplay.plot` function as list,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# :func:`~sklearn.inspection.PartialDependenceDisplay.plot` function as list,
# :func:`~sklearn.inspection.PartialDependenceDisplay.plot` function as a list,

# which will plot the partial dependence curves of each model on the same axes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a sentence saying that the length of axes must be equal to the number of features we plot?

# The length of the axes list must be equal to the number of plots drawn.

# sphinx_gallery_thumbnail_number = 4
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 6))
tree_disp.plot(ax=[ax1, ax2], line_kw={"label": "Decision Tree"})
mlp_disp.plot(ax=[ax1, ax2], line_kw={"label": "Multi-layer Perceptron",
"c": "red"})
ax1.legend()
ax2.legend()

##############################################################################
# `tree_disp.axes_` is a numpy array container the axes used to draw th
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# `tree_disp.axes_` is a numpy array container the axes used to draw th
# `tree_disp.axes_` is a numpy array container the axes used to draw the

# partial dependence plots. This can be passed to `mlp_disp` to have the same
# affect of drawing the plots on top of each other. Furthermore, the
# `mlp_disp.figure_` stores the figure, which allows for resizing the figure
# after calling `plot`.

tree_disp.plot(line_kw={"label": "Decision Tree"})
mlp_disp.plot(line_kw={"label": "Multi-layer Perceptron", "c": "red"},
ax=tree_disp.axes_)
tree_disp.figure_.set_size_inches(10, 6)
tree_disp.axes_[0, 0].legend()
tree_disp.axes_[0, 1].legend()
plt.show()


##############################################################################
# Plotting partial dependence for one feature
# ===========================================
#
# Here we plot the partial dependence curves for a single feature, "LSTAT", on
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Here we plot the partial dependence curves for a single feature, "LSTAT", on
# Here, we plot the partial dependence curves for a single feature, "LSTAT", on

# the same axes. In this case, `tree_disp.axes_` is passed into the second
# plot function.
tree_disp = plot_partial_dependence(tree, X, ["LSTAT"],
feature_names=feature_names)
mlp_disp = plot_partial_dependence(mlp, X, ["LSTAT"],
feature_names=feature_names,
ax=tree_disp.axes_, line_kw={"c": "red"})
4 changes: 3 additions & 1 deletion sklearn/inspection/__init__.py
@@ -1,10 +1,12 @@
"""The :mod:`sklearn.inspection` module includes tools for model inspection."""
from .partial_dependence import partial_dependence
from .partial_dependence import plot_partial_dependence
from .partial_dependence import PartialDependenceDisplay
from .permutation_importance import permutation_importance

__all__ = [
'partial_dependence',
'plot_partial_dependence',
'permutation_importance'
'permutation_importance',
'PartialDependenceDisplay'
]