diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 6db20bff58fc3..5e3b0dd71d33f 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -1490,3 +1490,18 @@ def test_make_scorer_deprecation(deprecated_params, new_params, warn_msg): assert deprecated_roc_auc_scorer(classifier, X, y) == pytest.approx( roc_auc_scorer(classifier, X, y) ) + + +@pytest.mark.parametrize("enable_metadata_routing", [True, False]) +def test_metadata_routing_multimetric_metadata_routing(enable_metadata_routing): + """Test multimetric scorer works with and without metadata routing enabled when + there is no actual metadata to pass. + + Non-regression test for https://github.com/scikit-learn/scikit-learn/issues/28256 + """ + X, y = make_classification(n_samples=50, n_features=10, random_state=0) + estimator = EstimatorWithFitAndPredict().fit(X, y) + + multimetric_scorer = _MultimetricScorer(scorers={"acc": get_scorer("accuracy")}) + with config_context(enable_metadata_routing=enable_metadata_routing): + multimetric_scorer(estimator, X, y) diff --git a/sklearn/tests/test_metadata_routing.py b/sklearn/tests/test_metadata_routing.py index cad5fbd78e5e3..34078a59e0529 100644 --- a/sklearn/tests/test_metadata_routing.py +++ b/sklearn/tests/test_metadata_routing.py @@ -239,6 +239,22 @@ class InvalidObject: process_routing(InvalidObject(), "fit", groups=my_groups) +@pytest.mark.parametrize("method", METHODS) +@pytest.mark.parametrize("default", [None, "default", []]) +def test_process_routing_empty_params_get_with_default(method, default): + empty_params = {} + routed_params = process_routing(ConsumingClassifier(), "fit", **empty_params) + + # Behaviour should be an empty dictionary returned for each method when retrieved. + params_for_method = routed_params[method] + assert isinstance(params_for_method, dict) + assert set(params_for_method.keys()) == set(METHODS) + + # No default to `get` should be equivalent to the default + default_params_for_method = routed_params.get(method, default=default) + assert default_params_for_method == params_for_method + + def test_simple_metadata_routing(): # Tests that metadata is properly routed diff --git a/sklearn/utils/_metadata_requests.py b/sklearn/utils/_metadata_requests.py index 00c0e2023e78c..8b99012d7b0fb 100644 --- a/sklearn/utils/_metadata_requests.py +++ b/sklearn/utils/_metadata_requests.py @@ -1082,8 +1082,12 @@ def _serialize(self): def __iter__(self): if self._self_request: - yield "$self_request", RouterMappingPair( - mapping=MethodMapping.from_str("one-to-one"), router=self._self_request + yield ( + "$self_request", + RouterMappingPair( + mapping=MethodMapping.from_str("one-to-one"), + router=self._self_request, + ), ) for name, route_mapping in self._route_mappings.items(): yield (name, route_mapping) @@ -1530,7 +1534,7 @@ def process_routing(_obj, _method, /, **kwargs): # an empty dict on routed_params.ANYTHING.ANY_METHOD. class EmptyRequest: def get(self, name, default=None): - return default if default else {} + return Bunch(**{method: dict() for method in METHODS}) def __getitem__(self, name): return Bunch(**{method: dict() for method in METHODS})