Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] EHN Provide a pos_label parameter in plot_roc_curve #17651

Merged
merged 20 commits into from Jul 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats_new/v0.24.rst
Expand Up @@ -174,6 +174,11 @@ Changelog
optional in the matplotlib plot by setting colorbar=False. :pr:`17192` by
:user:`Avi Gupta <avigupta2612>`

- |Enhancement| Add `pos_label` parameter in
:func:`metrics.plot_roc_curve` in order to specify the positive
class to be used when computing the roc auc statistics.
:pr:`17651` by :user:`Clara Matos <claramatos>`.

:mod:`sklearn.model_selection`
..............................

Expand Down
74 changes: 74 additions & 0 deletions sklearn/metrics/_plot/base.py
@@ -1,3 +1,8 @@
import numpy as np

from sklearn.base import is_classifier


def _check_classifier_response_method(estimator, response_method):
"""Return prediction method from the response_method

Expand Down Expand Up @@ -38,3 +43,72 @@ def _check_classifier_response_method(estimator, response_method):
estimator.__class__.__name__))

return prediction_method


def _get_response(X, estimator, response_method, pos_label=None):
"""Return response and positive label.

Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Input values.

estimator : estimator instance
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
in which the last estimator is a classifier.

response_method: {'auto', 'predict_proba', 'decision_function'}
Specifies whether to use :term:`predict_proba` or
:term:`decision_function` as the target response. If set to 'auto',
:term:`predict_proba` is tried first and if it does not exist
:term:`decision_function` is tried next.

pos_label : str or int, default=None
The class considered as the positive class when computing
the metrics. By default, `estimators.classes_[1]` is
considered as the positive class.

Returns
-------
y_pred: ndarray of shape (n_samples,)
Target scores calculated from the provided response_method
and pos_label.

pos_label: str or int
The class considered as the positive class when computing
the metrics.
"""
classification_error = (
"{} should be a binary classifier".format(estimator.__class__.__name__)
)

if not is_classifier(estimator):
raise ValueError(classification_error)

prediction_method = _check_classifier_response_method(
estimator, response_method)

y_pred = prediction_method(X)

if pos_label is not None and pos_label not in estimator.classes_:
raise ValueError(
f"The class provided by 'pos_label' is unknown. Got "
f"{pos_label} instead of one of {estimator.classes_}"
)

if y_pred.ndim != 1: # `predict_proba`
if y_pred.shape[1] != 2:
raise ValueError(classification_error)
if pos_label is None:
pos_label = estimator.classes_[1]
y_pred = y_pred[:, 1]
else:
class_idx = np.flatnonzero(estimator.classes_ == pos_label)
y_pred = y_pred[:, class_idx]
else:
if pos_label is None:
pos_label = estimator.classes_[1]
elif pos_label == estimator.classes_[0]:
y_pred *= -1

return y_pred, pos_label
60 changes: 19 additions & 41 deletions sklearn/metrics/_plot/precision_recall_curve.py
@@ -1,13 +1,10 @@
import numpy as np

from .base import _check_classifier_response_method
from .base import _get_response

from .. import average_precision_score
from .. import precision_recall_curve

from ...utils import check_matplotlib_support
from ...utils.validation import _deprecate_positional_args
from ...base import is_classifier


class PrecisionRecallDisplay:
Expand Down Expand Up @@ -67,12 +64,13 @@ class PrecisionRecallDisplay:
>>> disp = PrecisionRecallDisplay(precision=precision, recall=recall)
>>> disp.plot() # doctest: +SKIP
"""

def __init__(self, precision, recall, *,
average_precision=None, estimator_name=None, pos_label=None):
self.estimator_name = estimator_name
self.precision = precision
self.recall = recall
self.average_precision = average_precision
self.estimator_name = estimator_name
self.pos_label = pos_label

