Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Adds plot_precision_recall_curve #14936

Merged
merged 30 commits into from Nov 11, 2019
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
720f9ac
WIP
thomasjpfan Aug 20, 2019
8ac4469
DOC Uses plot_precision_recall in example
thomasjpfan Aug 22, 2019
0b81383
DOC Adds to userguide
thomasjpfan Aug 22, 2019
e9c8131
DOC style
thomasjpfan Sep 3, 2019
241845f
Merge remote-tracking branch 'upstream/master' into plot_precision_re…
thomasjpfan Sep 3, 2019
3d86867
DOC Better docs
thomasjpfan Sep 5, 2019
c7029a6
Merge remote-tracking branch 'upstream/master' into plot_precision_re…
thomasjpfan Sep 5, 2019
9bf152b
Merge remote-tracking branch 'upstream/master' into plot_precision_re…
thomasjpfan Sep 9, 2019
66deaac
CLN
thomasjpfan Sep 9, 2019
10dc97e
CLN Address @glemaitre comments
thomasjpfan Sep 19, 2019
affec16
CLN Address @glemaitre comments
thomasjpfan Sep 20, 2019
99b18ed
Merge remote-tracking branch 'upstream/master' into plot_precision_re…
thomasjpfan Sep 20, 2019
1796020
DOC Remove whatsnew
thomasjpfan Sep 24, 2019
ee62183
Merge remote-tracking branch 'upstream/master' into plot_precision_re…
thomasjpfan Sep 24, 2019
d7d448f
DOC Style
thomasjpfan Sep 24, 2019
294c29a
CLN Addresses @amuller comments
thomasjpfan Sep 25, 2019
dbe9a3a
CLN Addresses @amuller comments
thomasjpfan Sep 25, 2019
3fade4b
Merge remote-tracking branch 'upstream/master' into plot_precision_re…
thomasjpfan Oct 2, 2019
fd1cc44
TST Clearier error messages
thomasjpfan Oct 2, 2019
fdb60ae
TST Modify test name
thomasjpfan Oct 2, 2019
7cbb1ac
Merge remote-tracking branch 'upstream/master' into plot_precision_re…
thomasjpfan Nov 6, 2019
abbbb9d
BUG Quick fix
thomasjpfan Nov 6, 2019
a589f92
BUG Fix test
thomasjpfan Nov 6, 2019
c9b3d60
ENH Better error message
thomasjpfan Nov 6, 2019
bfd5634
CLN Address comments
thomasjpfan Nov 7, 2019
dbae9f8
Merge remote-tracking branch 'upstream/master' into plot_precision_re…
thomasjpfan Nov 8, 2019
2c3d78d
CLN Address comments
thomasjpfan Nov 8, 2019
7736f77
CLN Move to base
thomasjpfan Nov 10, 2019
91f0d05
CLN Unify response detection
thomasjpfan Nov 10, 2019
a559342
CLN Removes unneeded check
thomasjpfan Nov 10, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/modules/classes.rst
Expand Up @@ -1031,12 +1031,14 @@ See the :ref:`visualizations` section of the user guide for further details.
:toctree: generated/
:template: function.rst

metrics.plot_precision_recall_curve
metrics.plot_roc_curve

.. autosummary::
:toctree: generated/
:template: class.rst

metrics.PrecisionRecallDisplay
metrics.RocCurveDisplay


Expand Down
2 changes: 2 additions & 0 deletions doc/visualizations.rst
Expand Up @@ -70,6 +70,7 @@ Functions

.. autosummary::

metrics.plot_precision_recall_curve
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alphabetic order?

metrics.plot_roc_curve


Expand All @@ -80,4 +81,5 @@ Display Objects

.. autosummary::

metrics.PrecisionRecallDisplay
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alphabetic order?

metrics.RocCurveDisplay
26 changes: 5 additions & 21 deletions examples/model_selection/plot_precision_recall.py
Expand Up @@ -134,25 +134,12 @@
# Plot the Precision-Recall curve
# ................................
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import plot_precision_recall_curve
import matplotlib.pyplot as plt
from inspect import signature

precision, recall, _ = precision_recall_curve(y_test, y_score)

# In matplotlib < 1.5, plt.fill_between does not have a 'step' argument
step_kwargs = ({'step': 'post'}
if 'step' in signature(plt.fill_between).parameters
else {})
plt.step(recall, precision, color='b', alpha=0.2,
where='post')
plt.fill_between(recall, precision, alpha=0.2, color='b', **step_kwargs)

plt.xlabel('Recall')
plt.ylabel('Precision')
plt.ylim([0.0, 1.05])
plt.xlim([0.0, 1.0])
plt.title('2-class Precision-Recall curve: AP={0:0.2f}'.format(
average_precision))
disp = plot_precision_recall_curve(classifier, X_test, y_test, color='b')
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
disp.ax_.set_title('2-class Precision-Recall curve: '
'AP={0:0.2f}'.format(average_precision))

