Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] EHN Provide a pos_label parameter in plot_roc_curve #17651

Merged
merged 20 commits into from Jul 6, 2020
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats_new/v0.24.rst
Expand Up @@ -117,6 +117,11 @@ Changelog
:class:`metrics.median_absolute_error`. :pr:`17225` by
:user:`Lucy Liu <lucyleeow>`.

- |Enhancement| Add `pos_label` parameter in
:func:`metrics.plot_roc_curve` in order to specify the positive
class to be used when computing the roc auc statistics.
:pr:`17651` by :user:`Clara Matos <claramatos>`.

:mod:`sklearn.model_selection`
..............................

Expand Down
153 changes: 153 additions & 0 deletions sklearn/metrics/_plot/base.py
@@ -1,3 +1,8 @@
import numpy as np

from sklearn.base import is_classifier


def _check_classifier_response_method(estimator, response_method):
"""Return prediction method from the response_method

Expand Down Expand Up @@ -38,3 +43,151 @@ def _check_classifier_response_method(estimator, response_method):
estimator.__class__.__name__))

return prediction_method


def _get_target_scores(X, estimator, response_method, pos_label=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the word "scores" here may get confused with metrics and scorers. Maybe _get_response is more clear?

"""Return target scores and positive label.
claramatos marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Input values.

estimator : estimator instance
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
in which the last estimator is a classifier.

response_method: {'auto', 'predict_proba', 'decision_function'}
Specifies whether to use :term:`predict_proba` or
:term:`decision_function` as the target response. If set to 'auto',
:term:`predict_proba` is tried first and if it does not exist
:term:`decision_function` is tried next.

pos_label : str or int, default=None
The class considered as the positive class when computing
the metrics. By default, `estimators.classes_[1]` is
considered as the positive class.

Returns
-------
y_pred: ndarray of shape (n_samples,)
Target scores calculated from the provided response_method
and pos_label.

pos_label: str or int
The class considered as the positive class when computing
the metrics.
"""
classification_error = (
"{} should be a binary classifier".format(estimator.__class__.__name__)
)

if not is_classifier(estimator):
raise ValueError(classification_error)

prediction_method = _check_classifier_response_method(
estimator, response_method)

y_pred = prediction_method(X)

if pos_label is not None and pos_label not in estimator.classes_:
raise ValueError(
f"The class provided by 'pos_label' is unknown. Got "
f"{pos_label} instead of one of {estimator.classes_}"
)

if y_pred.ndim != 1: # `predict_proba`
if y_pred.shape[1] != 2:
raise ValueError(classification_error)
if pos_label is None:
pos_label = estimator.classes_[1]
y_pred = y_pred[:, 1]
else:
class_idx = np.flatnonzero(estimator.classes_ == pos_label)
y_pred = y_pred[:, class_idx]
else:
if pos_label is None:
pos_label = estimator.classes_[1]
elif pos_label == estimator.classes_[0]:
y_pred *= -1

return y_pred, pos_label


class CurveDisplay:
"""Metrics visualization base class.

Parameters
-----------
estimator_name : str, default=None
Name of estimator. If None, then the estimator name is not shown.

pos_label : str or int, default=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having pos_label here makes this base class less generic. We would only be able to use this to plot curves related to classification.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are there any plot classes you believe should use the CurveDisplay class?

Copy link
Member

@glemaitre glemaitre Jun 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thomasjpfan Do we have curves for regression?
I assume that we could make a trick to find what type of data we handle since we have y and to know if this is a classification problem. If this is not, we can discard pos_label

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll be happy with renaming it to ClassificationCurveDisplay

I forgot I had this comment still here.

The class considered as the positive class when computing
the metrics. By default, `estimators.classes_[1]` is
considered as the positive class.

Attributes
----------
line_ : matplotlib Artist
Metrics curve.

ax_ : matplotlib Axes
Axes with the curve.

figure_ : matplotlib Figure
Figure containing the curve.
"""

def __init__(self, estimator_name=None, pos_label=None):
self.estimator_name = estimator_name
self.pos_label = pos_label

def _setup_display(self, x, y, line_kwargs,
xlabel, ylabel, loc, ax=None):
"""Setup visualization.

Parameters
----------
x, y : array-like or scalar
The horizontal / vertical coordinates of the data points.

line_kwargs : dict
Keyword arguments to be passed to matplotlib's `plot`.

xlabel : str
Label of the horizontal axis

ylabel : str
Label of the vertical axis

loc : str
Location of the legend.

ax : Matplotlib Axes, default=None
Axes object to plot on. If `None`, a new figure and axes is
created.

