New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Adds plot_precision_recall_curve #14936
Changes from 9 commits
720f9ac
8ac4469
0b81383
e9c8131
241845f
3d86867
c7029a6
9bf152b
66deaac
10dc97e
affec16
99b18ed
1796020
ee62183
d7d448f
294c29a
dbe9a3a
3fade4b
fd1cc44
fdb60ae
7cbb1ac
abbbb9d
a589f92
c9b3d60
bfd5634
dbae9f8
2c3d78d
7736f77
91f0d05
a559342
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,6 +70,7 @@ Functions | |
|
||
.. autosummary:: | ||
|
||
metrics.plot_precision_recall_curve | ||
metrics.plot_roc_curve | ||
|
||
|
||
|
@@ -80,4 +81,5 @@ Display Objects | |
|
||
.. autosummary:: | ||
|
||
metrics.PrecisionRecallDisplay | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. alphabetic order? |
||
metrics.RocCurveDisplay |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -79,6 +79,8 @@ | |
|
||
from ._plot.roc_curve import plot_roc_curve | ||
from ._plot.roc_curve import RocCurveDisplay | ||
from ._plot.precision_recall import plot_precision_recall_curve | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rename the file to |
||
from ._plot.precision_recall import PrecisionRecallDisplay | ||
|
||
|
||
__all__ = [ | ||
|
@@ -135,7 +137,9 @@ | |
'pairwise_distances_argmin_min', | ||
'pairwise_distances_chunked', | ||
'pairwise_kernels', | ||
'plot_precision_recall_curve', | ||
'plot_roc_curve', | ||
'PrecisionRecallDisplay', | ||
'precision_recall_curve', | ||
'precision_recall_fscore_support', | ||
'precision_score', | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
from .. import average_precision_score | ||
from .. import precision_recall_curve | ||
|
||
from ...utils import check_matplotlib_support | ||
|
||
|
||
class PrecisionRecallDisplay: | ||
"""Precision Recall visualization. | ||
|
||
It is recommend to use `sklearn.metrics.plot_precision_recall_curve` to | ||
create a visualizer. All parameters are stored as attributes. | ||
|
||
Read more in the :ref:`User Guide <visualizations>`. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add link to Visualization UG |
||
Parameters | ||
----------- | ||
precision : ndarray | ||
glemaitre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Precision values. | ||
|
||
recall : ndarray | ||
Recall values. | ||
|
||
average_precision : float | ||
Average precision. | ||
|
||
estimator_name : str | ||
Name of estimator. | ||
|
||
Attributes | ||
---------- | ||
line_ : matplotlib Artist | ||
Precision recall curve. | ||
|
||
ax_ : matplotlib Axes | ||
Axes with precision recall curve. | ||
|
||
figure_ : matplotlib Figure | ||
Figure containing the curve. | ||
""" | ||
|
||
def __init__(self, precision, recall, average_precision, estimator_name): | ||
self.precision = precision | ||
self.recall = recall | ||
self.average_precision = average_precision | ||
self.estimator_name = estimator_name | ||
|
||
def plot(self, ax=None, name=None, **kwargs): | ||
"""Plot visualization | ||
glemaitre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Extra keyword arguments will be passed to matplotlib's ``plot``. | ||
|
||
Parameters | ||
---------- | ||
ax : Matplotlib Axes, default=None | ||
Axes object to plot on. If `None`, a new figure and axes is | ||
created. | ||
|
||
name : str, default=None | ||
glemaitre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Name of precision recall curve for labeling. If `None`, use the | ||
name of the estimator. | ||
|
||
Returns | ||
------- | ||
display : :class:`~sklearn.metrics.PrecisionRecallDisplay` | ||
Object that stores computed values. | ||
""" | ||
check_matplotlib_support("PrecisionRecallDisplay.plot") | ||
import matplotlib.pyplot as plt | ||
|
||
if ax is None: | ||
fig, ax = plt.subplots() | ||
|
||
name = self.estimator_name if name is None else name | ||
|
||
line_kwargs = { | ||
"label": "{} (AP = {:0.2f})".format(name, self.average_precision), | ||
"drawstyle": "steps-post" | ||
} | ||
line_kwargs.update(**kwargs) | ||
|
||
self.line_ = ax.plot(self.recall, self.precision, **line_kwargs)[0] | ||
glemaitre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ax.set(xlabel="Recall", ylabel="Precision", ylim=[0.0, 1.05], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rely on default xlim/ylim? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
xlim=[0.0, 1.0]) | ||
ax.legend(loc='lower left') | ||
|
||
self.ax_ = ax | ||
self.figure_ = ax.figure | ||
return self | ||
|
||
|
||
def plot_precision_recall_curve(estimator, X, y, pos_label=None, | ||
sample_weight=None, response_method="auto", | ||
name=None, ax=None, **kwargs): | ||
"""Plot Precision Recall Curve. | ||
|
||
Extra keyword arguments will be passed to matplotlib's ``plot``. | ||
|
||
Read more in the :ref:`User Guide <visualizations>`. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and link include |
||
Parameters | ||
---------- | ||
estimator : estimator instance | ||
Trained classifier. | ||
|
||
X : {array-like, sparse matrix} of shape (n_samples, n_features) | ||
Input values. | ||
|
||
y : array-like of shape (n_samples,) | ||
Target values. | ||
glemaitre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
pos_label : int or str, default=None | ||
The label of the positive class. | ||
When `pos_label=None`, if y_true is in {-1, 1} or {0, 1}, | ||
`pos_label` is set to 1, otherwise an error will be raised. | ||
|
||
sample_weight : array-like of shape (n_samples,), default=None | ||
Sample weights. | ||
|
||
response_method : {'predict_proba', 'decision_function', 'auto'} \ | ||
default='auto' | ||
glemaitre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Specifies whether to use :term:`predict_proba` or | ||
:term:`decision_function` as the target response. If set to 'auto', | ||
:term:`predict_proba` is tried first and if it does not exist | ||
:term:`decision_function` is tried next. | ||
|
||
name : str, default=None | ||
glemaitre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Name for labeling curve. If `None`, the name of the | ||
estimator is used. | ||
|
||
ax : matplotlib axes, default=None | ||
Axes object to plot on. If `None`, a new figure and axes is created. | ||
|
||
Returns | ||
------- | ||
display : :class:`~sklearn.metrics.PrecisionRecallDisplay` | ||
Object that stores computed values. | ||
""" | ||
check_matplotlib_support("plot_precision_recall_curve") | ||
|
||
if response_method not in ("predict_proba", "decision_function", "auto"): | ||
raise ValueError("response_method must be 'predict_proba', " | ||
"'decision_function' or 'auto'") | ||
|
||
if response_method != "auto": | ||
prediction_method = getattr(estimator, response_method, None) | ||
if prediction_method is None: | ||
raise ValueError( | ||
"response method {} is not defined".format(response_method)) | ||
glemaitre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
predict_proba = getattr(estimator, 'predict_proba', None) | ||
decision_function = getattr(estimator, 'decision_function', None) | ||
prediction_method = predict_proba or decision_function | ||
|
||
if prediction_method is None: | ||
raise ValueError('response methods not defined') | ||
glemaitre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
y_pred = prediction_method(X) | ||
glemaitre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if y_pred.ndim != 1: | ||
if y_pred.shape[1] > 2: | ||
raise ValueError("Estimator should solve a " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could fail before to compute the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isn't it possible to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this only supports There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hm we could also check if the classifier is binary but that might be overkill? |
||
"binary classification problem") | ||
y_pred = y_pred[:, 1] | ||
amueller marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
precision, recall, _ = precision_recall_curve(y, y_pred, | ||
pos_label=pos_label, | ||
sample_weight=sample_weight) | ||
average_precision = average_precision_score(y, y_pred, | ||
sample_weight=sample_weight) | ||
viz = PrecisionRecallDisplay(precision, recall, average_precision, | ||
estimator.__class__.__name__) | ||
return viz.plot(ax=ax, name=name, **kwargs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import pytest | ||
import numpy as np | ||
from numpy.testing import assert_allclose | ||
|
||
from sklearn.metrics import plot_precision_recall_curve | ||
from sklearn.metrics import average_precision_score | ||
from sklearn.metrics import precision_recall_curve | ||
from sklearn.datasets import load_breast_cancer | ||
from sklearn.datasets import load_iris | ||
from sklearn.tree import DecisionTreeClassifier | ||
from sklearn.linear_model import LogisticRegression | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def data_binary(): | ||
return load_breast_cancer(return_X_y=True) | ||
|
||
|
||
def test_error_non_binary(pyplot): | ||
X, y = load_iris(return_X_y=True) | ||
clf = DecisionTreeClassifier() | ||
clf.fit(X, y) | ||
|
||
msg = "Estimator should solve a binary classification problem" | ||
with pytest.raises(ValueError, match=msg): | ||
plot_precision_recall_curve(clf, X, y) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"response_method, msg", | ||
[("predict_proba", "response method predict_proba is not defined"), | ||
("decision_function", "response method decision_function is not defined"), | ||
("auto", "response methods not defined"), | ||
("bad_method", "response_method must be 'predict_proba', " | ||
"'decision_function' or 'auto'")]) | ||
def test_error_no_response(pyplot, data_binary, response_method, msg): | ||
X, y = data_binary | ||
|
||
class MyClassifier: | ||
pass | ||
|
||
clf = MyClassifier() | ||
|
||
with pytest.raises(ValueError, match=msg): | ||
plot_precision_recall_curve(clf, X, y, response_method=response_method) | ||
|
||
|
||
@pytest.mark.parametrize("response_method", | ||
["predict_proba", "decision_function"]) | ||
@pytest.mark.parametrize("with_sample_weight", [True, False]) | ||
def test_plot_precision_recall(pyplot, response_method, data_binary, | ||
with_sample_weight): | ||
X, y = data_binary | ||
|
||
lr = LogisticRegression() | ||
glemaitre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
lr.fit(X, y) | ||
glemaitre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if with_sample_weight: | ||
rng = np.random.RandomState(42) | ||
sample_weight = rng.randint(0, 4, size=X.shape[0]) | ||
else: | ||
sample_weight = None | ||
|
||
viz = plot_precision_recall_curve(lr, X, y, alpha=0.8, | ||
sample_weight=sample_weight) | ||
|
||
y_score = getattr(lr, response_method)(X) | ||
if y_score.ndim == 2: | ||
y_score = y_score[:, 1] | ||
|
||
prec, recall, _ = precision_recall_curve(y, y_score, | ||
sample_weight=sample_weight) | ||
avg_prec = average_precision_score(y, y_score, sample_weight=sample_weight) | ||
|
||
assert_allclose(viz.precision, prec) | ||
assert_allclose(viz.recall, recall) | ||
assert_allclose(viz.average_precision, avg_prec) | ||
glemaitre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
assert viz.estimator_name == "LogisticRegression" | ||
glemaitre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# cannot fail thanks to pyplot fixture | ||
import matplotlib as mpl # noqa | ||
assert isinstance(viz.line_, mpl.lines.Line2D) | ||
assert viz.line_.get_alpha() == 0.8 | ||
assert isinstance(viz.ax_, mpl.axes.Axes) | ||
assert isinstance(viz.figure_, mpl.figure.Figure) | ||
|
||
expected_label = "LogisticRegression (AP = {:0.2f})".format(avg_prec) | ||
assert viz.line_.get_label() == expected_label | ||
assert viz.ax_.get_xlabel() == "Recall" | ||
assert viz.ax_.get_ylabel() == "Precision" | ||
assert_allclose(viz.ax_.get_xlim(), [0.0, 1.0]) | ||
assert_allclose(viz.ax_.get_ylim(), [0.0, 1.05]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alphabetic order?