@_deprecate_positional_args
Expand Down Expand Up @@ -100,10 +98,6 @@ def plot(self, ax=None, *, name=None, **kwargs):
Object that stores computed values.
"""
check_matplotlib_support("PrecisionRecallDisplay.plot")
import matplotlib.pyplot as plt

if ax is None:
fig, ax = plt.subplots()

name = self.estimator_name if name is None else name

Expand All @@ -118,15 +112,21 @@ def plot(self, ax=None, *, name=None, **kwargs):
line_kwargs["label"] = name
line_kwargs.update(**kwargs)

import matplotlib.pyplot as plt

if ax is None:
fig, ax = plt.subplots()

self.line_, = ax.plot(self.recall, self.precision, **line_kwargs)
info_pos_label = (f" (Positive label: {self.pos_label})"
if self.pos_label is not None else "")

xlabel = "Recall" + info_pos_label
ylabel = "Precision" + info_pos_label
ax.set(xlabel=xlabel, ylabel=ylabel)

if "label" in line_kwargs:
ax.legend(loc='lower left')
ax.legend(loc="lower left")

self.ax_ = ax
self.figure_ = ax.figure
Expand Down Expand Up @@ -194,46 +194,24 @@ def plot_precision_recall_curve(estimator, X, y, *,
"""
check_matplotlib_support("plot_precision_recall_curve")

classification_error = ("{} should be a binary classifier".format(
estimator.__class__.__name__))
if not is_classifier(estimator):
raise ValueError(classification_error)

prediction_method = _check_classifier_response_method(estimator,
response_method)
y_pred = prediction_method(X)

if pos_label is not None and pos_label not in estimator.classes_:
raise ValueError(
f"The class provided by 'pos_label' is unknown. Got "
f"{pos_label} instead of one of {estimator.classes_}"
)

if y_pred.ndim != 1: # `predict_proba`
if y_pred.shape[1] != 2:
raise ValueError(classification_error)
if pos_label is None:
pos_label = estimator.classes_[1]
y_pred = y_pred[:, 1]
else:
class_idx = np.flatnonzero(estimator.classes_ == pos_label)
y_pred = y_pred[:, class_idx]
else: # `decision_function`
if pos_label is None:
pos_label = estimator.classes_[1]
elif pos_label == estimator.classes_[0]:
y_pred *= -1
y_pred, pos_label = _get_response(
X, estimator, response_method, pos_label=pos_label)

precision, recall, _ = precision_recall_curve(y, y_pred,
pos_label=pos_label,
sample_weight=sample_weight)
average_precision = average_precision_score(y, y_pred,
pos_label=pos_label,
sample_weight=sample_weight)

name = name if name is not None else estimator.__class__.__name__

viz = PrecisionRecallDisplay(
precision=precision, recall=recall,
average_precision=average_precision, estimator_name=name,
precision=precision,
recall=recall,
average_precision=average_precision,
estimator_name=name,
pos_label=pos_label,
)

return viz.plot(ax=ax, name=name, **kwargs)
75 changes: 45 additions & 30 deletions sklearn/metrics/_plot/roc_curve.py
@@ -1,9 +1,9 @@
from .base import _get_response

from .. import auc
from .. import roc_curve

from .base import _check_classifier_response_method
from ...utils import check_matplotlib_support
from ...base import is_classifier
from ...utils.validation import _deprecate_positional_args


Expand All @@ -29,6 +29,13 @@ class RocCurveDisplay:
estimator_name : str, default=None
Name of estimator. If None, the estimator name is not shown.

pos_label : str or int, default=None
The class considered as the positive class when computing the roc auc
metrics. By default, `estimators.classes_[1]` is considered
as the positive class.

.. versionadded:: 0.24