Returns
-------
display : :class:`~sklearn.metrics.Display`
Object that stores computed values.
"""
import matplotlib.pyplot as plt

if ax is None:
fig, ax = plt.subplots()

self.line_, = ax.plot(x, y, **line_kwargs)
info_pos_label = (f" (Positive label: {self.pos_label})"
if self.pos_label is not None else "")
xlabel += info_pos_label
ylabel += info_pos_label
ax.set(xlabel=xlabel, ylabel=ylabel)

if "label" in line_kwargs:
ax.legend(loc=loc)

self.ax_ = ax
self.figure_ = ax.figure
return self
76 changes: 40 additions & 36 deletions sklearn/metrics/_plot/precision_recall_curve.py
@@ -1,14 +1,14 @@
from .base import _check_classifier_response_method
from .base import _get_target_scores
from .base import CurveDisplay

from .. import average_precision_score
from .. import precision_recall_curve

from ...utils import check_matplotlib_support
from ...utils.validation import _deprecate_positional_args
from ...base import is_classifier


class PrecisionRecallDisplay:
class PrecisionRecallDisplay(CurveDisplay):
"""Precision Recall visualization.

It is recommend to use :func:`~sklearn.metrics.plot_precision_recall_curve`
Expand All @@ -30,6 +30,13 @@ class PrecisionRecallDisplay:
estimator_name : str, default=None
Name of estimator. If None, then the estimator name is not shown.

pos_label : str or int, default=None
The class considered as the positive class when computing the precision
and recall metrics. By default, `estimators.classes_[1]` is considered
as the positive class.
claramatos marked this conversation as resolved.
Show resolved Hide resolved

.. versionadded:: 0.24

Attributes
----------
line_ : matplotlib Artist
Expand Down Expand Up @@ -59,12 +66,13 @@ class PrecisionRecallDisplay:
>>> disp = PrecisionRecallDisplay(precision=precision, recall=recall)
>>> disp.plot() # doctest: +SKIP
"""

def __init__(self, precision, recall, *,
average_precision=None, estimator_name=None):
average_precision=None, estimator_name=None, pos_label=None):
super().__init__(estimator_name, pos_label)
self.precision = precision
self.recall = recall
self.average_precision = average_precision
self.estimator_name = estimator_name

@_deprecate_positional_args
def plot(self, ax=None, *, name=None, **kwargs):
Expand All @@ -91,10 +99,6 @@ def plot(self, ax=None, *, name=None, **kwargs):
Object that stores computed values.
"""
check_matplotlib_support("PrecisionRecallDisplay.plot")
import matplotlib.pyplot as plt

if ax is None:
fig, ax = plt.subplots()

name = self.estimator_name if name is None else name

Expand All @@ -109,21 +113,21 @@ def plot(self, ax=None, *, name=None, **kwargs):
line_kwargs["label"] = name
line_kwargs.update(**kwargs)

self.line_, = ax.plot(self.recall, self.precision, **line_kwargs)
ax.set(xlabel="Recall", ylabel="Precision")

if "label" in line_kwargs:
ax.legend(loc='lower left')

self.ax_ = ax
self.figure_ = ax.figure
return self
return self._setup_display(
x=self.recall,
y=self.precision,
line_kwargs=line_kwargs,
xlabel="Recall",
ylabel="Precision",
loc="lower left",
ax=ax
)


@_deprecate_positional_args
def plot_precision_recall_curve(estimator, X, y, *,
sample_weight=None, response_method="auto",
name=None, ax=None, **kwargs):
name=None, ax=None, pos_label=None, **kwargs):
"""Plot Precision Recall Curve for binary classifiers.

Extra keyword arguments will be passed to matplotlib's `plot`.
Expand Down Expand Up @@ -159,6 +163,13 @@ def plot_precision_recall_curve(estimator, X, y, *,
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is created.

pos_label : str or int, default=None
The class considered as the positive class when computing the precision
and recall metrics. By default, `estimators.classes_[1]` is considered
as the positive class.

.. versionadded:: 0.24

**kwargs : dict
Keyword arguments to be passed to matplotlib's `plot`.

Expand All @@ -174,31 +185,24 @@ def plot_precision_recall_curve(estimator, X, y, *,
"""
check_matplotlib_support("plot_precision_recall_curve")

classification_error = ("{} should be a binary classifier".format(
estimator.__class__.__name__))
if not is_classifier(estimator):
raise ValueError(classification_error)

prediction_method = _check_classifier_response_method(estimator,
response_method)
y_pred = prediction_method(X)

if y_pred.ndim != 1:
if y_pred.shape[1] != 2:
raise ValueError(classification_error)
else:
y_pred = y_pred[:, 1]
y_pred, pos_label = _get_target_scores(
X, estimator, response_method, pos_label=pos_label)

pos_label = estimator.classes_[1]
precision, recall, _ = precision_recall_curve(y, y_pred,
pos_label=pos_label,
sample_weight=sample_weight)
average_precision = average_precision_score(y, y_pred,
pos_label=pos_label,
sample_weight=sample_weight)

name = name if name is not None else estimator.__class__.__name__

viz = PrecisionRecallDisplay(
precision=precision, recall=recall,
average_precision=average_precision, estimator_name=name
precision=precision,
recall=recall,
average_precision=average_precision,
estimator_name=name,
pos_label=pos_label,
)

return viz.plot(ax=ax, name=name, **kwargs)