###############################################################################
# In multi-label settings
Expand Down Expand Up @@ -212,10 +199,7 @@
#

plt.figure()
plt.step(recall['micro'], precision['micro'], color='b', alpha=0.2,
where='post')
plt.fill_between(recall["micro"], precision["micro"], alpha=0.2, color='b',
**step_kwargs)
plt.step(recall['micro'], precision['micro'], color='b', where='post')
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

plt.xlabel('Recall')
plt.ylabel('Precision')
Expand Down
4 changes: 4 additions & 0 deletions sklearn/metrics/__init__.py
Expand Up @@ -79,6 +79,8 @@

from ._plot.roc_curve import plot_roc_curve
from ._plot.roc_curve import RocCurveDisplay
from ._plot.precision_recall import plot_precision_recall_curve
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename the file to precision_recall_curve.py?

from ._plot.precision_recall import PrecisionRecallDisplay


__all__ = [
Expand Down Expand Up @@ -135,7 +137,9 @@
'pairwise_distances_argmin_min',
'pairwise_distances_chunked',
'pairwise_kernels',
'plot_precision_recall_curve',
'plot_roc_curve',
'PrecisionRecallDisplay',
'precision_recall_curve',
'precision_recall_fscore_support',
'precision_score',
Expand Down
172 changes: 172 additions & 0 deletions sklearn/metrics/_plot/precision_recall.py
@@ -0,0 +1,172 @@
from .. import average_precision_score
from .. import precision_recall_curve

from ...utils import check_matplotlib_support


class PrecisionRecallDisplay:
"""Precision Recall visualization.

It is recommend to use `sklearn.metrics.plot_precision_recall_curve` to
create a visualizer. All parameters are stored as attributes.

Read more in the :ref:`User Guide <visualizations>`.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add link to Visualization UG

Parameters
-----------
precision : ndarray
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
Precision values.

recall : ndarray
Recall values.

average_precision : float
Average precision.

estimator_name : str
Name of estimator.

Attributes
----------
line_ : matplotlib Artist
Precision recall curve.

ax_ : matplotlib Axes
Axes with precision recall curve.

figure_ : matplotlib Figure
Figure containing the curve.
"""

def __init__(self, precision, recall, average_precision, estimator_name):
self.precision = precision
self.recall = recall
self.average_precision = average_precision
self.estimator_name = estimator_name

def plot(self, ax=None, name=None, **kwargs):
"""Plot visualization
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

Extra keyword arguments will be passed to matplotlib's ``plot``.

Parameters
----------
ax : Matplotlib Axes, default=None
Axes object to plot on. If `None`, a new figure and axes is
created.

name : str, default=None
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
Name of precision recall curve for labeling. If `None`, use the
name of the estimator.

Returns
-------
display : :class:`~sklearn.metrics.PrecisionRecallDisplay`
Object that stores computed values.
"""
check_matplotlib_support("PrecisionRecallDisplay.plot")
import matplotlib.pyplot as plt

if ax is None:
fig, ax = plt.subplots()

name = self.estimator_name if name is None else name

line_kwargs = {
"label": "{} (AP = {:0.2f})".format(name, self.average_precision),
"drawstyle": "steps-post"
}
line_kwargs.update(**kwargs)

self.line_ = ax.plot(self.recall, self.precision, **line_kwargs)[0]
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
ax.set(xlabel="Recall", ylabel="Precision", ylim=[0.0, 1.05],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rely on default xlim/ylim?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the, x, going without the I think for the y it is kind of important because of the scaling:

With ylim explicit set:
set

Not set:
not_set

xlim=[0.0, 1.0])
ax.legend(loc='lower left')

self.ax_ = ax
self.figure_ = ax.figure
return self


def plot_precision_recall_curve(estimator, X, y, pos_label=None,
sample_weight=None, response_method="auto",
name=None, ax=None, **kwargs):
"""Plot Precision Recall Curve.

Extra keyword arguments will be passed to matplotlib's ``plot``.

Read more in the :ref:`User Guide <visualizations>`.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parameters
----------
estimator : estimator instance
Trained classifier.

X : {array-like, sparse matrix} of shape (n_samples, n_features)
Input values.

y : array-like of shape (n_samples,)
Target values.
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

pos_label : int or str, default=None
The label of the positive class.
When `pos_label=None`, if y_true is in {-1, 1} or {0, 1},
`pos_label` is set to 1, otherwise an error will be raised.

sample_weight : array-like of shape (n_samples,), default=None
Sample weights.

response_method : {'predict_proba', 'decision_function', 'auto'} \
default='auto'
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
Specifies whether to use :term:`predict_proba` or
:term:`decision_function` as the target response. If set to 'auto',
:term:`predict_proba` is tried first and if it does not exist
:term:`decision_function` is tried next.

name : str, default=None
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
Name for labeling curve. If `None`, the name of the
estimator is used.

ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is created.

Returns
-------
display : :class:`~sklearn.metrics.PrecisionRecallDisplay`
Object that stores computed values.
"""
check_matplotlib_support("plot_precision_recall_curve")

if response_method not in ("predict_proba", "decision_function", "auto"):
raise ValueError("response_method must be 'predict_proba', "
"'decision_function' or 'auto'")

if response_method != "auto":
prediction_method = getattr(estimator, response_method, None)
if prediction_method is None:
raise ValueError(
"response method {} is not defined".format(response_method))
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
else:
predict_proba = getattr(estimator, 'predict_proba', None)
decision_function = getattr(estimator, 'decision_function', None)
prediction_method = predict_proba or decision_function

if prediction_method is None:
raise ValueError('response methods not defined')
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

y_pred = prediction_method(X)
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

if y_pred.ndim != 1:
if y_pred.shape[1] > 2:
raise ValueError("Estimator should solve a "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could fail before to compute the y_pred by checking the target y using check_classification_targets

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it possible to use check_classification_targets?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this only supports binary classification. A type_of_target(y) == 'binary' was added before the compute.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm we could also check if the classifier is binary but that might be overkill?
Alternatively we could plot it for a single class even in a multiclass setting by getting the probabilities that correspond to the class we're interested in (that wouldn't work for decision_function necessarily though?). I'm ok with doing this one as a first version, though.

"binary classification problem")
y_pred = y_pred[:, 1]
amueller marked this conversation as resolved.
Show resolved Hide resolved

