Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] statsforecast 1.6.0 compatibility - fix argument differences between sktime and statsforecast #5393

Merged
merged 2 commits into from Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Expand Up @@ -132,7 +132,7 @@ all_extras = [
"seaborn>=0.11.0",
"seasonal",
"skpro>=2.0.0,<2.1.0",
"statsforecast>=0.5.2,<1.6.0; python_version < '3.11'",
"statsforecast>=0.5.2,<1.7.0",
"statsmodels>=0.12.1",
"stumpy>=1.5.1; python_version < '3.11'",
"tbats>=1.1.0",
Expand Down Expand Up @@ -168,7 +168,7 @@ all_extras_pandas2 = [
"seaborn>=0.11.0",
"seasonal",
"skpro>=2.0.0,<2.1.0",
"statsforecast>=0.5.2,<1.6.0; python_version < '3.11'",
"statsforecast>=0.5.2,<1.7.0",
"statsmodels>=0.12.1",
"stumpy>=1.5.1; python_version < '3.11'",
"tbats>=1.1.0",
Expand Down
96 changes: 78 additions & 18 deletions sktime/forecasting/base/adapters/_generalised_statsforecast.py
Expand Up @@ -43,30 +43,77 @@ def __init__(self):
def _get_statsforecast_class(self):
raise NotImplementedError("abstract method")

def _get_statsforecast_params(self):
def _get_statsforecast_params(self) -> dict:
return self.get_params()

def _get_init_statsforecast_params(self):
"""Return parameters in __init__ statsforecast forecaster.

Return a list of parameters in the __init__ method from
the statsforecast forecaster class used in the sktime adapter.
"""
statsforecast_class = self._get_statsforecast_class()
return list(signature(statsforecast_class.__init__).parameters.keys())

def _get_validated_statsforecast_params(self):
sktime_params = self._get_statsforecast_params()
sktime_default_params = self.get_param_defaults().keys()
statsforecast_params = self._get_init_statsforecast_params()

for sktime_param in sktime_params.keys():
if sktime_param not in statsforecast_params:
sktime_params.pop(sktime_param)
if sktime_param not in sktime_default_params:
warn(
f"Keyword argument '{sktime_param}' will be omitted as it is"
f" not found in the __init__ method "
f"from {self._get_statsforecast_class()}. "
f"Check your statsforecast version"
f"to find out the right API parameters."
)
return sktime_params
def _get_statsforecast_default_params(self) -> dict:
"""Get default parameters for the statsforecast forecaster.

This will in general be different from self.get_param_defaults(),
as the set or names of inner parameters can differ.

For parameters without defaults, will use the parameter
of self instead.
"""
self_params = self.get_params(deep=False)
self_default_params = self.get_param_defaults()
self_params.update(self_default_params)
cls_with_defaults = type(self)(**self_params)
return cls_with_defaults._get_statsforecast_params()

def _get_validated_statsforecast_params(self) -> dict:
"""Return parameter dict with only parameters accepted by statsforecast API.

Checks if the parameters passed to the statsforecast forecaster
are valid in the __init__ method of the aforementioned forecaster.
If the parameter is not there it will just not be passed. Furthermore
if the parameter is modified by the sktime user,
he will be notified that the parameter does not exist
anymore in the version installed of statsforecast by the user.

"""
params_sktime_to_statsforecast: dict = self._get_statsforecast_params()
params_sktime_to_statsforecast_default: dict = (
self._get_statsforecast_default_params()
)
statsforecast_init_params = set(self._get_init_statsforecast_params())

# Filter sktime_params to only include keys in statsforecast_params
filtered_sktime_params = {
key: value
for key, value in params_sktime_to_statsforecast.items()
if key in statsforecast_init_params
}

non_default_params = [
p
for p in params_sktime_to_statsforecast
if params_sktime_to_statsforecast[p]
!= params_sktime_to_statsforecast_default[p]
]
# Find parameters not in statsforecast_params or sktime_default_params
param_diff = set(non_default_params) - statsforecast_init_params

if param_diff:
params_str = ", ".join([f'"{param}"' for param in param_diff])
warning_message = (
f"Keyword arguments {params_str} "
f"will be omitted as they are not found in the __init__ method from "
f"{self._get_statsforecast_class()}. Check your statsforecast version "
f"to find out the right API parameters."
)
warn(warning_message)

return filtered_sktime_params

def _instantiate_model(self):
cls = self._get_statsforecast_class()
Expand Down Expand Up @@ -400,10 +447,17 @@ def __init__(self, estimator):
super().__init__()

self.estimator = estimator
self.prediction_intervals = None

def __repr__(self):
return "StatsForecastBackAdapter"

def new(self):
"""Make new instance of back-adapter."""
_self = type(self).__new__(type(self))
_self.__dict__.update(self.__dict__)
return _self

def fit(self, y, X=None):
"""Fit to training data.

Expand Down Expand Up @@ -442,6 +496,12 @@ def predict(self, h, X=None, level=None):
mean = self.estimator.predict(fh=range(1, h + 1), X=X)[:, 0]
if level is None:
return {"mean": mean}
# if a level is passed, and if prediction_intervals has not been instantiated
# yet
elif self.prediction_intervals is None:
from statsforecast.utils import ConformalIntervals

self.prediction_intervals = ConformalIntervals(h=h)

level = sorted(level)
coverage = [round(1 - (_l / 100), 2) for _l in level]
Expand Down
24 changes: 24 additions & 0 deletions sktime/forecasting/statsforecast.py
Expand Up @@ -595,6 +595,12 @@ class StatsForecastMSTL(_GeneralisedStatsForecastAdapter):
trend_forecaster : estimator, optional, default=StatsForecastAutoETS()
Sktime estimator used to make univariate forecasts. Multivariate estimators are
not supported.
stl_kwargs : dict, optional
Extra arguments to pass to [`statsmodels.tsa.seasonal.STL`]
(https://www.statsmodels.org/dev/generated/statsmodels.tsa.seasonal.STL.html#statsmodels.tsa.seasonal.STL).
The `period` and `seasonal` arguments are reserved.
pred_int_kwargs : dict, optional
Extra arguments to pass to [`statsforecast.utils.ConformalIntervals`].

References
----------
Expand Down Expand Up @@ -623,6 +629,8 @@ def __init__(
self,
season_length: Union[int, List[int]],
trend_forecaster=None,
stl_kwargs: Optional[Dict] = None,
pred_int_kwargs: Optional[Dict] = None,
):
super().__init__()

Expand All @@ -634,6 +642,8 @@ def __init__(
self._trend_forecaster = clone(trend_forecaster)
else:
self._trend_forecaster = StatsForecastAutoETS(model="ZZN")
self.stl_kwargs = stl_kwargs
self.pred_int_kwargs = pred_int_kwargs

# checks if trend_forecaster is already wrapped with
# StatsForecastBackAdapter
Expand All @@ -650,6 +660,14 @@ def __init__(
"forecaster."
)

# check if prediction interval kwargs are passed
if self.pred_int_kwargs:
from statsforecast.utils import ConformalIntervals

self._trend_forecaster.prediction_intervals = ConformalIntervals(
**self.pred_int_kwargs
)

def _get_statsforecast_class(self):
from statsforecast.models import MSTL

Expand Down Expand Up @@ -695,6 +713,12 @@ def get_test_params(cls, parameter_set="default"):
{
"season_length": 4,
},
{
"season_length": 4,
"pred_int_kwargs": {
"n_windows": 2,
},
},
]
except ModuleNotFoundError:
from sktime.forecasting.naive import NaiveForecaster
Expand Down