Skip to content

Commit

Permalink
[BUG] allow unused parameters in metric when using `make_forecasting_…
Browse files Browse the repository at this point in the history
…scorer` (#4833)

Fixes #4827 by ignoring
parameters that are not present in the default params of a metric.
  • Loading branch information
fkiraly committed Jul 13, 2023
1 parent bea55f1 commit 22b8854
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
7 changes: 4 additions & 3 deletions sktime/performance_metrics/forecasting/_classes.py
Expand Up @@ -575,6 +575,7 @@ def _evaluate(self, y_true, y_pred, **kwargs):
if getfullargspec(func).varkw is None:
func_params = signature(func).parameters.keys()
func_params = set(func_params).difference(["y_true", "y_pred"])
func_params = func_params.intersection(params.keys())
params = {key: params[key] for key in func_params}

res = func(y_true=y_true, y_pred=y_pred, **params)
Expand Down Expand Up @@ -657,12 +658,12 @@ def make_forecasting_scorer(
multioutput="uniform_average",
multilevel="uniform_average",
):
"""Create a metric class from a metric functions.
"""Create a metric class from a metric function.
Parameters
----------
func
Function to convert to a forecasting scorer class.
func : callable
Callable to convert to a forecasting scorer class.
Score function (or loss function) with signature ``func(y, y_pred, **kwargs)``.
name : str, default=None
Name to use for the forecasting scorer loss class.
Expand Down
16 changes: 16 additions & 0 deletions sktime/performance_metrics/forecasting/tests/test_metrics.py
Expand Up @@ -4,6 +4,7 @@
# since the numpy output print changes between versions

import numpy as np
import pandas as pd


def test_gmse_class():
Expand Down Expand Up @@ -106,3 +107,18 @@ def test_linex_function():
assert np.allclose(
mean_linex_error(y_true, y_pred, multioutput=[0.3, 0.7]), 0.30917568000716666
)


def test_make_scorer():
"""Test make_forecasting_scorer and the failure case in #4827."""
import functools

from sklearn.metrics import mean_squared_log_error

from sktime.performance_metrics.forecasting import make_forecasting_scorer

rmsle = functools.partial(mean_squared_log_error, squared=False)

scorer = make_forecasting_scorer(rmsle, name="RMSLE")

scorer.evaluate(pd.Series([1, 2, 3]), pd.Series([1, 2, 4]))

0 comments on commit 22b8854

Please sign in to comment.