From 4a70683f63733ca52a1800252e235c26fafd8ce8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 16 Jan 2024 14:05:57 +0100 Subject: [PATCH] [BUG] Fix dynamic `make_forecasting_scorer` for newer `sklearn` metrics (#5717) Fixes https://github.com/sktime/sktime/issues/5715. The issue with `sklearn` metrics is coming from their decorator which no longer allows inspection of admissible arguments via `python` native `sigature`. Question for reviewers: is there an `sklearn` native way to inspect permissible signature, without reading private attriutes? --- .../forecasting/_classes.py | 23 +++++++++++++++++++ .../forecasting/tests/test_metrics.py | 15 ++++++++++++ 2 files changed, 38 insertions(+) diff --git a/sktime/performance_metrics/forecasting/_classes.py b/sktime/performance_metrics/forecasting/_classes.py index 08d3313837f..e8311bd5378 100644 --- a/sktime/performance_metrics/forecasting/_classes.py +++ b/sktime/performance_metrics/forecasting/_classes.py @@ -573,6 +573,10 @@ def _evaluate(self, y_true, y_pred, **kwargs): else: func = self.func + return self._evaluate_func(func=func, y_true=y_true, y_pred=y_pred, **params) + + def _evaluate_func(self, func, y_true, y_pred, **params): + """Call func with kwargs subset to func parameters.""" # import here for now to avoid interaction with getmembers in tests # todo: clean up ancient getmembers in test_metrics_classes from functools import partial @@ -584,6 +588,15 @@ def _evaluate(self, y_true, y_pred, **kwargs): func_params = func_params.intersection(params.keys()) params = {key: params[key] for key in func_params} + # deal with sklearn specific parameter constraints + # as these are a decorator, they obfuscate python native inspection + # via signature, so have to be dealt with separately + if hasattr(func, "_skl_parameter_constraints"): + constr = func._skl_parameter_constraints + if isinstance(constr, dict): + constr_params = set(constr.keys()).intersection(params.keys()) + params = {key: params[key] for key in constr_params} + res = func(y_true=y_true, y_pred=y_pred, **params) return res @@ -609,6 +622,16 @@ def __init__( self.set_tags(**{"lower_is_better": lower_is_better}) + def _evaluate(self, y_true, y_pred, **kwargs): + """Evaluate the desired metric on given inputs.""" + # this dict should contain all parameters + params = kwargs + params.update({"multioutput": self.multioutput, "multilevel": self.multilevel}) + + func = self.func + + return self._evaluate_func(func=func, y_true=y_true, y_pred=y_pred, **params) + @classmethod def get_test_params(cls, parameter_set="default"): """Return testing parameter settings for the estimator. diff --git a/sktime/performance_metrics/forecasting/tests/test_metrics.py b/sktime/performance_metrics/forecasting/tests/test_metrics.py index 1b2d07c4fcd..12f9709bf8f 100644 --- a/sktime/performance_metrics/forecasting/tests/test_metrics.py +++ b/sktime/performance_metrics/forecasting/tests/test_metrics.py @@ -129,3 +129,18 @@ def test_make_scorer(): scorer = make_forecasting_scorer(rmsle, name="RMSLE") scorer.evaluate(pd.Series([1, 2, 3]), pd.Series([1, 2, 4])) + + +def test_make_scorer_sklearn(): + """Test make_forecasting_scorer and the failure case in #5715. + + Naive adaptation fails on newer sklearn versions due to + decoration with sklearn's custom input constraint wrapper. + """ + from sklearn.metrics import mean_absolute_error + + from sktime.performance_metrics.forecasting import make_forecasting_scorer + + scorer = make_forecasting_scorer(mean_absolute_error, name="RMSLE") + + scorer.evaluate(pd.Series([1, 2, 3]), pd.Series([1, 2, 4]))