Skip to content

Commit

Permalink
FIX EmptyRequest.get defaults to Bunch of METHODS (#28371)
Browse files Browse the repository at this point in the history
  • Loading branch information
eddiebergman authored and glemaitre committed Feb 13, 2024
1 parent e4d7d9a commit b2e231e
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 3 deletions.
15 changes: 15 additions & 0 deletions sklearn/metrics/tests/test_score_objects.py
Expand Up @@ -1490,3 +1490,18 @@ def test_make_scorer_deprecation(deprecated_params, new_params, warn_msg):
assert deprecated_roc_auc_scorer(classifier, X, y) == pytest.approx(
roc_auc_scorer(classifier, X, y)
)


@pytest.mark.parametrize("enable_metadata_routing", [True, False])
def test_metadata_routing_multimetric_metadata_routing(enable_metadata_routing):
"""Test multimetric scorer works with and without metadata routing enabled when
there is no actual metadata to pass.
Non-regression test for https://github.com/scikit-learn/scikit-learn/issues/28256
"""
X, y = make_classification(n_samples=50, n_features=10, random_state=0)
estimator = EstimatorWithFitAndPredict().fit(X, y)

multimetric_scorer = _MultimetricScorer(scorers={"acc": get_scorer("accuracy")})
with config_context(enable_metadata_routing=enable_metadata_routing):
multimetric_scorer(estimator, X, y)
16 changes: 16 additions & 0 deletions sklearn/tests/test_metadata_routing.py
Expand Up @@ -239,6 +239,22 @@ class InvalidObject:
process_routing(InvalidObject(), "fit", groups=my_groups)


@pytest.mark.parametrize("method", METHODS)
@pytest.mark.parametrize("default", [None, "default", []])
def test_process_routing_empty_params_get_with_default(method, default):
empty_params = {}
routed_params = process_routing(ConsumingClassifier(), "fit", **empty_params)

# Behaviour should be an empty dictionary returned for each method when retrieved.
params_for_method = routed_params[method]
assert isinstance(params_for_method, dict)
assert set(params_for_method.keys()) == set(METHODS)

# No default to `get` should be equivalent to the default
default_params_for_method = routed_params.get(method, default=default)
assert default_params_for_method == params_for_method


def test_simple_metadata_routing():
# Tests that metadata is properly routed

Expand Down
10 changes: 7 additions & 3 deletions sklearn/utils/_metadata_requests.py
Expand Up @@ -1082,8 +1082,12 @@ def _serialize(self):

def __iter__(self):
if self._self_request:
yield "$self_request", RouterMappingPair(
mapping=MethodMapping.from_str("one-to-one"), router=self._self_request
yield (
"$self_request",
RouterMappingPair(
mapping=MethodMapping.from_str("one-to-one"),
router=self._self_request,
),
)
for name, route_mapping in self._route_mappings.items():
yield (name, route_mapping)
Expand Down Expand Up @@ -1530,7 +1534,7 @@ def process_routing(_obj, _method, /, **kwargs):
# an empty dict on routed_params.ANYTHING.ANY_METHOD.
class EmptyRequest:
def get(self, name, default=None):
return default if default else {}
return Bunch(**{method: dict() for method in METHODS})

def __getitem__(self, name):
return Bunch(**{method: dict() for method in METHODS})
Expand Down

0 comments on commit b2e231e

Please sign in to comment.