Attributes
----------
line_ : matplotlib Artist
Expand All @@ -54,11 +61,14 @@ class RocCurveDisplay:
>>> display.plot() # doctest: +SKIP
>>> plt.show() # doctest: +SKIP
"""
def __init__(self, *, fpr, tpr, roc_auc=None, estimator_name=None):

def __init__(self, *, fpr, tpr,
roc_auc=None, estimator_name=None, pos_label=None):
self.estimator_name = estimator_name
self.fpr = fpr
self.tpr = tpr
self.roc_auc = roc_auc
self.estimator_name = estimator_name
self.pos_label = pos_label

@_deprecate_positional_args
def plot(self, ax=None, *, name=None, **kwargs):
Expand All @@ -82,10 +92,6 @@ def plot(self, ax=None, *, name=None, **kwargs):
Object that stores computed values.
"""
check_matplotlib_support('RocCurveDisplay.plot')
import matplotlib.pyplot as plt

if ax is None:
fig, ax = plt.subplots()

name = self.estimator_name if name is None else name

Expand All @@ -99,12 +105,21 @@ def plot(self, ax=None, *, name=None, **kwargs):

line_kwargs.update(**kwargs)

self.line_ = ax.plot(self.fpr, self.tpr, **line_kwargs)[0]
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
import matplotlib.pyplot as plt

if ax is None:
fig, ax = plt.subplots()

self.line_, = ax.plot(self.fpr, self.tpr, **line_kwargs)
info_pos_label = (f" (Positive label: {self.pos_label})"
if self.pos_label is not None else "")

xlabel = "False Positive Rate" + info_pos_label
ylabel = "True Positive Rate" + info_pos_label
ax.set(xlabel=xlabel, ylabel=ylabel)

if "label" in line_kwargs:
ax.legend(loc='lower right')
ax.legend(loc="lower right")

self.ax_ = ax
self.figure_ = ax.figure
Expand All @@ -114,7 +129,7 @@ def plot(self, ax=None, *, name=None, **kwargs):
@_deprecate_positional_args
def plot_roc_curve(estimator, X, y, *, sample_weight=None,
drop_intermediate=True, response_method="auto",
name=None, ax=None, **kwargs):
name=None, ax=None, pos_label=None, **kwargs):
"""Plot Receiver operating characteristic (ROC) curve.

Extra keyword arguments will be passed to matplotlib's `plot`.
Expand Down Expand Up @@ -155,6 +170,13 @@ def plot_roc_curve(estimator, X, y, *, sample_weight=None,
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is created.

pos_label : str or int, default=None
The class considered as the positive class when computing the roc auc
metrics. By default, `estimators.classes_[1]` is considered
as the positive class.

.. versionadded:: 0.24

Returns
-------
display : :class:`~sklearn.metrics.RocCurveDisplay`
Expand All @@ -181,29 +203,22 @@ def plot_roc_curve(estimator, X, y, *, sample_weight=None,
"""
check_matplotlib_support('plot_roc_curve')

classification_error = (
"{} should be a binary classifier".format(estimator.__class__.__name__)
)
if not is_classifier(estimator):
raise ValueError(classification_error)

prediction_method = _check_classifier_response_method(estimator,
response_method)
y_pred = prediction_method(X)

if y_pred.ndim != 1:
if y_pred.shape[1] != 2:
raise ValueError(classification_error)
else:
y_pred = y_pred[:, 1]
y_pred, pos_label = _get_response(
X, estimator, response_method, pos_label=pos_label)

pos_label = estimator.classes_[1]
fpr, tpr, _ = roc_curve(y, y_pred, pos_label=pos_label,
sample_weight=sample_weight,
drop_intermediate=drop_intermediate)
roc_auc = auc(fpr, tpr)

name = estimator.__class__.__name__ if name is None else name

viz = RocCurveDisplay(
fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name=name
fpr=fpr,
tpr=tpr,
roc_auc=roc_auc,
estimator_name=name,
pos_label=pos_label
)

return viz.plot(ax=ax, name=name, **kwargs)
1 change: 0 additions & 1 deletion sklearn/metrics/_plot/tests/test_plot_precision_recall.py
Expand Up @@ -18,7 +18,6 @@
from sklearn.utils import shuffle
from sklearn.compose import make_column_transformer


# TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved
pytestmark = pytest.mark.filterwarnings(
"ignore:In future, it will be an error for 'np.bool_':DeprecationWarning:"
Expand Down