diff --git a/doc/model_selection.rst b/doc/model_selection.rst index 25cd2b655ccc5..522544aefc820 100644 --- a/doc/model_selection.rst +++ b/doc/model_selection.rst @@ -14,5 +14,6 @@ Model selection and evaluation modules/cross_validation modules/grid_search + modules/classification_threshold modules/model_evaluation modules/learning_curve diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 55336389f93d5..804546eababef 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -1248,6 +1248,17 @@ Hyper-parameter optimizers model_selection.RandomizedSearchCV model_selection.HalvingRandomSearchCV +Post-fit model tuning +--------------------- + +.. currentmodule:: sklearn + +.. autosummary:: + :toctree: generated/ + :template: class.rst + + model_selection.FixedThresholdClassifier + model_selection.TunedThresholdClassifierCV Model validation ---------------- diff --git a/doc/modules/classification_threshold.rst b/doc/modules/classification_threshold.rst new file mode 100644 index 0000000000000..712a094a43246 --- /dev/null +++ b/doc/modules/classification_threshold.rst @@ -0,0 +1,156 @@ +.. currentmodule:: sklearn.model_selection + +.. _TunedThresholdClassifierCV: + +================================================== +Tuning the decision threshold for class prediction +================================================== + +Classification is best divided into two parts: + +* the statistical problem of learning a model to predict, ideally, class probabilities; +* the decision problem to take concrete action based on those probability predictions. + +Let's take a straightforward example related to weather forecasting: the first point is +related to answering "what is the chance that it will rain tomorrow?" while the second +point is related to answering "should I take an umbrella tomorrow?". + +When it comes to the scikit-learn API, the first point is addressed providing scores +using :term:`predict_proba` or :term:`decision_function`. The former returns conditional +probability estimates :math:`P(y|X)` for each class, while the latter returns a decision +score for each class. + +The decision corresponding to the labels are obtained with :term:`predict`. In binary +classification, a decision rule or action is then defined by thresholding the scores, +leading to the prediction of a single class label for each sample. For binary +classification in scikit-learn, class labels predictions are obtained by hard-coded +cut-off rules: a positive class is predicted when the conditional probability +:math:`P(y|X)` is greater than 0.5 (obtained with :term:`predict_proba`) or if the +decision score is greater than 0 (obtained with :term:`decision_function`). + +Here, we show an example that illustrates the relation between conditional +probability estimates :math:`P(y|X)` and class labels:: + + >>> from sklearn.datasets import make_classification + >>> from sklearn.tree import DecisionTreeClassifier + >>> X, y = make_classification(random_state=0) + >>> classifier = DecisionTreeClassifier(max_depth=2, random_state=0).fit(X, y) + >>> classifier.predict_proba(X[:4]) + array([[0.94 , 0.06 ], + [0.94 , 0.06 ], + [0.0416..., 0.9583...], + [0.0416..., 0.9583...]]) + >>> classifier.predict(X[:4]) + array([0, 0, 1, 1]) + +While these hard-coded rules might at first seem reasonable as default behavior, they +are most certainly not ideal for most use cases. Let's illustrate with an example. + +Consider a scenario where a predictive model is being deployed to assist +physicians in detecting tumors. In this setting, physicians will most likely be +interested in identifying all patients with cancer and not missing anyone with cancer so +that they can provide them with the right treatment. In other words, physicians +prioritize achieving a high recall rate. This emphasis on recall comes, of course, with +the trade-off of potentially more false-positive predictions, reducing the precision of +the model. That is a risk physicians are willing to take because the cost of a missed +cancer is much higher than the cost of further diagnostic tests. Consequently, when it +comes to deciding whether to classify a patient as having cancer or not, it may be more +beneficial to classify them as positive for cancer when the conditional probability +estimate is much lower than 0.5. + +Post-tuning the decision threshold +================================== + +One solution to address the problem stated in the introduction is to tune the decision +threshold of the classifier once the model has been trained. The +:class:`~sklearn.model_selection.TunedThresholdClassifierCV` tunes this threshold using +an internal cross-validation. The optimum threshold is chosen to maximize a given +metric. + +The following image illustrates the tuning of the decision threshold for a gradient +boosting classifier. While the vanilla and tuned classifiers provide the same +:term:`predict_proba` outputs and thus the same Receiver Operating Characteristic (ROC) +and Precision-Recall curves, the class label predictions differ because of the tuned +decision threshold. The vanilla classifier predicts the class of interest for a +conditional probability greater than 0.5 while the tuned classifier predicts the class +of interest for a very low probability (around 0.02). This decision threshold optimizes +a utility metric defined by the business (in this case an insurance company). + +.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_cost_sensitive_learning_002.png + :target: ../auto_examples/model_selection/plot_cost_sensitive_learning.html + :align: center + +Options to tune the decision threshold +-------------------------------------- + +The decision threshold can be tuned through different strategies controlled by the +parameter `scoring`. + +One way to tune the threshold is by maximizing a pre-defined scikit-learn metric. These +metrics can be found by calling the function :func:`~sklearn.metrics.get_scorer_names`. +By default, the balanced accuracy is the metric used but be aware that one should choose +a meaningful metric for their use case. + +.. note:: + + It is important to notice that these metrics come with default parameters, notably + the label of the class of interest (i.e. `pos_label`). Thus, if this label is not + the right one for your application, you need to define a scorer and pass the right + `pos_label` (and additional parameters) using the + :func:`~sklearn.metrics.make_scorer`. Refer to :ref:`scoring` to get + information to define your own scoring function. For instance, we show how to pass + the information to the scorer that the label of interest is `0` when maximizing the + :func:`~sklearn.metrics.f1_score`:: + + >>> from sklearn.linear_model import LogisticRegression + >>> from sklearn.model_selection import TunedThresholdClassifierCV + >>> from sklearn.metrics import make_scorer, f1_score + >>> X, y = make_classification( + ... n_samples=1_000, weights=[0.1, 0.9], random_state=0) + >>> pos_label = 0 + >>> scorer = make_scorer(f1_score, pos_label=pos_label) + >>> base_model = LogisticRegression() + >>> model = TunedThresholdClassifierCV(base_model, scoring=scorer) + >>> scorer(model.fit(X, y), X, y) + 0.88... + >>> # compare it with the internal score found by cross-validation + >>> model.best_score_ + 0.86... + +Important notes regarding the internal cross-validation +------------------------------------------------------- + +By default :class:`~sklearn.model_selection.TunedThresholdClassifierCV` uses a 5-fold +stratified cross-validation to tune the decision threshold. The parameter `cv` allows to +control the cross-validation strategy. It is possible to bypass cross-validation by +setting `cv="prefit"` and providing a fitted classifier. In this case, the decision +threshold is tuned on the data provided to the `fit` method. + +However, you should be extremely careful when using this option. You should never use +the same data for training the classifier and tuning the decision threshold due to the +risk of overfitting. Refer to the following example section for more details (cf. +:ref:`TunedThresholdClassifierCV_no_cv`). If you have limited resources, consider using +a float number for `cv` to limit to an internal single train-test split. + +The option `cv="prefit"` should only be used when the provided classifier was already +trained, and you just want to find the best decision threshold using a new validation +set. + +.. _FixedThresholdClassifier: + +Manually setting the decision threshold +--------------------------------------- + +The previous sections discussed strategies to find an optimal decision threshold. It is +also possible to manually set the decision threshold using the class +:class:`~sklearn.model_selection.FixedThresholdClassifier`. + +Examples +-------- + +- See the example entitled + :ref:`sphx_glr_auto_examples_model_selection_plot_tuned_decision_threshold.py`, + to get insights on the post-tuning of the decision threshold. +- See the example entitled + :ref:`sphx_glr_auto_examples_model_selection_plot_cost_sensitive_learning.py`, + to learn about cost-sensitive learning and decision threshold tuning. diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index 9f53afd433ffc..3764bb98968cb 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -412,6 +412,13 @@ Changelog :mod:`sklearn.model_selection` .............................. +- |MajorFeature| :class:`model_selection.TunedThresholdClassifierCV` finds + the decision threshold of a binary classifier that maximizes a + classification metric through cross-validation. + :class:`model_selection.FixedThresholdClassifier` is an alternative when one wants + to use a fixed decision threshold without any tuning scheme. + :pr:`26120` by :user:`Guillaume Lemaitre `. + - |Enhancement| :term:`CV splitters ` that ignores the group parameter now raises a warning when groups are passed in to :term:`split`. :pr:`28210` by `Thomas Fan`_. diff --git a/examples/model_selection/plot_cost_sensitive_learning.py b/examples/model_selection/plot_cost_sensitive_learning.py new file mode 100644 index 0000000000000..7b64af48139f2 --- /dev/null +++ b/examples/model_selection/plot_cost_sensitive_learning.py @@ -0,0 +1,702 @@ +""" +============================================================== +Post-tuning the decision threshold for cost-sensitive learning +============================================================== + +Once a classifier is trained, the output of the :term:`predict` method outputs class +label predictions corresponding to a thresholding of either the :term:`decision +function` or the :term:`predict_proba` output. For a binary classifier, the default +threshold is defined as a posterior probability estimate of 0.5 or a decision score of +0.0. + +However, this default strategy is most likely not optimal for the task at hand. +Here, we use the "Statlog" German credit dataset [1]_ to illustrate a use case. +In this dataset, the task is to predict whether a person has a "good" or "bad" credit. +In addition, a cost-matrix is provided that specifies the cost of +misclassification. Specifically, misclassifying a "bad" credit as "good" is five +times more costly on average than misclassifying a "good" credit as "bad". + +We use the :class:`~sklearn.model_selection.TunedThresholdClassifierCV` to select the +cut-off point of the decision function that minimizes the provided business +cost. + +In the second part of the example, we further extend this approach by +considering the problem of fraud detection in credit card transactions: in this +case, the business metric depends on the amount of each individual transaction. +.. topic:: References + + .. [1] "Statlog (German Credit Data) Data Set", UCI Machine Learning Repository, + `Link + `_. + + .. [2] `Charles Elkan, "The Foundations of Cost-Sensitive Learning", + International joint conference on artificial intelligence. + Vol. 17. No. 1. Lawrence Erlbaum Associates Ltd, 2001. + `_ +""" + +# %% +# Cost-sensitive learning with constant gains and costs +# ----------------------------------------------------- +# +# In this first section, we illustrate the use of the +# :class:`~sklearn.model_selection.TunedThresholdClassifierCV` in a setting of +# cost-sensitive learning when the gains and costs associated to each entry of the +# confusion matrix are constant. We use the problematic presented in [2]_ using the +# "Statlog" German credit dataset [1]_. +# +# "Statlog" German credit dataset +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# We fetch the German credit dataset from OpenML. +import sklearn +from sklearn.datasets import fetch_openml + +sklearn.set_config(transform_output="pandas") + +german_credit = fetch_openml(data_id=31, as_frame=True, parser="pandas") +X, y = german_credit.data, german_credit.target + +# %% +# We check the feature types available in `X`. +X.info() + +# %% +# Many features are categorical and usually string-encoded. We need to encode +# these categories when we develop our predictive model. Let's check the targets. +y.value_counts() + +# %% +# Another observation is that the dataset is imbalanced. We would need to be careful +# when evaluating our predictive model and use a family of metrics that are adapted +# to this setting. +# +# In addition, we observe that the target is string-encoded. Some metrics +# (e.g. precision and recall) require to provide the label of interest also called +# the "positive label". Here, we define that our goal is to predict whether or not +# a sample is a "bad" credit. +pos_label, neg_label = "bad", "good" + +# %% +# To carry our analysis, we split our dataset using a single stratified split. +from sklearn.model_selection import train_test_split + +X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0) + +# %% +# We are ready to design our predictive model and the associated evaluation strategy. +# +# Evaluation metrics +# ^^^^^^^^^^^^^^^^^^ +# +# In this section, we define a set of metrics that we use later. To see +# the effect of tuning the cut-off point, we evaluate the predictive model using +# the Receiver Operating Characteristic (ROC) curve and the Precision-Recall curve. +# The values reported on these plots are therefore the true positive rate (TPR), +# also known as the recall or the sensitivity, and the false positive rate (FPR), +# also known as the specificity, for the ROC curve and the precision and recall for +# the Precision-Recall curve. +# +# From these four metrics, scikit-learn does not provide a scorer for the FPR. We +# therefore need to define a small custom function to compute it. +from sklearn.metrics import confusion_matrix + + +def fpr_score(y, y_pred, neg_label, pos_label): + cm = confusion_matrix(y, y_pred, labels=[neg_label, pos_label]) + tn, fp, _, _ = cm.ravel() + tnr = tn / (tn + fp) + return 1 - tnr + + +# %% +# As previously stated, the "positive label" is not defined as the value "1" and calling +# some of the metrics with this non-standard value raise an error. We need to +# provide the indication of the "positive label" to the metrics. +# +# We therefore need to define a scikit-learn scorer using +# :func:`~sklearn.metrics.make_scorer` where the information is passed. We store all +# the custom scorers in a dictionary. To use them, we need to pass the fitted model, +# the data and the target on which we want to evaluate the predictive model. +from sklearn.metrics import make_scorer, precision_score, recall_score + +tpr_score = recall_score # TPR and recall are the same metric +scoring = { + "precision": make_scorer(precision_score, pos_label=pos_label), + "recall": make_scorer(recall_score, pos_label=pos_label), + "fpr": make_scorer(fpr_score, neg_label=neg_label, pos_label=pos_label), + "tpr": make_scorer(tpr_score, pos_label=pos_label), +} + +# %% +# In addition, the original research [1]_ defines a custom business metric. We +# call a "business metric" any metric function that aims at quantifying how the +# predictions (correct or wrong) might impact the business value of deploying a +# given machine learning model in a specific application context. For our +# credit prediction task, the authors provide a custom cost-matrix which +# encodes that classifying a a "bad" credit as "good" is 5 times more costly on +# average than the opposite: it is less costly for the financing institution to +# not grant a credit to a potential customer that will not default (and +# therefore miss a good customer that would have otherwise both reimbursed the +# credit and payed interests) than to grant a credit to a customer that will +# default. +# +# We define a python function that weight the confusion matrix and return the +# overall cost. +import numpy as np + + +def credit_gain_score(y, y_pred, neg_label, pos_label): + cm = confusion_matrix(y, y_pred, labels=[neg_label, pos_label]) + # The rows of the confusion matrix hold the counts of observed classes + # while the columns hold counts of predicted classes. Recall that here we + # consider "bad" as the positive class (second row and column). + # Scikit-learn model selection tools expect that we follow a convention + # that "higher" means "better", hence the following gain matrix assigns + # negative gains (costs) to the two kinds of prediction errors: + # - a gain of -1 for each false positive ("good" credit labeled as "bad"), + # - a gain of -5 for each false negative ("bad" credit labeled as "good"), + # The true positives and true negatives are assigned null gains in this + # metric. + # + # Note that theoretically, given that our model is calibrated and our data + # set representative and large enough, we do not need to tune the + # threshold, but can safely set it to the cost ration 1/5, as stated by Eq. + # (2) in Elkan paper [2]_. + gain_matrix = np.array( + [ + [0, -1], # -1 gain for false positives + [-5, 0], # -5 gain for false negatives + ] + ) + return np.sum(cm * gain_matrix) + + +scoring["cost_gain"] = make_scorer( + credit_gain_score, neg_label=neg_label, pos_label=pos_label +) +# %% +# Vanilla predictive model +# ^^^^^^^^^^^^^^^^^^^^^^^^ +# +# We use :class:`~sklearn.ensemble.HistGradientBoostingClassifier` as a predictive model +# that natively handles categorical features and missing values. +from sklearn.ensemble import HistGradientBoostingClassifier + +model = HistGradientBoostingClassifier( + categorical_features="from_dtype", random_state=0 +).fit(X_train, y_train) +model + +# %% +# We evaluate the performance of our predictive model using the ROC and Precision-Recall +# curves. +import matplotlib.pyplot as plt + +from sklearn.metrics import PrecisionRecallDisplay, RocCurveDisplay + +fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(14, 6)) + +PrecisionRecallDisplay.from_estimator( + model, X_test, y_test, pos_label=pos_label, ax=axs[0], name="GBDT" +) +axs[0].plot( + scoring["recall"](model, X_test, y_test), + scoring["precision"](model, X_test, y_test), + marker="o", + markersize=10, + color="tab:blue", + label="Default cut-off point at a probability of 0.5", +) +axs[0].set_title("Precision-Recall curve") +axs[0].legend() + +RocCurveDisplay.from_estimator( + model, + X_test, + y_test, + pos_label=pos_label, + ax=axs[1], + name="GBDT", + plot_chance_level=True, +) +axs[1].plot( + scoring["fpr"](model, X_test, y_test), + scoring["tpr"](model, X_test, y_test), + marker="o", + markersize=10, + color="tab:blue", + label="Default cut-off point at a probability of 0.5", +) +axs[1].set_title("ROC curve") +axs[1].legend() +_ = fig.suptitle("Evaluation of the vanilla GBDT model") + +# %% +# We recall that these curves give insights on the statistical performance of the +# predictive model for different cut-off points. For the Precision-Recall curve, the +# reported metrics are the precision and recall and for the ROC curve, the reported +# metrics are the TPR (same as recall) and FPR. +# +# Here, the different cut-off points correspond to different levels of posterior +# probability estimates ranging between 0 and 1. By default, `model.predict` uses a +# cut-off point at a probability estimate of 0.5. The metrics for such a cut-off point +# are reported with the blue dot on the curves: it corresponds to the statistical +# performance of the model when using `model.predict`. +# +# However, we recall that the original aim was to minimize the cost (or maximize the +# gain) as defined by the business metric. We can compute the value of the business +# metric: +print(f"Business defined metric: {scoring['cost_gain'](model, X_test, y_test)}") + +# %% +# At this stage we don't know if any other cut-off can lead to a greater gain. To find +# the optimal one, we need to compute the cost-gain using the business metric for all +# possible cut-off points and choose the best. This strategy can be quite tedious to +# implement by hand, but the +# :class:`~sklearn.model_selection.TunedThresholdClassifierCV` class is here to help us. +# It automatically computes the cost-gain for all possible cut-off points and optimizes +# for the `scoring`. +# +# .. _cost_sensitive_learning_example: +# +# Tuning the cut-off point +# ^^^^^^^^^^^^^^^^^^^^^^^^ +# +# We use :class:`~sklearn.model_selection.TunedThresholdClassifierCV` to tune the +# cut-off point. We need to provide the business metric to optimize as well as the +# positive label. Internally, the optimum cut-off point is chosen such that it maximizes +# the business metric via cross-validation. By default a 5-fold stratified +# cross-validation is used. +from sklearn.model_selection import TunedThresholdClassifierCV + +tuned_model = TunedThresholdClassifierCV( + estimator=model, + scoring=scoring["cost_gain"], + store_cv_results=True, # necessary to inspect all results +) +tuned_model.fit(X_train, y_train) +print(f"{tuned_model.best_threshold_=:0.2f}") + +# %% +# We plot the ROC and Precision-Recall curves for the vanilla model and the tuned model. +# Also we plot the cut-off points that would be used by each model. Because, we are +# reusing the same code later, we define a function that generates the plots. + + +def plot_roc_pr_curves(vanilla_model, tuned_model, *, title): + fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(21, 6)) + + linestyles = ("dashed", "dotted") + markerstyles = ("o", ">") + colors = ("tab:blue", "tab:orange") + names = ("Vanilla GBDT", "Tuned GBDT") + for idx, (est, linestyle, marker, color, name) in enumerate( + zip((vanilla_model, tuned_model), linestyles, markerstyles, colors, names) + ): + decision_threshold = getattr(est, "best_threshold_", 0.5) + PrecisionRecallDisplay.from_estimator( + est, + X_test, + y_test, + pos_label=pos_label, + linestyle=linestyle, + color=color, + ax=axs[0], + name=name, + ) + axs[0].plot( + scoring["recall"](est, X_test, y_test), + scoring["precision"](est, X_test, y_test), + marker, + markersize=10, + color=color, + label=f"Cut-off point at probability of {decision_threshold:.2f}", + ) + RocCurveDisplay.from_estimator( + est, + X_test, + y_test, + pos_label=pos_label, + linestyle=linestyle, + color=color, + ax=axs[1], + name=name, + plot_chance_level=idx == 1, + ) + axs[1].plot( + scoring["fpr"](est, X_test, y_test), + scoring["tpr"](est, X_test, y_test), + marker, + markersize=10, + color=color, + label=f"Cut-off point at probability of {decision_threshold:.2f}", + ) + + axs[0].set_title("Precision-Recall curve") + axs[0].legend() + axs[1].set_title("ROC curve") + axs[1].legend() + + axs[2].plot( + tuned_model.cv_results_["thresholds"], + tuned_model.cv_results_["scores"], + color="tab:orange", + ) + axs[2].plot( + tuned_model.best_threshold_, + tuned_model.best_score_, + "o", + markersize=10, + color="tab:orange", + label="Optimal cut-off point for the business metric", + ) + axs[2].legend() + axs[2].set_xlabel("Decision threshold (probability)") + axs[2].set_ylabel("Objective score (using cost-matrix)") + axs[2].set_title("Objective score as a function of the decision threshold") + fig.suptitle(title) + + +# %% +title = "Comparison of the cut-off point for the vanilla and tuned GBDT model" +plot_roc_pr_curves(model, tuned_model, title=title) + +# %% +# The first remark is that both classifiers have exactly the same ROC and +# Precision-Recall curves. It is expected because by default, the classifier is fitted +# on the same training data. In a later section, we discuss more in detail the +# available options regarding model refitting and cross-validation. +# +# The second remark is that the cut-off points of the vanilla and tuned model are +# different. To understand why the tuned model has chosen this cut-off point, we can +# look at the right-hand side plot that plots the objective score that is our exactly +# the same as our business metric. We see that the optimum threshold corresponds to the +# maximum of the objective score. This maximum is reached for a decision threshold +# much lower than 0.5: the tuned model enjoys a much higher recall at the cost of +# of significantly lower precision: the tuned model is much more eager to +# predict the "bad" class label to larger fraction of individuals. +# +# We can now check if choosing this cut-off point leads to a better score on the testing +# set: +print(f"Business defined metric: {scoring['cost_gain'](tuned_model, X_test, y_test)}") + +# %% +# We observe that tuning the decision threshold almost improves our business gains +# by factor of 2. +# +# .. _TunedThresholdClassifierCV_no_cv: +# +# Consideration regarding model refitting and cross-validation +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# In the above experiment, we used the default setting of the +# :class:`~sklearn.model_selection.TunedThresholdClassifierCV`. In particular, the +# cut-off point is tuned using a 5-fold stratified cross-validation. Also, the +# underlying predictive model is refitted on the entire training data once the cut-off +# point is chosen. +# +# These two strategies can be changed by providing the `refit` and `cv` parameters. +# For instance, one could provide a fitted `estimator` and set `cv="prefit"`, in which +# case the cut-off point is found on the entire dataset provided at fitting time. +# Also, the underlying classifier is not be refitted by setting `refit=False`. Here, we +# can try to do such experiment. +model.fit(X_train, y_train) +tuned_model.set_params(cv="prefit", refit=False).fit(X_train, y_train) +print(f"{tuned_model.best_threshold_=:0.2f}") + + +# %% +# Then, we evaluate our model with the same approach as before: +title = "Tuned GBDT model without refitting and using the entire dataset" +plot_roc_pr_curves(model, tuned_model, title=title) + +# %% +# We observe the that the optimum cut-off point is different from the one found +# in the previous experiment. If we look at the right-hand side plot, we +# observe that the business gain has large plateau of near-optimal 0 gain for a +# large span of decision thresholds. This behavior is symptomatic of an +# overfitting. Because we disable cross-validation, we tuned the cut-off point +# on the same set as the model was trained on, and this is the reason for the +# observed overfitting. +# +# This option should therefore be used with caution. One needs to make sure that the +# data provided at fitting time to the +# :class:`~sklearn.model_selection.TunedThresholdClassifierCV` is not the same as the +# data used to train the underlying classifier. This could happen sometimes when the +# idea is just to tune the predictive model on a completely new validation set without a +# costly complete refit. +# +# When cross-validation is too costly, a potential alternative is to use a +# single train-test split by providing a floating number in range `[0, 1]` to the `cv` +# parameter. It splits the data into a training and testing set. Let's explore this +# option: +tuned_model.set_params(cv=0.75).fit(X_train, y_train) + +# %% +title = "Tuned GBDT model without refitting and using the entire dataset" +plot_roc_pr_curves(model, tuned_model, title=title) + +# %% +# Regarding the cut-off point, we observe that the optimum is similar to the multiple +# repeated cross-validation case. However, be aware that a single split does not account +# for the variability of the fit/predict process and thus we are unable to know if there +# is any variance in the cut-off point. The repeated cross-validation averages out +# this effect. +# +# Another observation concerns the ROC and Precision-Recall curves of the tuned model. +# As expected, these curves differ from those of the vanilla model, given that we +# trained the underlying classifier on a subset of the data provided during fitting and +# reserved a validation set for tuning the cut-off point. +# +# Cost-sensitive learning when gains and costs are not constant +# ------------------------------------------------------------- +# +# As stated in [2]_, gains and costs are generally not constant in real-world problems. +# In this section, we use a similar example as in [2]_ for the problem of +# detecting fraud in credit card transaction records. +# +# The credit card dataset +# ^^^^^^^^^^^^^^^^^^^^^^^ +credit_card = fetch_openml(data_id=1597, as_frame=True, parser="pandas") +credit_card.frame.info() + +# %% +# The dataset contains information about credit card records from which some are +# fraudulent and others are legitimate. The goal is therefore to predict whether or +# not a credit card record is fraudulent. +columns_to_drop = ["Class"] +data = credit_card.frame.drop(columns=columns_to_drop) +target = credit_card.frame["Class"].astype(int) + +# %% +# First, we check the class distribution of the datasets. +target.value_counts(normalize=True) + +# %% +# The dataset is highly imbalanced with fraudulent transaction representing only 0.17% +# of the data. Since we are interested in training a machine learning model, we should +# also make sure that we have enough samples in the minority class to train the model. +target.value_counts() + +# %% +# We observe that we have around 500 samples that is on the low end of the number of +# samples required to train a machine learning model. In addition of the target +# distribution, we check the distribution of the amount of the +# fraudulent transactions. +fraud = target == 1 +amount_fraud = data["Amount"][fraud] +_, ax = plt.subplots() +ax.hist(amount_fraud, bins=100) +ax.set_title("Amount of fraud transaction") +_ = ax.set_xlabel("Amount ($)") + +# %% +# Addressing the problem with a business metric +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Now, we create the business metric that depends on the amount of each transaction. We +# define the cost matrix similarly to [2]_. Accepting a legitimate transaction provides +# a gain of 2% of the amount of the transaction. However, accepting a fraudulent +# transaction result in a loss of the amount of the transaction. As stated in [2]_, the +# gain and loss related to refusals (of fraudulent and legitimate transactions) are not +# trivial to define. Here, we define that a refusal of a legitimate transaction is +# estimated to a loss of $5 while the refusal of a fraudulent transaction is estimated +# to a gain of $50 dollars and the amount of the transaction. Therefore, we define the +# following function to compute the total benefit of a given decision: + + +def business_metric(y_true, y_pred, amount): + mask_true_positive = (y_true == 1) & (y_pred == 1) + mask_true_negative = (y_true == 0) & (y_pred == 0) + mask_false_positive = (y_true == 0) & (y_pred == 1) + mask_false_negative = (y_true == 1) & (y_pred == 0) + fraudulent_refuse = (mask_true_positive.sum() * 50) + amount[ + mask_true_positive + ].sum() + fraudulent_accept = -amount[mask_false_negative].sum() + legitimate_refuse = mask_false_positive.sum() * -5 + legitimate_accept = (amount[mask_true_negative] * 0.02).sum() + return fraudulent_refuse + fraudulent_accept + legitimate_refuse + legitimate_accept + + +# %% +# From this business metric, we create a scikit-learn scorer that given a fitted +# classifier and a test set compute the business metric. In this regard, we use +# the :func:`~sklearn.metrics.make_scorer` factory. The variable `amount` is an +# additional metadata to be passed to the scorer and we need to use +# :ref:`metadata routing ` to take into account this information. +sklearn.set_config(enable_metadata_routing=True) +business_scorer = make_scorer(business_metric).set_score_request(amount=True) + +# %% +# So at this stage, we observe that the amount of the transaction is used twice: once +# as a feature to train our predictive model and once as a metadata to compute the +# the business metric and thus the statistical performance of our model. When used as a +# feature, we are only required to have a column in `data` that contains the amount of +# each transaction. To use this information as metadata, we need to have an external +# variable that we can pass to the scorer or the model that internally routes this +# metadata to the scorer. So let's create this variable. +amount = credit_card.frame["Amount"].to_numpy() + +# %% +# We first start to train a dummy classifier to have some baseline results. +from sklearn.model_selection import train_test_split + +data_train, data_test, target_train, target_test, amount_train, amount_test = ( + train_test_split( + data, target, amount, stratify=target, test_size=0.5, random_state=42 + ) +) + +# %% +from sklearn.dummy import DummyClassifier + +easy_going_classifier = DummyClassifier(strategy="constant", constant=0) +easy_going_classifier.fit(data_train, target_train) +benefit_cost = business_scorer( + easy_going_classifier, data_test, target_test, amount=amount_test +) +print(f"Benefit/cost of our easy-going classifier: ${benefit_cost:,.2f}") + +# %% +# A classifier that predict all transactions as legitimate would create a profit of +# around $220,000. We make the same evaluation for a classifier that predicts all +# transactions as fraudulent. +intolerant_classifier = DummyClassifier(strategy="constant", constant=1) +intolerant_classifier.fit(data_train, target_train) +benefit_cost = business_scorer( + intolerant_classifier, data_test, target_test, amount=amount_test +) +print(f"Benefit/cost of our intolerant classifier: ${benefit_cost:,.2f}") + +# %% +# Such a classifier create a loss of around $670,000. A predictive model should allow +# us to make a profit larger than $220,000. It is interesting to compare this business +# metric with another "standard" statistical metric such as the balanced accuracy. +from sklearn.metrics import get_scorer + +balanced_accuracy_scorer = get_scorer("balanced_accuracy") +print( + "Balanced accuracy of our easy-going classifier: " + f"{balanced_accuracy_scorer(easy_going_classifier, data_test, target_test):.3f}" +) +print( + "Balanced accuracy of our intolerant classifier: " + f"{balanced_accuracy_scorer(intolerant_classifier, data_test, target_test):.3f}" +) + +# %% +# This is not a surprise that the balanced accuracy is at 0.5 for both classifiers. +# However, we need to be careful in the rest of the evaluation: we potentially can +# obtain a model with a decent balanced accuracy that does not make any profit. +# In this case, the model would be harmful for our business. +# +# Let's now create a predictive model using a logistic regression without tuning the +# decision threshold. +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import GridSearchCV +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler + +logistic_regression = make_pipeline(StandardScaler(), LogisticRegression()) +param_grid = {"logisticregression__C": np.logspace(-6, 6, 13)} +model = GridSearchCV(logistic_regression, param_grid, scoring="neg_log_loss").fit( + data_train, target_train +) + +print( + "Benefit/cost of our logistic regression: " + f"${business_scorer(model, data_test, target_test, amount=amount_test):,.2f}" +) +print( + "Balanced accuracy of our logistic regression: " + f"{balanced_accuracy_scorer(model, data_test, target_test):.3f}" +) + +# %% +# By observing the balanced accuracy, we see that our predictive model is learning +# some associations between the features and the target. The business metric also shows +# that our model is beating the baseline in terms of profit and it would be already +# beneficial to use it instead of ignoring the fraud detection problem. +# +# Tuning the decision threshold +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Now the question is: is our model optimum for the type of decision that we want to do? +# Up to now, we did not optimize the decision threshold. We use the +# :class:`~sklearn.model_selection.TunedThresholdClassifierCV` to optimize the decision +# given our business scorer. To avoid a nested cross-validation, we will use the +# best estimator found during the previous grid-search. +tuned_model = TunedThresholdClassifierCV( + estimator=model.best_estimator_, + scoring=business_scorer, + thresholds=100, + n_jobs=2, +) + +# %% +# Since our business scorer requires the amount of each transaction, we need to pass +# this information in the `fit` method. The +# :class:`~sklearn.model_selection.TunedThresholdClassifierCV` is in charge of +# automatically dispatching this metadata to the underlying scorer. +tuned_model.fit(data_train, target_train, amount=amount_train) + +# %% +print( + "Benefit/cost of our logistic regression: " + f"${business_scorer(tuned_model, data_test, target_test, amount=amount_test):,.2f}" +) +print( + "Balanced accuracy of our logistic regression: " + f"{balanced_accuracy_scorer(tuned_model, data_test, target_test):.3f}" +) + +# %% +# We observe that tuning the decision threshold increases the expected profit of +# deploying our model as estimated by the business metric. +# Eventually, the balanced accuracy also increased. Note that it might not always be +# the case because the statistical metric is not necessarily a surrogate of the +# business metric. It is therefore important, whenever possible, optimize the decision +# threshold with respect to the business metric. +# +# Finally, the estimate of the business metric itself can be unreliable, in +# particular when the number of data points in the minority class is so small. +# Any business impact estimated by cross-validation of a business metric on +# historical data (offline evaluation) should ideally be confirmed by A/B testing +# on live data (online evaluation). Note however that A/B testing models is +# beyond the scope of the scikit-learn library itself. +# +# Manually setting the decision threshold instead of tuning it +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# In the previous example, we used the +# :class:`~sklearn.model_selection.TunedThresholdClassifierCV` to find the optimal +# decision threshold. However, in some cases, we might have some prior knowledge about +# the problem at hand and we might be happy to set the decision threshold manually. +# +# The class :class:`~sklearn.model_selection.FixedThresholdClassifier` allows us to +# manually set the decision threshold. At prediction time, it behave as the previous +# tuned model but no search is performed during the fitting process. +# +# Here, we will reuse the decision threshold found in the previous section to create a +# new model and check that it gives the same results. +from sklearn.model_selection import FixedThresholdClassifier + +model_fixed_threshold = FixedThresholdClassifier( + estimator=model, threshold=tuned_model.best_threshold_ +).fit(data_train, target_train) + +# %% +business_score = business_scorer( + model_fixed_threshold, data_test, target_test, amount=amount_test +) +print(f"Benefit/cost of our logistic regression: ${business_score:,.2f}") +print( + "Balanced accuracy of our logistic regression: " + f"{balanced_accuracy_scorer(model_fixed_threshold, data_test, target_test):.3f}" +) + +# %% +# We observe that we obtained the exact same results but the fitting process was much +# faster since we did not perform any search. diff --git a/examples/model_selection/plot_tuned_decision_threshold.py b/examples/model_selection/plot_tuned_decision_threshold.py new file mode 100644 index 0000000000000..7e997ee255e4d --- /dev/null +++ b/examples/model_selection/plot_tuned_decision_threshold.py @@ -0,0 +1,184 @@ +""" +====================================================== +Post-hoc tuning the cut-off point of decision function +====================================================== + +Once a binary classifier is trained, the :term:`predict` method outputs class label +predictions corresponding to a thresholding of either the :term:`decision_function` or +the :term:`predict_proba` output. The default threshold is defined as a posterior +probability estimate of 0.5 or a decision score of 0.0. However, this default strategy +may not be optimal for the task at hand. + +This example shows how to use the +:class:`~sklearn.model_selection.TunedThresholdClassifierCV` to tune the decision +threshold, depending on a metric of interest. +""" + +# %% +# The diabetes dataset +# -------------------- +# +# To illustrate the tuning of the decision threshold, we will use the diabetes dataset. +# This dataset is available on OpenML: https://www.openml.org/d/37. We use the +# :func:`~sklearn.datasets.fetch_openml` function to fetch this dataset. +from sklearn.datasets import fetch_openml + +diabetes = fetch_openml(data_id=37, as_frame=True, parser="pandas") +data, target = diabetes.data, diabetes.target + +# %% +# We look at the target to understand the type of problem we are dealing with. +target.value_counts() + +# %% +# We can see that we are dealing with a binary classification problem. Since the +# labels are not encoded as 0 and 1, we make it explicit that we consider the class +# labeled "tested_negative" as the negative class (which is also the most frequent) +# and the class labeled "tested_positive" the positive as the positive class: +neg_label, pos_label = target.value_counts().index + +# %% +# We can also observe that this binary problem is slightly imbalanced where we have +# around twice more samples from the negative class than from the positive class. When +# it comes to evaluation, we should consider this aspect to interpret the results. +# +# Our vanilla classifier +# ---------------------- +# +# We define a basic predictive model composed of a scaler followed by a logistic +# regression classifier. +from sklearn.linear_model import LogisticRegression +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler + +model = make_pipeline(StandardScaler(), LogisticRegression()) +model + +# %% +# We evaluate our model using cross-validation. We use the accuracy and the balanced +# accuracy to report the performance of our model. The balanced accuracy is a metric +# that is less sensitive to class imbalance and will allow us to put the accuracy +# score in perspective. +# +# Cross-validation allows us to study the variance of the decision threshold across +# different splits of the data. However, the dataset is rather small and it would be +# detrimental to use more than 5 folds to evaluate the dispersion. Therefore, we use +# a :class:`~sklearn.model_selection.RepeatedStratifiedKFold` where we apply several +# repetitions of 5-fold cross-validation. +import pandas as pd + +from sklearn.model_selection import RepeatedStratifiedKFold, cross_validate + +scoring = ["accuracy", "balanced_accuracy"] +cv_scores = [ + "train_accuracy", + "test_accuracy", + "train_balanced_accuracy", + "test_balanced_accuracy", +] +cv = RepeatedStratifiedKFold(n_splits=5, n_repeats=10, random_state=42) +cv_results_vanilla_model = pd.DataFrame( + cross_validate( + model, + data, + target, + scoring=scoring, + cv=cv, + return_train_score=True, + return_estimator=True, + ) +) +cv_results_vanilla_model[cv_scores].aggregate(["mean", "std"]).T + +# %% +# Our predictive model succeeds to grasp the relationship between the data and the +# target. The training and testing scores are close to each other, meaning that our +# predictive model is not overfitting. We can also observe that the balanced accuracy is +# lower than the accuracy, due to the class imbalance previously mentioned. +# +# For this classifier, we let the decision threshold, used convert the probability of +# the positive class into a class prediction, to its default value: 0.5. However, this +# threshold might not be optimal. If our interest is to maximize the balanced accuracy, +# we should select another threshold that would maximize this metric. +# +# The :class:`~sklearn.model_selection.TunedThresholdClassifierCV` meta-estimator allows +# to tune the decision threshold of a classifier given a metric of interest. +# +# Tuning the decision threshold +# ----------------------------- +# +# We create a :class:`~sklearn.model_selection.TunedThresholdClassifierCV` and +# configure it to maximize the balanced accuracy. We evaluate the model using the same +# cross-validation strategy as previously. +from sklearn.model_selection import TunedThresholdClassifierCV + +tuned_model = TunedThresholdClassifierCV(estimator=model, scoring="balanced_accuracy") +cv_results_tuned_model = pd.DataFrame( + cross_validate( + tuned_model, + data, + target, + scoring=scoring, + cv=cv, + return_train_score=True, + return_estimator=True, + ) +) +cv_results_tuned_model[cv_scores].aggregate(["mean", "std"]).T + +# %% +# In comparison with the vanilla model, we observe that the balanced accuracy score +# increased. Of course, it comes at the cost of a lower accuracy score. It means that +# our model is now more sensitive to the positive class but makes more mistakes on the +# negative class. +# +# However, it is important to note that this tuned predictive model is internally the +# same model as the vanilla model: they have the same fitted coefficients. +import matplotlib.pyplot as plt + +vanilla_model_coef = pd.DataFrame( + [est[-1].coef_.ravel() for est in cv_results_vanilla_model["estimator"]], + columns=diabetes.feature_names, +) +tuned_model_coef = pd.DataFrame( + [est.estimator_[-1].coef_.ravel() for est in cv_results_tuned_model["estimator"]], + columns=diabetes.feature_names, +) + +fig, ax = plt.subplots(ncols=2, figsize=(12, 4), sharex=True, sharey=True) +vanilla_model_coef.boxplot(ax=ax[0]) +ax[0].set_ylabel("Coefficient value") +ax[0].set_title("Vanilla model") +tuned_model_coef.boxplot(ax=ax[1]) +ax[1].set_title("Tuned model") +_ = fig.suptitle("Coefficients of the predictive models") + +# %% +# Only the decision threshold of each model was changed during the cross-validation. +decision_threshold = pd.Series( + [est.best_threshold_ for est in cv_results_tuned_model["estimator"]], +) +ax = decision_threshold.plot.kde() +ax.axvline( + decision_threshold.mean(), + color="k", + linestyle="--", + label=f"Mean decision threshold: {decision_threshold.mean():.2f}", +) +ax.set_xlabel("Decision threshold") +ax.legend(loc="upper right") +_ = ax.set_title( + "Distribution of the decision threshold \nacross different cross-validation folds" +) + +# %% +# In average, a decision threshold around 0.32 maximizes the balanced accuracy, which is +# different from the default decision threshold of 0.5. Thus tuning the decision +# threshold is particularly important when the output of the predictive model +# is used to make decisions. Besides, the metric used to tune the decision threshold +# should be chosen carefully. Here, we used the balanced accuracy but it might not be +# the most appropriate metric for the problem at hand. The choice of the "right" metric +# is usually problem-dependent and might require some domain knowledge. Refer to the +# example entitled, +# :ref:`sphx_glr_auto_examples_model_selection_plot_cost_sensitive_learning.py`, +# for more details. diff --git a/sklearn/metrics/_scorer.py b/sklearn/metrics/_scorer.py index 19ca055a81b95..cf06ea1798dc8 100644 --- a/sklearn/metrics/_scorer.py +++ b/sklearn/metrics/_scorer.py @@ -193,6 +193,25 @@ def get_metadata_routing(self): class _BaseScorer(_MetadataRequester): + """Base scorer that is used as `scorer(estimator, X, y_true)`. + + Parameters + ---------- + score_func : callable + The score function to use. It will be called as + `score_func(y_true, y_pred, **kwargs)`. + + sign : int + Either 1 or -1 to returns the score with `sign * score_func(estimator, X, y)`. + Thus, `sign` defined if higher scores are better or worse. + + kwargs : dict + Additional parameters to pass to the score function. + + response_method : str + The method to call on the estimator to get the response values. + """ + def __init__(self, score_func, sign, kwargs, response_method="predict"): self._score_func = score_func self._sign = sign diff --git a/sklearn/model_selection/__init__.py b/sklearn/model_selection/__init__.py index d7d316d95ada4..c97d48f4b20b7 100644 --- a/sklearn/model_selection/__init__.py +++ b/sklearn/model_selection/__init__.py @@ -1,5 +1,9 @@ import typing +from ._classification_threshold import ( + FixedThresholdClassifier, + TunedThresholdClassifierCV, +) from ._plot import LearningCurveDisplay, ValidationCurveDisplay from ._search import GridSearchCV, ParameterGrid, ParameterSampler, RandomizedSearchCV from ._split import ( @@ -63,6 +67,8 @@ "StratifiedKFold", "StratifiedGroupKFold", "StratifiedShuffleSplit", + "FixedThresholdClassifier", + "TunedThresholdClassifierCV", "check_cv", "cross_val_predict", "cross_val_score", diff --git a/sklearn/model_selection/_classification_threshold.py b/sklearn/model_selection/_classification_threshold.py new file mode 100644 index 0000000000000..d5a864da10653 --- /dev/null +++ b/sklearn/model_selection/_classification_threshold.py @@ -0,0 +1,1000 @@ +from collections.abc import MutableMapping +from numbers import Integral, Real + +import numpy as np + +from ..base import ( + BaseEstimator, + ClassifierMixin, + MetaEstimatorMixin, + _fit_context, + clone, +) +from ..exceptions import NotFittedError +from ..metrics import ( + check_scoring, + get_scorer_names, +) +from ..metrics._scorer import _BaseScorer +from ..utils import _safe_indexing +from ..utils._param_validation import HasMethods, Interval, RealNotInt, StrOptions +from ..utils._response import _get_response_values_binary +from ..utils.metadata_routing import ( + MetadataRouter, + MethodMapping, + _raise_for_params, + process_routing, +) +from ..utils.metaestimators import available_if +from ..utils.multiclass import type_of_target +from ..utils.parallel import Parallel, delayed +from ..utils.validation import ( + _check_method_params, + _num_samples, + check_is_fitted, + indexable, +) +from ._split import StratifiedShuffleSplit, check_cv + + +def _estimator_has(attr): + """Check if we can delegate a method to the underlying estimator. + + First, we check the fitted estimator if available, otherwise we + check the unfitted estimator. + """ + + def check(self): + if hasattr(self, "estimator_"): + getattr(self.estimator_, attr) + else: + getattr(self.estimator, attr) + return True + + return check + + +def _threshold_scores_to_class_labels(y_score, threshold, classes, pos_label): + """Threshold `y_score` and return the associated class labels.""" + if pos_label is None: + map_thresholded_score_to_label = np.array([0, 1]) + else: + pos_label_idx = np.flatnonzero(classes == pos_label)[0] + neg_label_idx = np.flatnonzero(classes != pos_label)[0] + map_thresholded_score_to_label = np.array([neg_label_idx, pos_label_idx]) + + return classes[map_thresholded_score_to_label[(y_score >= threshold).astype(int)]] + + +class BaseThresholdClassifier(ClassifierMixin, MetaEstimatorMixin, BaseEstimator): + """Base class for binary classifiers that set a non-default decision threshold. + + In this base class, we define the following interface: + + - the validation of common parameters in `fit`; + - the different prediction methods that can be used with the classifier. + + .. versionadded:: 1.5 + + Parameters + ---------- + estimator : estimator instance + The binary classifier, fitted or not, for which we want to optimize + the decision threshold used during `predict`. + + response_method : {"auto", "decision_function", "predict_proba"}, default="auto" + Methods by the classifier `estimator` corresponding to the + decision function for which we want to find a threshold. It can be: + + * if `"auto"`, it will try to invoke, for each classifier, + `"predict_proba"` or `"decision_function"` in that order. + * otherwise, one of `"predict_proba"` or `"decision_function"`. + If the method is not implemented by the classifier, it will raise an + error. + """ + + _required_parameters = ["estimator"] + _parameter_constraints: dict = { + "estimator": [ + HasMethods(["fit", "predict_proba"]), + HasMethods(["fit", "decision_function"]), + ], + "response_method": [StrOptions({"auto", "predict_proba", "decision_function"})], + } + + def __init__(self, estimator, *, response_method="auto"): + self.estimator = estimator + self.response_method = response_method + + @_fit_context( + # *ThresholdClassifier*.estimator is not validated yet + prefer_skip_nested_validation=False + ) + def fit(self, X, y, **params): + """Fit the classifier. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Training data. + + y : array-like of shape (n_samples,) + Target values. + + **params : dict + Parameters to pass to the `fit` method of the underlying + classifier. + + Returns + ------- + self : object + Returns an instance of self. + """ + _raise_for_params(params, self, None) + + X, y = indexable(X, y) + + y_type = type_of_target(y, input_name="y") + if y_type != "binary": + raise ValueError( + f"Only binary classification is supported. Unknown label type: {y_type}" + ) + + if self.response_method == "auto": + self._response_method = ["predict_proba", "decision_function"] + else: + self._response_method = self.response_method + + self._fit(X, y, **params) + + if hasattr(self.estimator_, "n_features_in_"): + self.n_features_in_ = self.estimator_.n_features_in_ + if hasattr(self.estimator_, "feature_names_in_"): + self.feature_names_in_ = self.estimator_.feature_names_in_ + + return self + + @property + def classes_(self): + """Classes labels.""" + return self.estimator_.classes_ + + @available_if(_estimator_has("predict_proba")) + def predict_proba(self, X): + """Predict class probabilities for `X` using the fitted estimator. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Training vectors, where `n_samples` is the number of samples and + `n_features` is the number of features. + + Returns + ------- + probabilities : ndarray of shape (n_samples, n_classes) + The class probabilities of the input samples. + """ + check_is_fitted(self, "estimator_") + return self.estimator_.predict_proba(X) + + @available_if(_estimator_has("predict_log_proba")) + def predict_log_proba(self, X): + """Predict logarithm class probabilities for `X` using the fitted estimator. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Training vectors, where `n_samples` is the number of samples and + `n_features` is the number of features. + + Returns + ------- + log_probabilities : ndarray of shape (n_samples, n_classes) + The logarithm class probabilities of the input samples. + """ + check_is_fitted(self, "estimator_") + return self.estimator_.predict_log_proba(X) + + @available_if(_estimator_has("decision_function")) + def decision_function(self, X): + """Decision function for samples in `X` using the fitted estimator. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Training vectors, where `n_samples` is the number of samples and + `n_features` is the number of features. + + Returns + ------- + decisions : ndarray of shape (n_samples,) + The decision function computed the fitted estimator. + """ + check_is_fitted(self, "estimator_") + return self.estimator_.decision_function(X) + + def _more_tags(self): + return { + "binary_only": True, + "_xfail_checks": { + "check_classifiers_train": "Threshold at probability 0.5 does not hold", + "check_sample_weights_invariance": ( + "Due to the cross-validation and sample ordering, removing a sample" + " is not strictly equal to putting is weight to zero. Specific unit" + " tests are added for TunedThresholdClassifierCV specifically." + ), + }, + } + + +class FixedThresholdClassifier(BaseThresholdClassifier): + """Binary classifier that manually sets the decision threshold. + + This classifier allows to change the default decision threshold used for + converting posterior probability estimates (i.e. output of `predict_proba`) or + decision scores (i.e. output of `decision_function`) into a class label. + + Here, the threshold is not optimized and is set to a constant value. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.5 + + Parameters + ---------- + estimator : estimator instance + The binary classifier, fitted or not, for which we want to optimize + the decision threshold used during `predict`. + + threshold : {"auto"} or float, default="auto" + The decision threshold to use when converting posterior probability estimates + (i.e. output of `predict_proba`) or decision scores (i.e. output of + `decision_function`) into a class label. When `"auto"`, the threshold is set + to 0.5 if `predict_proba` is used as `response_method`, otherwise it is set to + 0 (i.e. the default threshold for `decision_function`). + + pos_label : int, float, bool or str, default=None + The label of the positive class. Used to process the output of the + `response_method` method. When `pos_label=None`, if `y_true` is in `{-1, 1}` or + `{0, 1}`, `pos_label` is set to 1, otherwise an error will be raised. + + response_method : {"auto", "decision_function", "predict_proba"}, default="auto" + Methods by the classifier `estimator` corresponding to the + decision function for which we want to find a threshold. It can be: + + * if `"auto"`, it will try to invoke `"predict_proba"` or `"decision_function"` + in that order. + * otherwise, one of `"predict_proba"` or `"decision_function"`. + If the method is not implemented by the classifier, it will raise an + error. + + Attributes + ---------- + estimator_ : estimator instance + The fitted classifier used when predicting. + + classes_ : ndarray of shape (n_classes,) + The class labels. + + n_features_in_ : int + Number of features seen during :term:`fit`. Only defined if the + underlying estimator exposes such an attribute when fit. + + feature_names_in_ : ndarray of shape (`n_features_in_`,) + Names of features seen during :term:`fit`. Only defined if the + underlying estimator exposes such an attribute when fit. + + See Also + -------- + sklearn.model_selection.TunedThresholdClassifierCV : Classifier that post-tunes + the decision threshold based on some metrics and using cross-validation. + sklearn.calibration.CalibratedClassifierCV : Estimator that calibrates + probabilities. + + Examples + -------- + >>> from sklearn.datasets import make_classification + >>> from sklearn.linear_model import LogisticRegression + >>> from sklearn.metrics import confusion_matrix + >>> from sklearn.model_selection import FixedThresholdClassifier, train_test_split + >>> X, y = make_classification( + ... n_samples=1_000, weights=[0.9, 0.1], class_sep=0.8, random_state=42 + ... ) + >>> X_train, X_test, y_train, y_test = train_test_split( + ... X, y, stratify=y, random_state=42 + ... ) + >>> classifier = LogisticRegression(random_state=0).fit(X_train, y_train) + >>> print(confusion_matrix(y_test, classifier.predict(X_test))) + [[217 7] + [ 19 7]] + >>> classifier_other_threshold = FixedThresholdClassifier( + ... classifier, threshold=0.1, response_method="predict_proba" + ... ).fit(X_train, y_train) + >>> print(confusion_matrix(y_test, classifier_other_threshold.predict(X_test))) + [[184 40] + [ 6 20]] + """ + + _parameter_constraints: dict = { + **BaseThresholdClassifier._parameter_constraints, + "threshold": [StrOptions({"auto"}), Real], + "pos_label": [Real, str, "boolean", None], + } + + def __init__( + self, + estimator, + *, + threshold="auto", + pos_label=None, + response_method="auto", + ): + super().__init__(estimator=estimator, response_method=response_method) + self.pos_label = pos_label + self.threshold = threshold + + def _fit(self, X, y, **params): + """Fit the classifier. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Training data. + + y : array-like of shape (n_samples,) + Target values. + + **params : dict + Parameters to pass to the `fit` method of the underlying + classifier. + + Returns + ------- + self : object + Returns an instance of self. + """ + routed_params = process_routing(self, "fit", **params) + self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit) + return self + + def predict(self, X): + """Predict the target of new samples. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The samples, as accepted by `estimator.predict`. + + Returns + ------- + class_labels : ndarray of shape (n_samples,) + The predicted class. + """ + check_is_fitted(self, "estimator_") + y_score, _, response_method_used = _get_response_values_binary( + self.estimator_, + X, + self._response_method, + pos_label=self.pos_label, + return_response_method_used=True, + ) + + if self.threshold == "auto": + decision_threshold = 0.5 if response_method_used == "predict_proba" else 0.0 + else: + decision_threshold = self.threshold + + return _threshold_scores_to_class_labels( + y_score, decision_threshold, self.classes_, self.pos_label + ) + + def get_metadata_routing(self): + """Get metadata routing of this object. + + Please check :ref:`User Guide ` on how the routing + mechanism works. + + Returns + ------- + routing : MetadataRouter + A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating + routing information. + """ + router = MetadataRouter(owner=self.__class__.__name__).add( + estimator=self.estimator, + method_mapping=MethodMapping().add(callee="fit", caller="fit"), + ) + return router + + +class _CurveScorer(_BaseScorer): + """Scorer taking a continuous response and output a score for each threshold. + + Parameters + ---------- + score_func : callable + The score function to use. It will be called as + `score_func(y_true, y_pred, **kwargs)`. + + sign : int + Either 1 or -1 to returns the score with `sign * score_func(estimator, X, y)`. + Thus, `sign` defined if higher scores are better or worse. + + kwargs : dict + Additional parameters to pass to the score function. + + thresholds : int or array-like + Related to the number of decision thresholds for which we want to compute the + score. If an integer, it will be used to generate `thresholds` thresholds + uniformly distributed between the minimum and maximum predicted scores. If an + array-like, it will be used as the thresholds. + + response_method : str + The method to call on the estimator to get the response values. + """ + + def __init__(self, score_func, sign, kwargs, thresholds, response_method): + super().__init__( + score_func=score_func, + sign=sign, + kwargs=kwargs, + response_method=response_method, + ) + self._thresholds = thresholds + + @classmethod + def from_scorer(cls, scorer, response_method, thresholds): + """Create a continuous scorer from a normal scorer.""" + instance = cls( + score_func=scorer._score_func, + sign=scorer._sign, + response_method=response_method, + thresholds=thresholds, + kwargs=scorer._kwargs, + ) + # transfer the metadata request + instance._metadata_request = scorer._get_metadata_request() + return instance + + def _score(self, method_caller, estimator, X, y_true, **kwargs): + """Evaluate predicted target values for X relative to y_true. + + Parameters + ---------- + method_caller : callable + Returns predictions given an estimator, method name, and other + arguments, potentially caching results. + + estimator : object + Trained estimator to use for scoring. + + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Test data that will be fed to estimator.predict. + + y_true : array-like of shape (n_samples,) + Gold standard target values for X. + + **kwargs : dict + Other parameters passed to the scorer. Refer to + :func:`set_score_request` for more details. + + Returns + ------- + scores : ndarray of shape (thresholds,) + The scores associated to each threshold. + + potential_thresholds : ndarray of shape (thresholds,) + The potential thresholds used to compute the scores. + """ + pos_label = self._get_pos_label() + y_score = method_caller( + estimator, self._response_method, X, pos_label=pos_label + ) + + scoring_kwargs = {**self._kwargs, **kwargs} + if isinstance(self._thresholds, Integral): + potential_thresholds = np.linspace( + np.min(y_score), np.max(y_score), self._thresholds + ) + else: + potential_thresholds = np.asarray(self._thresholds) + score_thresholds = [ + self._sign + * self._score_func( + y_true, + _threshold_scores_to_class_labels( + y_score, th, estimator.classes_, pos_label + ), + **scoring_kwargs, + ) + for th in potential_thresholds + ] + return np.array(score_thresholds), potential_thresholds + + +def _fit_and_score_over_thresholds( + classifier, + X, + y, + *, + fit_params, + train_idx, + val_idx, + curve_scorer, + score_params, +): + """Fit a classifier and compute the scores for different decision thresholds. + + Parameters + ---------- + classifier : estimator instance + The classifier to fit and use for scoring. If `classifier` is already fitted, + it will be used as is. + + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The entire dataset. + + y : array-like of shape (n_samples,) + The entire target vector. + + fit_params : dict + Parameters to pass to the `fit` method of the underlying classifier. + + train_idx : ndarray of shape (n_train_samples,) or None + The indices of the training set. If `None`, `classifier` is expected to be + already fitted. + + val_idx : ndarray of shape (n_val_samples,) + The indices of the validation set used to score `classifier`. If `train_idx`, + the entire set will be used. + + curve_scorer : scorer instance + The scorer taking `classifier` and the validation set as input and outputting + decision thresholds and scores as a curve. Note that this is different from + the usual scorer that output a single score value: + + * when `score_method` is one of the four constraint metrics, the curve scorer + will output a curve of two scores parametrized by the decision threshold, e.g. + TPR/TNR or precision/recall curves for each threshold; + * otherwise, the curve scorer will output a single score value for each + threshold. + + score_params : dict + Parameters to pass to the `score` method of the underlying scorer. + + Returns + ------- + scores : ndarray of shape (thresholds,) or tuple of such arrays + The scores computed for each decision threshold. When TPR/TNR or precision/ + recall are computed, `scores` is a tuple of two arrays. + + potential_thresholds : ndarray of shape (thresholds,) + The decision thresholds used to compute the scores. They are returned in + ascending order. + """ + + if train_idx is not None: + X_train, X_val = _safe_indexing(X, train_idx), _safe_indexing(X, val_idx) + y_train, y_val = _safe_indexing(y, train_idx), _safe_indexing(y, val_idx) + fit_params_train = _check_method_params(X, fit_params, indices=train_idx) + score_params_val = _check_method_params(X, score_params, indices=val_idx) + classifier.fit(X_train, y_train, **fit_params_train) + else: # prefit estimator, only a validation set is provided + X_val, y_val, score_params_val = X, y, score_params + + return curve_scorer(classifier, X_val, y_val, **score_params_val) + + +def _mean_interpolated_score(target_thresholds, cv_thresholds, cv_scores): + """Compute the mean interpolated score across folds by defining common thresholds. + + Parameters + ---------- + target_thresholds : ndarray of shape (thresholds,) + The thresholds to use to compute the mean score. + + cv_thresholds : ndarray of shape (n_folds, thresholds_fold) + The thresholds used to compute the scores for each fold. + + cv_scores : ndarray of shape (n_folds, thresholds_fold) + The scores computed for each threshold for each fold. + + Returns + ------- + mean_score : ndarray of shape (thresholds,) + The mean score across all folds for each target threshold. + """ + return np.mean( + [ + np.interp(target_thresholds, split_thresholds, split_score) + for split_thresholds, split_score in zip(cv_thresholds, cv_scores) + ], + axis=0, + ) + + +class TunedThresholdClassifierCV(BaseThresholdClassifier): + """Classifier that post-tunes the decision threshold using cross-validation. + + This estimator post-tunes the decision threshold (cut-off point) that is + used for converting posterior probability estimates (i.e. output of + `predict_proba`) or decision scores (i.e. output of `decision_function`) + into a class label. The tuning is done by optimizing a binary metric, + potentially constrained by a another metric. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.5 + + Parameters + ---------- + estimator : estimator instance + The classifier, fitted or not, for which we want to optimize + the decision threshold used during `predict`. + + scoring : str or callable, default="balanced_accuracy" + The objective metric to be optimized. Can be one of: + + * a string associated to a scoring function for binary classification + (see model evaluation documentation); + * a scorer callable object created with :func:`~sklearn.metrics.make_scorer`; + + response_method : {"auto", "decision_function", "predict_proba"}, default="auto" + Methods by the classifier `estimator` corresponding to the + decision function for which we want to find a threshold. It can be: + + * if `"auto"`, it will try to invoke, for each classifier, + `"predict_proba"` or `"decision_function"` in that order. + * otherwise, one of `"predict_proba"` or `"decision_function"`. + If the method is not implemented by the classifier, it will raise an + error. + + thresholds : int or array-like, default=100 + The number of decision threshold to use when discretizing the output of the + classifier `method`. Pass an array-like to manually specify the thresholds + to use. + + cv : int, float, cross-validation generator, iterable or "prefit", default=None + Determines the cross-validation splitting strategy to train classifier. + Possible inputs for cv are: + + * `None`, to use the default 5-fold stratified K-fold cross validation; + * An integer number, to specify the number of folds in a stratified k-fold; + * A float number, to specify a single shuffle split. The floating number should + be in (0, 1) and represent the size of the validation set; + * An object to be used as a cross-validation generator; + * An iterable yielding train, test splits; + * `"prefit"`, to bypass the cross-validation. + + Refer :ref:`User Guide ` for the various + cross-validation strategies that can be used here. + + .. warning:: + Using `cv="prefit"` and passing the same dataset for fitting `estimator` + and tuning the cut-off point is subject to undesired overfitting. You can + refer to :ref:`TunedThresholdClassifierCV_no_cv` for an example. + + This option should only be used when the set used to fit `estimator` is + different from the one used to tune the cut-off point (by calling + :meth:`TunedThresholdClassifierCV.fit`). + + refit : bool, default=True + Whether or not to refit the classifier on the entire training set once + the decision threshold has been found. + Note that forcing `refit=False` on cross-validation having more + than a single split will raise an error. Similarly, `refit=True` in + conjunction with `cv="prefit"` will raise an error. + + n_jobs : int, default=None + The number of jobs to run in parallel. When `cv` represents a + cross-validation strategy, the fitting and scoring on each data split + is done in parallel. ``None`` means 1 unless in a + :obj:`joblib.parallel_backend` context. ``-1`` means using all + processors. See :term:`Glossary ` for more details. + + random_state : int, RandomState instance or None, default=None + Controls the randomness of cross-validation when `cv` is a float. + See :term:`Glossary `. + + store_cv_results : bool, default=False + Whether to store all scores and thresholds computed during the cross-validation + process. + + Attributes + ---------- + estimator_ : estimator instance + The fitted classifier used when predicting. + + best_threshold_ : float + The new decision threshold. + + best_score_ : float or None + The optimal score of the objective metric, evaluated at `best_threshold_`. + + cv_results_ : dict or None + A dictionary containing the scores and thresholds computed during the + cross-validation process. Only exist if `store_cv_results=True`. The + keys are `"thresholds"` and `"scores"`. + + classes_ : ndarray of shape (n_classes,) + The class labels. + + n_features_in_ : int + Number of features seen during :term:`fit`. Only defined if the + underlying estimator exposes such an attribute when fit. + + feature_names_in_ : ndarray of shape (`n_features_in_`,) + Names of features seen during :term:`fit`. Only defined if the + underlying estimator exposes such an attribute when fit. + + See Also + -------- + sklearn.model_selection.FixedThresholdClassifier : Classifier that uses a + constant threshold. + sklearn.calibration.CalibratedClassifierCV : Estimator that calibrates + probabilities. + + Examples + -------- + >>> from sklearn.datasets import make_classification + >>> from sklearn.ensemble import RandomForestClassifier + >>> from sklearn.metrics import classification_report + >>> from sklearn.model_selection import TunedThresholdClassifierCV, train_test_split + >>> X, y = make_classification( + ... n_samples=1_000, weights=[0.9, 0.1], class_sep=0.8, random_state=42 + ... ) + >>> X_train, X_test, y_train, y_test = train_test_split( + ... X, y, stratify=y, random_state=42 + ... ) + >>> classifier = RandomForestClassifier(random_state=0).fit(X_train, y_train) + >>> print(classification_report(y_test, classifier.predict(X_test))) + precision recall f1-score support + + 0 0.94 0.99 0.96 224 + 1 0.80 0.46 0.59 26 + + accuracy 0.93 250 + macro avg 0.87 0.72 0.77 250 + weighted avg 0.93 0.93 0.92 250 + + >>> classifier_tuned = TunedThresholdClassifierCV( + ... classifier, scoring="balanced_accuracy" + ... ).fit(X_train, y_train) + >>> print( + ... f"Cut-off point found at {classifier_tuned.best_threshold_:.3f}" + ... ) + Cut-off point found at 0.342 + >>> print(classification_report(y_test, classifier_tuned.predict(X_test))) + precision recall f1-score support + + 0 0.96 0.95 0.96 224 + 1 0.61 0.65 0.63 26 + + accuracy 0.92 250 + macro avg 0.78 0.80 0.79 250 + weighted avg 0.92 0.92 0.92 250 + + """ + + _parameter_constraints: dict = { + **BaseThresholdClassifier._parameter_constraints, + "scoring": [ + StrOptions(set(get_scorer_names())), + callable, + MutableMapping, + ], + "thresholds": [Interval(Integral, 1, None, closed="left"), "array-like"], + "cv": [ + "cv_object", + StrOptions({"prefit"}), + Interval(RealNotInt, 0.0, 1.0, closed="neither"), + ], + "refit": ["boolean"], + "n_jobs": [Integral, None], + "random_state": ["random_state"], + "store_cv_results": ["boolean"], + } + + def __init__( + self, + estimator, + *, + scoring="balanced_accuracy", + response_method="auto", + thresholds=100, + cv=None, + refit=True, + n_jobs=None, + random_state=None, + store_cv_results=False, + ): + super().__init__(estimator=estimator, response_method=response_method) + self.scoring = scoring + self.thresholds = thresholds + self.cv = cv + self.refit = refit + self.n_jobs = n_jobs + self.random_state = random_state + self.store_cv_results = store_cv_results + + def _fit(self, X, y, **params): + """Fit the classifier and post-tune the decision threshold. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Training data. + + y : array-like of shape (n_samples,) + Target values. + + **params : dict + Parameters to pass to the `fit` method of the underlying + classifier and to the `scoring` scorer. + + Returns + ------- + self : object + Returns an instance of self. + """ + if isinstance(self.cv, Real) and 0 < self.cv < 1: + cv = StratifiedShuffleSplit( + n_splits=1, test_size=self.cv, random_state=self.random_state + ) + elif self.cv == "prefit": + if self.refit is True: + raise ValueError("When cv='prefit', refit cannot be True.") + try: + check_is_fitted(self.estimator, "classes_") + except NotFittedError as exc: + raise NotFittedError( + """When cv='prefit', `estimator` must be fitted.""" + ) from exc + cv = self.cv + else: + cv = check_cv(self.cv, y=y, classifier=True) + if self.refit is False and cv.get_n_splits() > 1: + raise ValueError("When cv has several folds, refit cannot be False.") + + routed_params = process_routing(self, "fit", **params) + self._curve_scorer = self._get_curve_scorer() + + # in the following block, we: + # - define the final classifier `self.estimator_` and train it if necessary + # - define `classifier` to be used to post-tune the decision threshold + # - define `split` to be used to fit/score `classifier` + if cv == "prefit": + self.estimator_ = self.estimator + classifier = self.estimator_ + splits = [(None, range(_num_samples(X)))] + else: + self.estimator_ = clone(self.estimator) + classifier = clone(self.estimator) + splits = cv.split(X, y, **routed_params.splitter.split) + + if self.refit: + # train on the whole dataset + X_train, y_train, fit_params_train = X, y, routed_params.estimator.fit + else: + # single split cross-validation + train_idx, _ = next(cv.split(X, y, **routed_params.splitter.split)) + X_train = _safe_indexing(X, train_idx) + y_train = _safe_indexing(y, train_idx) + fit_params_train = _check_method_params( + X, routed_params.estimator.fit, indices=train_idx + ) + + self.estimator_.fit(X_train, y_train, **fit_params_train) + + cv_scores, cv_thresholds = zip( + *Parallel(n_jobs=self.n_jobs)( + delayed(_fit_and_score_over_thresholds)( + clone(classifier) if cv != "prefit" else classifier, + X, + y, + fit_params=routed_params.estimator.fit, + train_idx=train_idx, + val_idx=val_idx, + curve_scorer=self._curve_scorer, + score_params=routed_params.scorer.score, + ) + for train_idx, val_idx in splits + ) + ) + + if any(np.isclose(th[0], th[-1]) for th in cv_thresholds): + raise ValueError( + "The provided estimator makes constant predictions. Therefore, it is " + "impossible to optimize the decision threshold." + ) + + # find the global min and max thresholds across all folds + min_threshold = min( + split_thresholds.min() for split_thresholds in cv_thresholds + ) + max_threshold = max( + split_thresholds.max() for split_thresholds in cv_thresholds + ) + if isinstance(self.thresholds, Integral): + decision_thresholds = np.linspace( + min_threshold, max_threshold, num=self.thresholds + ) + else: + decision_thresholds = np.asarray(self.thresholds) + + objective_scores = _mean_interpolated_score( + decision_thresholds, cv_thresholds, cv_scores + ) + best_idx = objective_scores.argmax() + self.best_score_ = objective_scores[best_idx] + self.best_threshold_ = decision_thresholds[best_idx] + if self.store_cv_results: + self.cv_results_ = { + "thresholds": decision_thresholds, + "scores": objective_scores, + } + + return self + + def predict(self, X): + """Predict the target of new samples. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The samples, as accepted by `estimator.predict`. + + Returns + ------- + class_labels : ndarray of shape (n_samples,) + The predicted class. + """ + check_is_fitted(self, "estimator_") + pos_label = self._curve_scorer._get_pos_label() + y_score, _ = _get_response_values_binary( + self.estimator_, + X, + self._response_method, + pos_label=pos_label, + ) + + return _threshold_scores_to_class_labels( + y_score, self.best_threshold_, self.classes_, pos_label + ) + + def get_metadata_routing(self): + """Get metadata routing of this object. + + Please check :ref:`User Guide ` on how the routing + mechanism works. + + Returns + ------- + routing : MetadataRouter + A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating + routing information. + """ + router = ( + MetadataRouter(owner=self.__class__.__name__) + .add( + estimator=self.estimator, + method_mapping=MethodMapping().add(callee="fit", caller="fit"), + ) + .add( + splitter=self.cv, + method_mapping=MethodMapping().add(callee="split", caller="fit"), + ) + .add( + scorer=self._get_curve_scorer(), + method_mapping=MethodMapping().add(callee="score", caller="fit"), + ) + ) + return router + + def _get_curve_scorer(self): + """Get the curve scorer based on the objective metric used.""" + scoring = check_scoring(self.estimator, scoring=self.scoring) + curve_scorer = _CurveScorer.from_scorer( + scoring, self._response_method, self.thresholds + ) + return curve_scorer diff --git a/sklearn/model_selection/tests/test_classification_threshold.py b/sklearn/model_selection/tests/test_classification_threshold.py new file mode 100644 index 0000000000000..f64edb2563c76 --- /dev/null +++ b/sklearn/model_selection/tests/test_classification_threshold.py @@ -0,0 +1,684 @@ +import numpy as np +import pytest + +from sklearn.base import clone +from sklearn.datasets import ( + load_breast_cancer, + load_iris, + make_classification, + make_multilabel_classification, +) +from sklearn.dummy import DummyClassifier +from sklearn.ensemble import GradientBoostingClassifier +from sklearn.exceptions import NotFittedError +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import ( + balanced_accuracy_score, + f1_score, + fbeta_score, + make_scorer, + recall_score, +) +from sklearn.model_selection import ( + FixedThresholdClassifier, + StratifiedShuffleSplit, + TunedThresholdClassifierCV, +) +from sklearn.model_selection._classification_threshold import ( + _CurveScorer, + _fit_and_score_over_thresholds, +) +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.svm import SVC +from sklearn.tree import DecisionTreeClassifier +from sklearn.utils._mocking import CheckingClassifier +from sklearn.utils._testing import ( + _convert_container, + assert_allclose, + assert_array_equal, +) + + +def test_curve_scorer(): + """Check the behaviour of the `_CurveScorer` class.""" + X, y = make_classification(random_state=0) + estimator = LogisticRegression().fit(X, y) + curve_scorer = _CurveScorer( + balanced_accuracy_score, + sign=1, + response_method="predict_proba", + thresholds=10, + kwargs={}, + ) + scores, thresholds = curve_scorer(estimator, X, y) + + assert thresholds.shape == scores.shape + # check that the thresholds are probabilities with extreme values close to 0 and 1. + # they are not exactly 0 and 1 because they are the extremum of the + # `estimator.predict_proba(X)` values. + assert 0 <= thresholds.min() <= 0.01 + assert 0.99 <= thresholds.max() <= 1 + # balanced accuracy should be between 0.5 and 1 when it is not adjusted + assert 0.5 <= scores.min() <= 1 + + # check that passing kwargs to the scorer works + curve_scorer = _CurveScorer( + balanced_accuracy_score, + sign=1, + response_method="predict_proba", + thresholds=10, + kwargs={"adjusted": True}, + ) + scores, thresholds = curve_scorer(estimator, X, y) + + # balanced accuracy should be between 0.5 and 1 when it is not adjusted + assert 0 <= scores.min() <= 0.5 + + # check that we can inverse the sign of the score when dealing with `neg_*` scorer + curve_scorer = _CurveScorer( + balanced_accuracy_score, + sign=-1, + response_method="predict_proba", + thresholds=10, + kwargs={"adjusted": True}, + ) + scores, thresholds = curve_scorer(estimator, X, y) + + assert all(scores <= 0) + + +def test_curve_scorer_pos_label(global_random_seed): + """Check that we propagate properly the `pos_label` parameter to the scorer.""" + n_samples = 30 + X, y = make_classification( + n_samples=n_samples, weights=[0.9, 0.1], random_state=global_random_seed + ) + estimator = LogisticRegression().fit(X, y) + + curve_scorer = _CurveScorer( + recall_score, + sign=1, + response_method="predict_proba", + thresholds=10, + kwargs={"pos_label": 1}, + ) + scores_pos_label_1, thresholds_pos_label_1 = curve_scorer(estimator, X, y) + + curve_scorer = _CurveScorer( + recall_score, + sign=1, + response_method="predict_proba", + thresholds=10, + kwargs={"pos_label": 0}, + ) + scores_pos_label_0, thresholds_pos_label_0 = curve_scorer(estimator, X, y) + + # Since `pos_label` is forwarded to the curve_scorer, the thresholds are not equal. + assert not (thresholds_pos_label_1 == thresholds_pos_label_0).all() + # The min-max range for the thresholds is defined by the probabilities of the + # `pos_label` class (the column of `predict_proba`). + y_pred = estimator.predict_proba(X) + assert thresholds_pos_label_0.min() == pytest.approx(y_pred.min(axis=0)[0]) + assert thresholds_pos_label_0.max() == pytest.approx(y_pred.max(axis=0)[0]) + assert thresholds_pos_label_1.min() == pytest.approx(y_pred.min(axis=0)[1]) + assert thresholds_pos_label_1.max() == pytest.approx(y_pred.max(axis=0)[1]) + + # The recall cannot be negative and `pos_label=1` should have a higher recall + # since there is less samples to be considered. + assert 0.0 < scores_pos_label_0.min() < scores_pos_label_1.min() + assert scores_pos_label_0.max() == pytest.approx(1.0) + assert scores_pos_label_1.max() == pytest.approx(1.0) + + +def test_fit_and_score_over_thresholds_curve_scorers(): + """Check that `_fit_and_score_over_thresholds` returns thresholds in ascending order + for the different accepted curve scorers.""" + X, y = make_classification(n_samples=100, random_state=0) + train_idx, val_idx = np.arange(50), np.arange(50, 100) + classifier = LogisticRegression() + + curve_scorer = _CurveScorer( + score_func=balanced_accuracy_score, + sign=1, + response_method="predict_proba", + thresholds=10, + kwargs={}, + ) + scores, thresholds = _fit_and_score_over_thresholds( + classifier, + X, + y, + fit_params={}, + train_idx=train_idx, + val_idx=val_idx, + curve_scorer=curve_scorer, + score_params={}, + ) + + assert np.all(thresholds[:-1] <= thresholds[1:]) + assert isinstance(scores, np.ndarray) + assert np.logical_and(scores >= 0, scores <= 1).all() + + +def test_fit_and_score_over_thresholds_prefit(): + """Check the behaviour with a prefit classifier.""" + X, y = make_classification(n_samples=100, random_state=0) + + # `train_idx is None` to indicate that the classifier is prefit + train_idx, val_idx = None, np.arange(50, 100) + classifier = DecisionTreeClassifier(random_state=0).fit(X, y) + # make sure that the classifier memorized the full dataset such that + # we get perfect predictions and thus match the expected score + assert classifier.score(X[val_idx], y[val_idx]) == pytest.approx(1.0) + + curve_scorer = _CurveScorer( + score_func=balanced_accuracy_score, + sign=1, + response_method="predict_proba", + thresholds=2, + kwargs={}, + ) + scores, thresholds = _fit_and_score_over_thresholds( + classifier, + X, + y, + fit_params={}, + train_idx=train_idx, + val_idx=val_idx, + curve_scorer=curve_scorer, + score_params={}, + ) + assert np.all(thresholds[:-1] <= thresholds[1:]) + assert_allclose(scores, [0.5, 1.0]) + + +@pytest.mark.usefixtures("enable_slep006") +def test_fit_and_score_over_thresholds_sample_weight(): + """Check that we dispatch the sample-weight to fit and score the classifier.""" + X, y = load_iris(return_X_y=True) + X, y = X[:100], y[:100] # only 2 classes + + # create a dataset and repeat twice the sample of class #0 + X_repeated, y_repeated = np.vstack([X, X[y == 0]]), np.hstack([y, y[y == 0]]) + # create a sample weight vector that is equivalent to the repeated dataset + sample_weight = np.ones_like(y) + sample_weight[:50] *= 2 + + classifier = LogisticRegression() + train_repeated_idx = np.arange(X_repeated.shape[0]) + val_repeated_idx = np.arange(X_repeated.shape[0]) + curve_scorer = _CurveScorer( + score_func=balanced_accuracy_score, + sign=1, + response_method="predict_proba", + thresholds=10, + kwargs={}, + ) + scores_repeated, thresholds_repeated = _fit_and_score_over_thresholds( + classifier, + X_repeated, + y_repeated, + fit_params={}, + train_idx=train_repeated_idx, + val_idx=val_repeated_idx, + curve_scorer=curve_scorer, + score_params={}, + ) + + train_idx, val_idx = np.arange(X.shape[0]), np.arange(X.shape[0]) + scores, thresholds = _fit_and_score_over_thresholds( + classifier.set_fit_request(sample_weight=True), + X, + y, + fit_params={"sample_weight": sample_weight}, + train_idx=train_idx, + val_idx=val_idx, + curve_scorer=curve_scorer.set_score_request(sample_weight=True), + score_params={"sample_weight": sample_weight}, + ) + + assert_allclose(thresholds_repeated, thresholds) + assert_allclose(scores_repeated, scores) + + +@pytest.mark.usefixtures("enable_slep006") +@pytest.mark.parametrize("fit_params_type", ["list", "array"]) +def test_fit_and_score_over_thresholds_fit_params(fit_params_type): + """Check that we pass `fit_params` to the classifier when calling `fit`.""" + X, y = make_classification(n_samples=100, random_state=0) + fit_params = { + "a": _convert_container(y, fit_params_type), + "b": _convert_container(y, fit_params_type), + } + + classifier = CheckingClassifier(expected_fit_params=["a", "b"], random_state=0) + classifier.set_fit_request(a=True, b=True) + train_idx, val_idx = np.arange(50), np.arange(50, 100) + + curve_scorer = _CurveScorer( + score_func=balanced_accuracy_score, + sign=1, + response_method="predict_proba", + thresholds=10, + kwargs={}, + ) + _fit_and_score_over_thresholds( + classifier, + X, + y, + fit_params=fit_params, + train_idx=train_idx, + val_idx=val_idx, + curve_scorer=curve_scorer, + score_params={}, + ) + + +@pytest.mark.parametrize( + "data", + [ + make_classification(n_classes=3, n_clusters_per_class=1, random_state=0), + make_multilabel_classification(random_state=0), + ], +) +def test_tuned_threshold_classifier_no_binary(data): + """Check that we raise an informative error message for non-binary problem.""" + err_msg = "Only binary classification is supported." + with pytest.raises(ValueError, match=err_msg): + TunedThresholdClassifierCV(LogisticRegression()).fit(*data) + + +@pytest.mark.parametrize( + "params, err_type, err_msg", + [ + ( + {"cv": "prefit", "refit": True}, + ValueError, + "When cv='prefit', refit cannot be True.", + ), + ( + {"cv": 10, "refit": False}, + ValueError, + "When cv has several folds, refit cannot be False.", + ), + ( + {"cv": "prefit", "refit": False}, + NotFittedError, + "`estimator` must be fitted.", + ), + ], +) +def test_tuned_threshold_classifier_conflict_cv_refit(params, err_type, err_msg): + """Check that we raise an informative error message when `cv` and `refit` + cannot be used together. + """ + X, y = make_classification(n_samples=100, random_state=0) + with pytest.raises(err_type, match=err_msg): + TunedThresholdClassifierCV(LogisticRegression(), **params).fit(X, y) + + +@pytest.mark.parametrize( + "estimator", + [LogisticRegression(), SVC(), GradientBoostingClassifier(n_estimators=4)], +) +@pytest.mark.parametrize( + "response_method", ["predict_proba", "predict_log_proba", "decision_function"] +) +@pytest.mark.parametrize( + "ThresholdClassifier", [FixedThresholdClassifier, TunedThresholdClassifierCV] +) +def test_threshold_classifier_estimator_response_methods( + ThresholdClassifier, estimator, response_method +): + """Check that `TunedThresholdClassifierCV` exposes the same response methods as the + underlying estimator. + """ + X, y = make_classification(n_samples=100, random_state=0) + + model = ThresholdClassifier(estimator=estimator) + assert hasattr(model, response_method) == hasattr(estimator, response_method) + + model.fit(X, y) + assert hasattr(model, response_method) == hasattr(estimator, response_method) + + if hasattr(model, response_method): + y_pred_cutoff = getattr(model, response_method)(X) + y_pred_underlying_estimator = getattr(model.estimator_, response_method)(X) + + assert_allclose(y_pred_cutoff, y_pred_underlying_estimator) + + +@pytest.mark.parametrize( + "response_method", ["auto", "decision_function", "predict_proba"] +) +def test_tuned_threshold_classifier_without_constraint_value(response_method): + """Check that `TunedThresholdClassifierCV` is optimizing a given objective + metric.""" + X, y = load_breast_cancer(return_X_y=True) + # remove feature to degrade performances + X = X[:, :5] + + # make the problem completely imbalanced such that the balanced accuracy is low + indices_pos = np.flatnonzero(y == 1) + indices_pos = indices_pos[: indices_pos.size // 50] + indices_neg = np.flatnonzero(y == 0) + + X = np.vstack([X[indices_neg], X[indices_pos]]) + y = np.hstack([y[indices_neg], y[indices_pos]]) + + lr = make_pipeline(StandardScaler(), LogisticRegression()).fit(X, y) + thresholds = 100 + model = TunedThresholdClassifierCV( + estimator=lr, + scoring="balanced_accuracy", + response_method=response_method, + thresholds=thresholds, + store_cv_results=True, + ) + score_optimized = balanced_accuracy_score(y, model.fit(X, y).predict(X)) + score_baseline = balanced_accuracy_score(y, lr.predict(X)) + assert score_optimized > score_baseline + assert model.cv_results_["thresholds"].shape == (thresholds,) + assert model.cv_results_["scores"].shape == (thresholds,) + + +def test_tuned_threshold_classifier_metric_with_parameter(): + """Check that we can pass a metric with a parameter in addition check that + `f_beta` with `beta=1` is equivalent to `f1` and different from `f_beta` with + `beta=2`. + """ + X, y = load_breast_cancer(return_X_y=True) + lr = make_pipeline(StandardScaler(), LogisticRegression()).fit(X, y) + model_fbeta_1 = TunedThresholdClassifierCV( + estimator=lr, scoring=make_scorer(fbeta_score, beta=1) + ).fit(X, y) + model_fbeta_2 = TunedThresholdClassifierCV( + estimator=lr, scoring=make_scorer(fbeta_score, beta=2) + ).fit(X, y) + model_f1 = TunedThresholdClassifierCV( + estimator=lr, scoring=make_scorer(f1_score) + ).fit(X, y) + + assert model_fbeta_1.best_threshold_ == pytest.approx(model_f1.best_threshold_) + assert model_fbeta_1.best_threshold_ != pytest.approx(model_fbeta_2.best_threshold_) + + +@pytest.mark.parametrize( + "response_method", ["auto", "decision_function", "predict_proba"] +) +@pytest.mark.parametrize( + "metric", + [ + make_scorer(balanced_accuracy_score), + make_scorer(f1_score, pos_label="cancer"), + ], +) +def test_tuned_threshold_classifier_with_string_targets(response_method, metric): + """Check that targets represented by str are properly managed. + Also, check with several metrics to be sure that `pos_label` is properly + dispatched. + """ + X, y = load_breast_cancer(return_X_y=True) + # Encode numeric targets by meaningful strings. We purposely designed the class + # names such that the `pos_label` is the first alphabetically sorted class and thus + # encoded as 0. + classes = np.array(["cancer", "healthy"], dtype=object) + y = classes[y] + model = TunedThresholdClassifierCV( + estimator=make_pipeline(StandardScaler(), LogisticRegression()), + scoring=metric, + response_method=response_method, + thresholds=100, + ).fit(X, y) + assert_array_equal(model.classes_, np.sort(classes)) + y_pred = model.predict(X) + assert_array_equal(np.unique(y_pred), np.sort(classes)) + + +@pytest.mark.usefixtures("enable_slep006") +@pytest.mark.parametrize("with_sample_weight", [True, False]) +def test_tuned_threshold_classifier_refit(with_sample_weight, global_random_seed): + """Check the behaviour of the `refit` parameter.""" + rng = np.random.RandomState(global_random_seed) + X, y = make_classification(n_samples=100, random_state=0) + if with_sample_weight: + sample_weight = rng.randn(X.shape[0]) + sample_weight = np.abs(sample_weight, out=sample_weight) + else: + sample_weight = None + + # check that `estimator_` if fitted on the full dataset when `refit=True` + estimator = LogisticRegression().set_fit_request(sample_weight=True) + model = TunedThresholdClassifierCV(estimator, refit=True).fit( + X, y, sample_weight=sample_weight + ) + + assert model.estimator_ is not estimator + estimator.fit(X, y, sample_weight=sample_weight) + assert_allclose(model.estimator_.coef_, estimator.coef_) + assert_allclose(model.estimator_.intercept_, estimator.intercept_) + + # check that `estimator_` was not altered when `refit=False` and `cv="prefit"` + estimator = LogisticRegression().set_fit_request(sample_weight=True) + estimator.fit(X, y, sample_weight=sample_weight) + coef = estimator.coef_.copy() + model = TunedThresholdClassifierCV(estimator, cv="prefit", refit=False).fit( + X, y, sample_weight=sample_weight + ) + + assert model.estimator_ is estimator + assert_allclose(model.estimator_.coef_, coef) + + # check that we train `estimator_` on the training split of a given cross-validation + estimator = LogisticRegression().set_fit_request(sample_weight=True) + cv = [ + (np.arange(50), np.arange(50, 100)), + ] # single split + model = TunedThresholdClassifierCV(estimator, cv=cv, refit=False).fit( + X, y, sample_weight=sample_weight + ) + + assert model.estimator_ is not estimator + if with_sample_weight: + sw_train = sample_weight[cv[0][0]] + else: + sw_train = None + estimator.fit(X[cv[0][0]], y[cv[0][0]], sample_weight=sw_train) + assert_allclose(model.estimator_.coef_, estimator.coef_) + + +@pytest.mark.usefixtures("enable_slep006") +@pytest.mark.parametrize("fit_params_type", ["list", "array"]) +def test_tuned_threshold_classifier_fit_params(fit_params_type): + """Check that we pass `fit_params` to the classifier when calling `fit`.""" + X, y = make_classification(n_samples=100, random_state=0) + fit_params = { + "a": _convert_container(y, fit_params_type), + "b": _convert_container(y, fit_params_type), + } + + classifier = CheckingClassifier(expected_fit_params=["a", "b"], random_state=0) + classifier.set_fit_request(a=True, b=True) + model = TunedThresholdClassifierCV(classifier) + model.fit(X, y, **fit_params) + + +@pytest.mark.usefixtures("enable_slep006") +def test_tuned_threshold_classifier_cv_zeros_sample_weights_equivalence(): + """Check that passing removing some sample from the dataset `X` is + equivalent to passing a `sample_weight` with a factor 0.""" + X, y = load_iris(return_X_y=True) + # Scale the data to avoid any convergence issue + X = StandardScaler().fit_transform(X) + # Only use 2 classes and select samples such that 2-fold cross-validation + # split will lead to an equivalence with a `sample_weight` of 0 + X = np.vstack((X[:40], X[50:90])) + y = np.hstack((y[:40], y[50:90])) + sample_weight = np.zeros_like(y) + sample_weight[::2] = 1 + + estimator = LogisticRegression().set_fit_request(sample_weight=True) + model_without_weights = TunedThresholdClassifierCV(estimator, cv=2) + model_with_weights = clone(model_without_weights) + + model_with_weights.fit(X, y, sample_weight=sample_weight) + model_without_weights.fit(X[::2], y[::2]) + + assert_allclose( + model_with_weights.estimator_.coef_, model_without_weights.estimator_.coef_ + ) + + y_pred_with_weights = model_with_weights.predict_proba(X) + y_pred_without_weights = model_without_weights.predict_proba(X) + assert_allclose(y_pred_with_weights, y_pred_without_weights) + + +def test_tuned_threshold_classifier_thresholds_array(): + """Check that we can pass an array to `thresholds` and it is used as candidate + threshold internally.""" + X, y = make_classification(random_state=0) + estimator = LogisticRegression() + thresholds = np.linspace(0, 1, 11) + tuned_model = TunedThresholdClassifierCV( + estimator, + thresholds=thresholds, + response_method="predict_proba", + store_cv_results=True, + ).fit(X, y) + assert_allclose(tuned_model.cv_results_["thresholds"], thresholds) + + +@pytest.mark.parametrize("store_cv_results", [True, False]) +def test_tuned_threshold_classifier_store_cv_results(store_cv_results): + """Check that if `cv_results_` exists depending on `store_cv_results`.""" + X, y = make_classification(random_state=0) + estimator = LogisticRegression() + tuned_model = TunedThresholdClassifierCV( + estimator, store_cv_results=store_cv_results + ).fit(X, y) + if store_cv_results: + assert hasattr(tuned_model, "cv_results_") + else: + assert not hasattr(tuned_model, "cv_results_") + + +def test_tuned_threshold_classifier_cv_float(): + """Check the behaviour when `cv` is set to a float.""" + X, y = make_classification(random_state=0) + + # case where `refit=False` and cv is a float: the underlying estimator will be fit + # on the training set given by a ShuffleSplit. We check that we get the same model + # coefficients. + test_size = 0.3 + estimator = LogisticRegression() + tuned_model = TunedThresholdClassifierCV( + estimator, cv=test_size, refit=False, random_state=0 + ).fit(X, y) + tuned_model.fit(X, y) + + cv = StratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=0) + train_idx, val_idx = next(cv.split(X, y)) + cloned_estimator = clone(estimator).fit(X[train_idx], y[train_idx]) + + assert_allclose(tuned_model.estimator_.coef_, cloned_estimator.coef_) + + # case where `refit=True`, then the underlying estimator is fitted on the full + # dataset. + tuned_model.set_params(refit=True).fit(X, y) + cloned_estimator = clone(estimator).fit(X, y) + + assert_allclose(tuned_model.estimator_.coef_, cloned_estimator.coef_) + + +def test_tuned_threshold_classifier_error_constant_predictor(): + """Check that we raise a ValueError if the underlying classifier returns constant + probabilities such that we cannot find any threshold. + """ + X, y = make_classification(random_state=0) + estimator = DummyClassifier(strategy="constant", constant=1) + tuned_model = TunedThresholdClassifierCV(estimator, response_method="predict_proba") + err_msg = "The provided estimator makes constant predictions" + with pytest.raises(ValueError, match=err_msg): + tuned_model.fit(X, y) + + +@pytest.mark.parametrize( + "response_method", ["auto", "predict_proba", "decision_function"] +) +def test_fixed_threshold_classifier_equivalence_default(response_method): + """Check that `FixedThresholdClassifier` has the same behaviour as the vanilla + classifier. + """ + X, y = make_classification(random_state=0) + classifier = LogisticRegression().fit(X, y) + classifier_default_threshold = FixedThresholdClassifier( + estimator=clone(classifier), response_method=response_method + ) + classifier_default_threshold.fit(X, y) + + # emulate the response method that should take into account the `pos_label` + if response_method in ("auto", "predict_proba"): + y_score = classifier_default_threshold.predict_proba(X)[:, 1] + threshold = 0.5 + else: # response_method == "decision_function" + y_score = classifier_default_threshold.decision_function(X) + threshold = 0.0 + + y_pred_lr = (y_score >= threshold).astype(int) + assert_allclose(classifier_default_threshold.predict(X), y_pred_lr) + + +@pytest.mark.parametrize( + "response_method, threshold", [("predict_proba", 0.7), ("decision_function", 2.0)] +) +@pytest.mark.parametrize("pos_label", [0, 1]) +def test_fixed_threshold_classifier(response_method, threshold, pos_label): + """Check that applying `predict` lead to the same prediction as applying the + threshold to the output of the response method. + """ + X, y = make_classification(n_samples=50, random_state=0) + logistic_regression = LogisticRegression().fit(X, y) + model = FixedThresholdClassifier( + estimator=clone(logistic_regression), + threshold=threshold, + response_method=response_method, + pos_label=pos_label, + ).fit(X, y) + + # check that the underlying estimator is the same + assert_allclose(model.estimator_.coef_, logistic_regression.coef_) + + # emulate the response method that should take into account the `pos_label` + if response_method == "predict_proba": + y_score = model.predict_proba(X)[:, pos_label] + else: # response_method == "decision_function" + y_score = model.decision_function(X) + y_score = y_score if pos_label == 1 else -y_score + + # create a mapping from boolean values to class labels + map_to_label = np.array([0, 1]) if pos_label == 1 else np.array([1, 0]) + y_pred_lr = map_to_label[(y_score >= threshold).astype(int)] + assert_allclose(model.predict(X), y_pred_lr) + + for method in ("predict_proba", "predict_log_proba", "decision_function"): + assert_allclose( + getattr(model, method)(X), getattr(logistic_regression, method)(X) + ) + assert_allclose( + getattr(model.estimator_, method)(X), + getattr(logistic_regression, method)(X), + ) + + +@pytest.mark.usefixtures("enable_slep006") +def test_fixed_threshold_classifier_metadata_routing(): + """Check that everything works with metadata routing.""" + X, y = make_classification(random_state=0) + sample_weight = np.ones_like(y) + sample_weight[::2] = 2 + classifier = LogisticRegression().set_fit_request(sample_weight=True) + classifier.fit(X, y, sample_weight=sample_weight) + classifier_default_threshold = FixedThresholdClassifier(estimator=clone(classifier)) + classifier_default_threshold.fit(X, y, sample_weight=sample_weight) + assert_allclose(classifier_default_threshold.estimator_.coef_, classifier.coef_) diff --git a/sklearn/utils/_mocking.py b/sklearn/utils/_mocking.py index 16acabf03755b..0afed8c08cfaa 100644 --- a/sklearn/utils/_mocking.py +++ b/sklearn/utils/_mocking.py @@ -3,7 +3,13 @@ from ..base import BaseEstimator, ClassifierMixin from ..utils._metadata_requests import RequestMethod from .metaestimators import available_if -from .validation import _check_sample_weight, _num_samples, check_array, check_is_fitted +from .validation import ( + _check_sample_weight, + _num_samples, + check_array, + check_is_fitted, + check_random_state, +) class ArraySlicingWrapper: @@ -133,6 +139,7 @@ def __init__( foo_param=0, expected_sample_weight=None, expected_fit_params=None, + random_state=None, ): self.check_y = check_y self.check_y_params = check_y_params @@ -142,6 +149,7 @@ def __init__( self.foo_param = foo_param self.expected_sample_weight = expected_sample_weight self.expected_fit_params = expected_fit_params + self.random_state = random_state def _check_X_y(self, X, y=None, should_be_fitted=True): """Validate X and y and make extra check. @@ -243,7 +251,8 @@ def predict(self, X): """ if self.methods_to_check == "all" or "predict" in self.methods_to_check: X, y = self._check_X_y(X) - return self.classes_[np.zeros(_num_samples(X), dtype=int)] + rng = check_random_state(self.random_state) + return rng.choice(self.classes_, size=_num_samples(X)) def predict_proba(self, X): """Predict probabilities for each class. @@ -263,8 +272,10 @@ def predict_proba(self, X): """ if self.methods_to_check == "all" or "predict_proba" in self.methods_to_check: X, y = self._check_X_y(X) - proba = np.zeros((_num_samples(X), len(self.classes_))) - proba[:, 0] = 1 + rng = check_random_state(self.random_state) + proba = rng.randn(_num_samples(X), len(self.classes_)) + proba = np.abs(proba, out=proba) + proba /= np.sum(proba, axis=1)[:, np.newaxis] return proba def decision_function(self, X): @@ -286,14 +297,13 @@ def decision_function(self, X): or "decision_function" in self.methods_to_check ): X, y = self._check_X_y(X) + rng = check_random_state(self.random_state) if len(self.classes_) == 2: # for binary classifier, the confidence score is related to # classes_[1] and therefore should be null. - return np.zeros(_num_samples(X)) + return rng.randn(_num_samples(X)) else: - decision = np.zeros((_num_samples(X), len(self.classes_))) - decision[:, 0] = 1 - return decision + return rng.randn(_num_samples(X), len(self.classes_)) def score(self, X=None, Y=None): """Fake score. diff --git a/sklearn/utils/_response.py b/sklearn/utils/_response.py index 0207cc1205120..0381c872a94b0 100644 --- a/sklearn/utils/_response.py +++ b/sklearn/utils/_response.py @@ -243,7 +243,9 @@ def _get_response_values( return y_pred, pos_label -def _get_response_values_binary(estimator, X, response_method, pos_label=None): +def _get_response_values_binary( + estimator, X, response_method, pos_label=None, return_response_method_used=False +): """Compute the response values of a binary classifier. Parameters @@ -266,6 +268,12 @@ def _get_response_values_binary(estimator, X, response_method, pos_label=None): the metrics. By default, `estimators.classes_[1]` is considered as the positive class. + return_response_method_used : bool, default=False + Whether to return the response method used to compute the response + values. + + .. versionadded:: 1.5 + Returns ------- y_pred : ndarray of shape (n_samples,) @@ -275,6 +283,12 @@ def _get_response_values_binary(estimator, X, response_method, pos_label=None): pos_label : int, float, bool or str The class considered as the positive class when computing the metrics. + + response_method_used : str + The response method used to compute the response values. Only returned + if `return_response_method_used` is `True`. + + .. versionadded:: 1.5 """ classification_error = "Expected 'estimator' to be a binary classifier." @@ -296,4 +310,5 @@ def _get_response_values_binary(estimator, X, response_method, pos_label=None): X, response_method, pos_label=pos_label, + return_response_method_used=return_response_method_used, ) diff --git a/sklearn/utils/tests/test_mocking.py b/sklearn/utils/tests/test_mocking.py index 9c66d1345bb6d..bd143855e6dcd 100644 --- a/sklearn/utils/tests/test_mocking.py +++ b/sklearn/utils/tests/test_mocking.py @@ -1,6 +1,6 @@ import numpy as np import pytest -from numpy.testing import assert_allclose, assert_array_equal +from numpy.testing import assert_array_equal from scipy import sparse from sklearn.datasets import load_iris @@ -90,7 +90,7 @@ def test_checking_classifier(iris, input_type): assert clf.n_features_in_ == 4 y_pred = clf.predict(X) - assert_array_equal(y_pred, np.zeros(y_pred.size, dtype=int)) + assert all(pred in clf.classes_ for pred in y_pred) assert clf.score(X) == pytest.approx(0) clf.set_params(foo_param=10) @@ -98,13 +98,10 @@ def test_checking_classifier(iris, input_type): y_proba = clf.predict_proba(X) assert y_proba.shape == (150, 3) - assert_allclose(y_proba[:, 0], 1) - assert_allclose(y_proba[:, 1:], 0) + assert np.logical_and(y_proba >= 0, y_proba <= 1).all() y_decision = clf.decision_function(X) assert y_decision.shape == (150, 3) - assert_allclose(y_decision[:, 0], 1) - assert_allclose(y_decision[:, 1:], 0) # check the shape in case of binary classification first_2_classes = np.logical_or(y == 0, y == 1) @@ -114,12 +111,10 @@ def test_checking_classifier(iris, input_type): y_proba = clf.predict_proba(X) assert y_proba.shape == (100, 2) - assert_allclose(y_proba[:, 0], 1) - assert_allclose(y_proba[:, 1], 0) + assert np.logical_and(y_proba >= 0, y_proba <= 1).all() y_decision = clf.decision_function(X) assert y_decision.shape == (100,) - assert_allclose(y_decision, 0) @pytest.mark.parametrize("csr_container", CSR_CONTAINERS) diff --git a/sklearn/utils/tests/test_response.py b/sklearn/utils/tests/test_response.py index c84bf6030336a..858c16cca4df1 100644 --- a/sklearn/utils/tests/test_response.py +++ b/sklearn/utils/tests/test_response.py @@ -240,36 +240,60 @@ def test_get_response_error(estimator, X, y, err_msg, params): _get_response_values_binary(estimator, X, **params) -def test_get_response_predict_proba(): +@pytest.mark.parametrize("return_response_method_used", [True, False]) +def test_get_response_predict_proba(return_response_method_used): """Check the behaviour of `_get_response_values_binary` using `predict_proba`.""" classifier = DecisionTreeClassifier().fit(X_binary, y_binary) - y_proba, pos_label = _get_response_values_binary( - classifier, X_binary, response_method="predict_proba" + results = _get_response_values_binary( + classifier, + X_binary, + response_method="predict_proba", + return_response_method_used=return_response_method_used, ) - assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 1]) - assert pos_label == 1 + assert_allclose(results[0], classifier.predict_proba(X_binary)[:, 1]) + assert results[1] == 1 + if return_response_method_used: + assert results[2] == "predict_proba" - y_proba, pos_label = _get_response_values_binary( - classifier, X_binary, response_method="predict_proba", pos_label=0 + results = _get_response_values_binary( + classifier, + X_binary, + response_method="predict_proba", + pos_label=0, + return_response_method_used=return_response_method_used, ) - assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 0]) - assert pos_label == 0 + assert_allclose(results[0], classifier.predict_proba(X_binary)[:, 0]) + assert results[1] == 0 + if return_response_method_used: + assert results[2] == "predict_proba" -def test_get_response_decision_function(): +@pytest.mark.parametrize("return_response_method_used", [True, False]) +def test_get_response_decision_function(return_response_method_used): """Check the behaviour of `_get_response_values_binary` using decision_function.""" classifier = LogisticRegression().fit(X_binary, y_binary) - y_score, pos_label = _get_response_values_binary( - classifier, X_binary, response_method="decision_function" + results = _get_response_values_binary( + classifier, + X_binary, + response_method="decision_function", + return_response_method_used=return_response_method_used, ) - assert_allclose(y_score, classifier.decision_function(X_binary)) - assert pos_label == 1 + assert_allclose(results[0], classifier.decision_function(X_binary)) + assert results[1] == 1 + if return_response_method_used: + assert results[2] == "decision_function" - y_score, pos_label = _get_response_values_binary( - classifier, X_binary, response_method="decision_function", pos_label=0 + results = _get_response_values_binary( + classifier, + X_binary, + response_method="decision_function", + pos_label=0, + return_response_method_used=return_response_method_used, ) - assert_allclose(y_score, classifier.decision_function(X_binary) * -1) - assert pos_label == 0 + assert_allclose(results[0], classifier.decision_function(X_binary) * -1) + assert results[1] == 0 + if return_response_method_used: + assert results[2] == "decision_function" @pytest.mark.parametrize(