Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Adds plot_precision_recall_curve #14936

Merged
merged 30 commits into from Nov 11, 2019

Conversation

thomasjpfan
Copy link
Member

@thomasjpfan thomasjpfan commented Sep 9, 2019

Reference Issues/PRs

Related to #7116

What does this implement/fix? Explain your changes.

This PR adds plot_precision_recall_curve.

Any other comments?

Only supports binary classifiers.

Copy link
Contributor

@glemaitre glemaitre left a comment

Looks good, only a couple of changes.

examples/model_selection/plot_precision_recall.py Outdated Show resolved Hide resolved
examples/model_selection/plot_precision_recall.py Outdated Show resolved Hide resolved
sklearn/metrics/_plot/precision_recall.py Outdated Show resolved Hide resolved
sklearn/metrics/_plot/precision_recall.py Outdated Show resolved Hide resolved
sklearn/metrics/_plot/precision_recall.py Outdated Show resolved Hide resolved
sklearn/metrics/_plot/tests/test_plot_precision_recall.py Outdated Show resolved Hide resolved
sklearn/metrics/_plot/tests/test_plot_precision_recall.py Outdated Show resolved Hide resolved
sklearn/metrics/_plot/precision_recall.py Outdated Show resolved Hide resolved
sklearn/metrics/_plot/tests/test_plot_precision_recall.py Outdated Show resolved Hide resolved
sklearn/metrics/_plot/tests/test_plot_precision_recall.py Outdated Show resolved Hide resolved
@glemaitre glemaitre added this to REVIEWED AND WAITING FOR CHANGES in Guillaume's pet Sep 13, 2019
Parameters
-----------
precision : ndarray of shape (n_thresholds + 1, )
Copy link
Contributor

@glemaitre glemaitre Sep 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
precision : ndarray of shape (n_thresholds + 1, )
precision : ndarray of shape (n_thresholds + 1,)

Copy link
Contributor

@glemaitre glemaitre left a comment

You can also add an entry in what's new

sklearn/metrics/_plot/precision_recall.py Outdated Show resolved Hide resolved

if y_pred.ndim != 1:
if y_pred.shape[1] > 2:
raise ValueError("Estimator should solve a "
Copy link
Contributor

@glemaitre glemaitre Sep 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it possible to use check_classification_targets?

@amueller
Copy link
Member

@amueller amueller commented Sep 20, 2019

conflicts ;)


y_pred = prediction_method(X)

if is_predict_proba and y_pred.ndim != 1:
Copy link
Member

@amueller amueller Sep 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if is_predict_proba y_pred.ndim is never 1, right?

plot_precision_recall_curve(clf, X, y)

msg = "Estimator should solve a binary classification problem"
y_binary = y == 1
Copy link
Member

@amueller amueller Sep 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why this raises this error, both semantically and why that is what the code does. I thought the code checked y_pred, which we're not changing here, right?

@amueller
Copy link
Member

@amueller amueller commented Sep 25, 2019

looks good apart from nitpicks

@glemaitre
Copy link
Contributor

@glemaitre glemaitre commented Oct 2, 2019

The error raised does not match.

@NicolasHug
Copy link
Member

@NicolasHug NicolasHug commented Nov 6, 2019

The user guide link of plot_precision_recall_curve is wrong: there's no point to link to the vizualization API UG. Also some of the links are broken

@thomasjpfan

@amueller
Copy link
Member

@amueller amueller commented Nov 6, 2019

see #15405 (comment)

The easy fix is removing it and inferring it from the estimator. The better fix is to actually ensure to correctly slice predict_proba / decision_function

@thomasjpfan
Copy link
Member Author

@thomasjpfan thomasjpfan commented Nov 6, 2019

Went with removing pos_label and infering it from the estimator.

@amueller
Copy link
Member

@amueller amueller commented Nov 6, 2019

Then we should do the same for plot_roc_curve and open an issue to do the fix for the next release?

Copy link
Member

@qinhanmin2014 qinhanmin2014 left a comment

Perhaps it's better to keep consistent with plot_roc_curve/RocCurveDisplay (API and code)

@@ -71,6 +71,7 @@ Functions

.. autosummary::

metrics.plot_precision_recall_curve
Copy link
Member

@qinhanmin2014 qinhanmin2014 Nov 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alphabetic order?

@@ -82,5 +83,6 @@ Display Objects

.. autosummary::

metrics.PrecisionRecallDisplay
Copy link
Member

@qinhanmin2014 qinhanmin2014 Nov 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alphabetic order?

@@ -79,6 +79,8 @@

from ._plot.roc_curve import plot_roc_curve
from ._plot.roc_curve import RocCurveDisplay
from ._plot.precision_recall import plot_precision_recall_curve
Copy link
Member

@qinhanmin2014 qinhanmin2014 Nov 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename the file to precision_recall_curve.py?

