Skip to content

Commit

Permalink
[BUG] Fix dynamic make_forecasting_scorer for newer sklearn metri…
Browse files Browse the repository at this point in the history
…cs (#5717)

Fixes #5715.

The issue with `sklearn` metrics is coming from their decorator which no
longer allows inspection of admissible arguments via `python` native
`sigature`.

Question for reviewers: is there an `sklearn` native way to inspect
permissible signature, without reading private attriutes?
  • Loading branch information
fkiraly committed Jan 16, 2024
1 parent 01e08f9 commit 4a70683
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
23 changes: 23 additions & 0 deletions sktime/performance_metrics/forecasting/_classes.py
Expand Up @@ -573,6 +573,10 @@ def _evaluate(self, y_true, y_pred, **kwargs):
else:
func = self.func

return self._evaluate_func(func=func, y_true=y_true, y_pred=y_pred, **params)

def _evaluate_func(self, func, y_true, y_pred, **params):
"""Call func with kwargs subset to func parameters."""
# import here for now to avoid interaction with getmembers in tests
# todo: clean up ancient getmembers in test_metrics_classes
from functools import partial
Expand All @@ -584,6 +588,15 @@ def _evaluate(self, y_true, y_pred, **kwargs):
func_params = func_params.intersection(params.keys())
params = {key: params[key] for key in func_params}

# deal with sklearn specific parameter constraints
# as these are a decorator, they obfuscate python native inspection
# via signature, so have to be dealt with separately
if hasattr(func, "_skl_parameter_constraints"):
constr = func._skl_parameter_constraints
if isinstance(constr, dict):
constr_params = set(constr.keys()).intersection(params.keys())
params = {key: params[key] for key in constr_params}

res = func(y_true=y_true, y_pred=y_pred, **params)
return res

Expand All @@ -609,6 +622,16 @@ def __init__(

self.set_tags(**{"lower_is_better": lower_is_better})

def _evaluate(self, y_true, y_pred, **kwargs):
"""Evaluate the desired metric on given inputs."""
# this dict should contain all parameters
params = kwargs
params.update({"multioutput": self.multioutput, "multilevel": self.multilevel})

func = self.func

return self._evaluate_func(func=func, y_true=y_true, y_pred=y_pred, **params)

@classmethod
def get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator.
Expand Down
15 changes: 15 additions & 0 deletions sktime/performance_metrics/forecasting/tests/test_metrics.py
Expand Up @@ -129,3 +129,18 @@ def test_make_scorer():
scorer = make_forecasting_scorer(rmsle, name="RMSLE")

scorer.evaluate(pd.Series([1, 2, 3]), pd.Series([1, 2, 4]))


def test_make_scorer_sklearn():
"""Test make_forecasting_scorer and the failure case in #5715.
Naive adaptation fails on newer sklearn versions due to
decoration with sklearn's custom input constraint wrapper.
"""
from sklearn.metrics import mean_absolute_error

from sktime.performance_metrics.forecasting import make_forecasting_scorer

scorer = make_forecasting_scorer(mean_absolute_error, name="RMSLE")

scorer.evaluate(pd.Series([1, 2, 3]), pd.Series([1, 2, 4]))

0 comments on commit 4a70683

Please sign in to comment.