Skip to content

Commit

Permalink
MNT Refactor scorer using _get_response
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre committed Oct 9, 2020
1 parent 193670c commit cc27a27
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 213 deletions.
136 changes: 136 additions & 0 deletions sklearn/metrics/_base.py
Expand Up @@ -16,6 +16,7 @@

import numpy as np

from ..base import is_classifier
from ..utils import check_array, check_consistent_length
from ..utils.multiclass import type_of_target

Expand Down Expand Up @@ -249,3 +250,138 @@ def _check_pos_label_consistency(pos_label, y_true):
pos_label = 1.0

return pos_label


def _check_classifier_response_method(estimator, response_method):
"""Return prediction method from the `response_method`.
Parameters
----------
estimator : estimator instance
Classifier to check.
response_method : {'auto', 'predict_proba', 'decision_function', 'predict'}
Specifies whether to use :term:`predict_proba` or
:term:`decision_function` as the target response. If set to 'auto',
:term:`predict_proba` is tried first and if it does not exist
:term:`decision_function` is tried next and :term:`predict` last.
Returns
-------
prediction_method : callable
Prediction method of estimator.
"""

possible_response_methods = (
"predict", "predict_proba", "decision_function", "auto"
)
if response_method not in possible_response_methods:
raise ValueError(
f"response_method must be one of "
f"{','.join(possible_response_methods)}."
)

error_msg = "response method {} is not defined in {}"
if response_method != "auto":
prediction_method = getattr(estimator, response_method, None)
if prediction_method is None:
raise ValueError(
error_msg.format(response_method, estimator.__class__.__name__)
)
else:
predict_proba = getattr(estimator, 'predict_proba', None)
decision_function = getattr(estimator, 'decision_function', None)
predict = getattr(estimator, 'predict', None)
prediction_method = predict_proba or decision_function or predict
if prediction_method is None:
raise ValueError(
error_msg.format(
"decision_function, predict_proba or predict",
estimator.__class__.__name__
)
)

return prediction_method


def _get_response(
estimator,
X,
y_true,
response_method,
pos_label=None,
):
"""Return response and positive label.
Parameters
----------
estimator : estimator instance
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
in which the last estimator is a classifier.
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Input values.
y_true : array-like of shape (n_samples,)
The true label.
response_method: {'auto', 'predict_proba', 'decision_function', 'predict'}
Specifies whether to use :term:`predict_proba` or
:term:`decision_function` as the target response. If set to 'auto',
:term:`predict_proba` is tried first and if it does not exist
:term:`decision_function` is tried next and :term:`predict` last.
pos_label : str or int, default=None
The class considered as the positive class when computing
the metrics. By default, `estimators.classes_[1]` is
considered as the positive class.
Returns
-------
y_pred : ndarray of shape (n_samples,)
Target scores calculated from the provided response_method
and pos_label.
pos_label : str or int
The class considered as the positive class when computing
the metrics.
"""
if is_classifier(estimator):
y_type = type_of_target(y_true)
classes = estimator.classes_
prediction_method = _check_classifier_response_method(
estimator, response_method
)
y_pred = prediction_method(X)

if pos_label is not None and pos_label not in classes:
raise ValueError(
f"pos_label={pos_label} is not a valid label: It should be "
f"one of {classes}"
)
elif pos_label is None and y_type == "binary":
pos_label = pos_label if pos_label is not None else classes[-1]

if prediction_method.__name__ == "predict_proba":
if y_type == "binary" and y_pred.shape[1] <= 2:
if y_pred.shape[1] == 2:
col_idx = np.flatnonzero(classes == pos_label)[0]
y_pred = y_pred[:, col_idx]
else:
err_msg = (
f"Got predict_proba of shape {y_pred.shape}, but need "
f"classifier with two classes."
)
raise ValueError(err_msg)
elif prediction_method.__name__ == "decision_function":
if y_type == "binary":
if pos_label == classes[0]:
y_pred *= -1
else:
if response_method not in ("predict", "auto"):
raise ValueError(
f"{estimator.__class__.__name__} should be a classifier"
)
y_pred, pos_label = estimator.predict(X), None

return y_pred, pos_label
114 changes: 0 additions & 114 deletions sklearn/metrics/_plot/base.py

This file was deleted.

2 changes: 1 addition & 1 deletion sklearn/metrics/_plot/det_curve.py
@@ -1,6 +1,6 @@
import scipy as sp

from .base import _get_response
from .._base import _get_response

from .. import det_curve

Expand Down
2 changes: 1 addition & 1 deletion sklearn/metrics/_plot/precision_recall_curve.py
@@ -1,4 +1,4 @@
from .base import _get_response
from .._base import _get_response

from .. import average_precision_score
from .. import precision_recall_curve
Expand Down
2 changes: 1 addition & 1 deletion sklearn/metrics/_plot/roc_curve.py
@@ -1,4 +1,4 @@
from .base import _get_response
from .._base import _get_response

from .. import auc
from .. import roc_curve
Expand Down

0 comments on commit cc27a27

Please sign in to comment.