Axes object to plot on. If `None`, a new figure and axes is
created.
label_name : str, default=None
Copy link
Member

@qinhanmin2014 qinhanmin2014 Nov 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it different from RocCurveDisplay?

Parameters
-----------
precision : ndarray of shape (n_thresholds + 1,)
Copy link
Member

@qinhanmin2014 qinhanmin2014 Nov 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_thresholds is not defined in this context.

line_kwargs.update(**kwargs)

self.line_, = ax.plot(self.recall, self.precision, **line_kwargs)
ax.set(xlabel="Recall", ylabel="Precision", ylim=[0.0, 1.05],
Copy link
Member

@qinhanmin2014 qinhanmin2014 Nov 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rely on default xlim/ylim?

Copy link
Member Author

@thomasjpfan thomasjpfan Nov 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the, x, going without the I think for the y it is kind of important because of the scaling:

With ylim explicit set:
set

Not set:
not_set

precision : ndarray of shape (n_thresholds + 1,)
Precision values.
recall : ndarray of shape (n_thresholds + 1,)
Copy link
Member

@qinhanmin2014 qinhanmin2014 Nov 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_thresholds is not defined in this context.

:term:`predict_proba` is tried first and if it does not exist
:term:`decision_function` is tried next.
label_name : str, default=None
Copy link
Member

@qinhanmin2014 qinhanmin2014 Nov 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not consistent with plot_roc_curve

Copy link
Member Author

@thomasjpfan thomasjpfan Nov 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed this to name to be consistent with plot_roc_curve.

Copy link
Member

@NicolasHug NicolasHug left a comment

Looks good but need to link to UG with small updates

It is recommend to use :func:`~sklearn.metrics.plot_precision_recall_curve`
to create a visualizer. All parameters are stored as attributes.
Copy link
Member

@NicolasHug NicolasHug Nov 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add link to Visualization UG

"""Plot Precision Recall Curve for binary classifers.
Extra keyword arguments will be passed to matplotlib's `plot`.
Copy link
Member

@NicolasHug NicolasHug Nov 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

@NicolasHug NicolasHug left a comment

Not sure why tests are failing but LGTM

Copy link
Member

@qinhanmin2014 qinhanmin2014 left a comment

I think tests are failing because we no longer set xlim and ylim manually but we don't update the test.

I feel a little uncomfortable that plot_roc_curve and plot_precision_recall_curve are written in different way, e.g., we introduce is_predict_proba in plot_precision_recall_curve, but do not introduce it in plot_roc_auc_score. If we keep these two functions consistent, it will be much easier to maintain, but prehaps it's not so important.

@thomasjpfan
Copy link
Member Author

@thomasjpfan thomasjpfan commented Nov 8, 2019

If we keep these two functions consistent, it will be much easier to maintain, but prehaps it's not so important.

I refactored the response method checking into a _check_classifer_response_method that can be used by plot_roc_auc_curve. We can have a follow up PR to have plot_roc_auc_curve use it as well, to keep the error messages and code consistent.

Copy link
Member

@qinhanmin2014 qinhanmin2014 left a comment

Should we rename the files to _precision_recall_curve.py and _roc_curve.py

@@ -0,0 +1,40 @@
def _check_classifer_response_method(estimator, response_method):
Copy link
Member

@qinhanmin2014 qinhanmin2014 Nov 9, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it good to include things in init.py? Perhaps base.py?
Let's update plot_roc_curve in this PR?

@thomasjpfan
Copy link
Member Author

@thomasjpfan thomasjpfan commented Nov 10, 2019

Should we rename the files to _precision_recall_curve.py and _roc_curve.py

Since they are both in _plot, either way works for me.

Let's update plot_roc_curve in this PR?

Done

@@ -180,18 +181,8 @@ def plot_roc_curve(estimator, X, y, sample_weight=None,
else:
raise ValueError(classification_error)

if response_method != "auto":
Copy link
Member

@qinhanmin2014 qinhanmin2014 Nov 10, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also need to remove following things above

if response_method not in ("predict_proba", "decision_function", "auto"):
        raise ValueError("response_method must be 'predict_proba', "
                         "'decision_function' or 'auto'")

@qinhanmin2014 qinhanmin2014 merged commit 968252d into scikit-learn:master Nov 11, 2019
20 checks passed
Interpretability / Plotting / Interactive dev automation moved this from Review in progress to Done Nov 11, 2019
Meeting Issues automation moved this from Review in progress to Done Nov 11, 2019
@glemaitre glemaitre moved this from TO REVIEW to REVIEWED AND WAITING FOR CHANGES in Guillaume's pet Nov 14, 2019
@glemaitre glemaitre moved this from REVIEWED AND WAITING FOR CHANGES to MERGED in Guillaume's pet Nov 14, 2019
panpiort8 pushed a commit to panpiort8/scikit-learn that referenced this issue Mar 3, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Development

Successfully merging this pull request may close these issues.

None yet

5 participants