Skip to content

Commit

Permalink
FIX validate properly zero_division=np.nan when used in parallel proc…
Browse files Browse the repository at this point in the history
…essing (#27573)
  • Loading branch information
glemaitre committed Oct 23, 2023
1 parent 1df363c commit d53756e
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 6 deletions.
18 changes: 12 additions & 6 deletions sklearn/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,7 +1079,8 @@ def zero_one_loss(y_true, y_pred, *, normalize=True, sample_weight=None):
],
"sample_weight": ["array-like", None],
"zero_division": [
Options(Real, {0.0, 1.0, np.nan}),
Options(Real, {0.0, 1.0}),
"nan",
StrOptions({"warn"}),
],
},
Expand Down Expand Up @@ -1260,7 +1261,8 @@ def f1_score(
],
"sample_weight": ["array-like", None],
"zero_division": [
Options(Real, {0.0, 1.0, np.nan}),
Options(Real, {0.0, 1.0}),
"nan",
StrOptions({"warn"}),
],
},
Expand Down Expand Up @@ -1542,7 +1544,8 @@ def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label):
"warn_for": [list, tuple, set],
"sample_weight": ["array-like", None],
"zero_division": [
Options(Real, {0.0, 1.0, np.nan}),
Options(Real, {0.0, 1.0}),
"nan",
StrOptions({"warn"}),
],
},
Expand Down Expand Up @@ -1979,7 +1982,8 @@ class after being classified as negative. This is the case when the
],
"sample_weight": ["array-like", None],
"zero_division": [
Options(Real, {0.0, 1.0, np.nan}),
Options(Real, {0.0, 1.0}),
"nan",
StrOptions({"warn"}),
],
},
Expand Down Expand Up @@ -2149,7 +2153,8 @@ def precision_score(
],
"sample_weight": ["array-like", None],
"zero_division": [
Options(Real, {0.0, 1.0, np.nan}),
Options(Real, {0.0, 1.0}),
"nan",
StrOptions({"warn"}),
],
},
Expand Down Expand Up @@ -2412,7 +2417,8 @@ def balanced_accuracy_score(y_true, y_pred, *, sample_weight=None, adjusted=Fals
"digits": [Interval(Integral, 0, None, closed="left")],
"output_dict": ["boolean"],
"zero_division": [
Options(Real, {0.0, 1.0, np.nan}),
Options(Real, {0.0, 1.0}),
"nan",
StrOptions({"warn"}),
],
},
Expand Down
27 changes: 27 additions & 0 deletions sklearn/metrics/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
hinge_loss,
jaccard_score,
log_loss,
make_scorer,
matthews_corrcoef,
multilabel_confusion_matrix,
precision_recall_fscore_support,
Expand All @@ -35,7 +36,9 @@
zero_one_loss,
)
from sklearn.metrics._classification import _check_targets
from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import LabelBinarizer, label_binarize
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils._mocking import MockDataFrame
from sklearn.utils._testing import (
assert_allclose,
Expand Down Expand Up @@ -2802,3 +2805,27 @@ def test_classification_metric_pos_label_types(metric, classes):
y_pred = y_true.copy()
result = metric(y_true, y_pred, pos_label=pos_label)
assert not np.any(np.isnan(result))


@pytest.mark.parametrize(
"scoring",
[
make_scorer(f1_score, zero_division=np.nan),
make_scorer(fbeta_score, beta=2, zero_division=np.nan),
make_scorer(precision_score, zero_division=np.nan),
make_scorer(recall_score, zero_division=np.nan),
],
)
def test_classification_metric_division_by_zero_nan_validaton(scoring):
"""Check that we validate `np.nan` properly for classification metrics.
With `n_jobs=2` in cross-validation, the `np.nan` used for the singleton will be
different in the sub-process and we should not use the `is` operator but
`math.isnan`.
Non-regression test for:
https://github.com/scikit-learn/scikit-learn/issues/27563
"""
X, y = datasets.make_classification(random_state=0)
classifier = DecisionTreeClassifier(max_depth=3, random_state=0).fit(X, y)
cross_val_score(classifier, X, y, scoring=scoring, n_jobs=2, error_score="raise")
3 changes: 3 additions & 0 deletions sklearn/utils/_param_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def validate_parameter_constraints(parameter_constraints, params, caller_name):
- the string "boolean"
- the string "verbose"
- the string "cv_object"
- the string "nan"
- a MissingValues object representing markers for missing values
- a HasMethods object, representing method(s) an object must have
- a Hidden object, representing a constraint not meant to be exposed to the user
Expand Down Expand Up @@ -137,6 +138,8 @@ def make_constraint(constraint):
constraint = make_constraint(constraint.constraint)
constraint.hidden = True
return constraint
if isinstance(constraint, str) and constraint == "nan":
return _NanConstraint()
raise ValueError(f"Unknown constraint type: {constraint}")


Expand Down
3 changes: 3 additions & 0 deletions sklearn/utils/tests/test_param_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_CVObjects,
_InstancesOf,
_IterablesNotString,
_NanConstraint,
_NoneConstraint,
_PandasNAConstraint,
_RandomStates,
Expand Down Expand Up @@ -387,6 +388,7 @@ def test_generate_valid_param(constraint):
(Real, 0.5),
("boolean", False),
("verbose", 1),
("nan", np.nan),
(MissingValues(), -1),
(MissingValues(), -1.0),
(MissingValues(), None),
Expand Down Expand Up @@ -420,6 +422,7 @@ def test_is_satisfied_by(constraint_declaration, value):
(MissingValues(numeric_only=True), MissingValues),
(HasMethods("fit"), HasMethods),
("cv_object", _CVObjects),
("nan", _NanConstraint),
],
)
def test_make_constraint(constraint_declaration, expected_constraint_class):
Expand Down

0 comments on commit d53756e

Please sign in to comment.