precision, recall, _ = precision_recall_curve(y, y_pred,
pos_label=pos_label,
sample_weight=sample_weight)
average_precision = average_precision_score(y, y_pred,
sample_weight=sample_weight)
viz = PrecisionRecallDisplay(precision, recall, average_precision,
estimator.__class__.__name__)
return viz.plot(ax=ax, name=name, **kwargs)
93 changes: 93 additions & 0 deletions sklearn/metrics/_plot/tests/test_plot_precision_recall.py
@@ -0,0 +1,93 @@
import pytest
import numpy as np
from numpy.testing import assert_allclose

from sklearn.metrics import plot_precision_recall_curve
from sklearn.metrics import average_precision_score
from sklearn.metrics import precision_recall_curve
from sklearn.datasets import load_breast_cancer
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression


@pytest.fixture(scope="module")
def data_binary():
return load_breast_cancer(return_X_y=True)


def test_error_non_binary(pyplot):
X, y = load_iris(return_X_y=True)
clf = DecisionTreeClassifier()
clf.fit(X, y)

msg = "Estimator should solve a binary classification problem"
with pytest.raises(ValueError, match=msg):
plot_precision_recall_curve(clf, X, y)


@pytest.mark.parametrize(
"response_method, msg",
[("predict_proba", "response method predict_proba is not defined"),
("decision_function", "response method decision_function is not defined"),
("auto", "response methods not defined"),
("bad_method", "response_method must be 'predict_proba', "
"'decision_function' or 'auto'")])
def test_error_no_response(pyplot, data_binary, response_method, msg):
X, y = data_binary

class MyClassifier:
pass

clf = MyClassifier()

with pytest.raises(ValueError, match=msg):
plot_precision_recall_curve(clf, X, y, response_method=response_method)


@pytest.mark.parametrize("response_method",
["predict_proba", "decision_function"])
@pytest.mark.parametrize("with_sample_weight", [True, False])
def test_plot_precision_recall(pyplot, response_method, data_binary,
with_sample_weight):
X, y = data_binary

lr = LogisticRegression()
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
lr.fit(X, y)
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

if with_sample_weight:
rng = np.random.RandomState(42)
sample_weight = rng.randint(0, 4, size=X.shape[0])
else:
sample_weight = None

viz = plot_precision_recall_curve(lr, X, y, alpha=0.8,
sample_weight=sample_weight)

y_score = getattr(lr, response_method)(X)
if y_score.ndim == 2:
y_score = y_score[:, 1]

prec, recall, _ = precision_recall_curve(y, y_score,
sample_weight=sample_weight)
avg_prec = average_precision_score(y, y_score, sample_weight=sample_weight)

assert_allclose(viz.precision, prec)
assert_allclose(viz.recall, recall)
assert_allclose(viz.average_precision, avg_prec)
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

assert viz.estimator_name == "LogisticRegression"
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

# cannot fail thanks to pyplot fixture
import matplotlib as mpl # noqa
assert isinstance(viz.line_, mpl.lines.Line2D)
assert viz.line_.get_alpha() == 0.8
assert isinstance(viz.ax_, mpl.axes.Axes)
assert isinstance(viz.figure_, mpl.figure.Figure)

expected_label = "LogisticRegression (AP = {:0.2f})".format(avg_prec)
assert viz.line_.get_label() == expected_label
assert viz.ax_.get_xlabel() == "Recall"
assert viz.ax_.get_ylabel() == "Precision"
assert_allclose(viz.ax_.get_xlim(), [0.0, 1.0])
assert_allclose(viz.ax_.get_ylim(), [0.0, 1.05])