Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MNT] [ENH] Several updates in direct statsforecast interface estimators #5920

Merged
merged 16 commits into from Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Expand Up @@ -103,7 +103,7 @@ all_extras = [
"seaborn>=0.11",
"seasonal",
"skpro<2.2.0,>=2",
'statsforecast<1.7.0,>=0.5.2; python_version < "3.12"',
'statsforecast<1.8.0,>=1.0.0; python_version < "3.12"',
"statsmodels>=0.12.1",
'stumpy>=1.5.1; python_version < "3.11"',
'tbats>=1.1; python_version < "3.12"',
Expand Down Expand Up @@ -138,7 +138,7 @@ all_extras_pandas2 = [
"seaborn>=0.11",
"seasonal",
"skpro<2.2.0,>=2",
'statsforecast<1.7.0,>=0.5.2; python_version < "3.12"',
'statsforecast<1.8.0,>=1.0.0; python_version < "3.12"',
"statsmodels>=0.12.1",
'stumpy>=1.5.1; python_version < "3.11"',
'tbats>=1.1; python_version < "3.12"',
Expand Down
8 changes: 8 additions & 0 deletions sktime/forecasting/arch/_statsforecast_arch.py
Expand Up @@ -42,6 +42,7 @@ class StatsForecastGARCH(_GeneralisedStatsForecastAdapter):
"ignores-exogeneous-X": False,
"capability:pred_int": True,
"capability:pred_int:insample": True,
"python_dependencies": ["statsforecast>=1.5.0"],
}

def __init__(
Expand Down Expand Up @@ -102,9 +103,16 @@ class StatsForecastARCH(_GeneralisedStatsForecastAdapter):
"""

_tags = {
# packaging info
# --------------
"authors": ["eyjo"],
"maintainers": ["eyjo"],
# estimator type
# --------------
"ignores-exogeneous-X": False,
"capability:pred_int": True,
"capability:pred_int:insample": True,
"python_dependencies": ["statsforecast>=1.5.0"],
}

def __init__(
Expand Down
Expand Up @@ -8,8 +8,8 @@

from sktime.forecasting.base import BaseForecaster

__all__ = ["_GeneralisedStatsForecastAdapter"]
__author__ = ["yarnabrina"]
__all__ = ["_GeneralisedStatsForecastAdapter", "StatsForecastBackAdapter"]
__author__ = ["yarnabrina", "arnaujc91", "luca-miniati"]


class _GeneralisedStatsForecastAdapter(BaseForecaster):
Expand All @@ -18,7 +18,7 @@ class _GeneralisedStatsForecastAdapter(BaseForecaster):
_tags = {
# packaging info
# --------------
"authors": ["yarnabrina"],
"authors": ["yarnabrina", "arnaujc91"],
"maintainers": ["yarnabrina"],
"python_version": ">=3.8",
"python_dependencies": ["statsforecast"],
Expand Down Expand Up @@ -456,6 +456,7 @@ def __init__(self, estimator):
self.prediction_intervals = None

def __repr__(self):
"""Representation dunder."""
return "StatsForecastBackAdapter"

def new(self):
Expand Down Expand Up @@ -545,6 +546,7 @@ def predict_in_sample(self, level=None):
return self.format_pred_int("fitted", fitted, pred_int, coverage, level)

def format_pred_int(self, y_pred_name, y_pred, pred_int, coverage, level):
"""Convert prediction intervals into a StatsForecast-format dictionary."""
pred_int_prefix = "fitted-" if y_pred_name == "fitted" else ""

pred_int_no_lev = pred_int.droplevel(0, axis=1)
Expand Down
36 changes: 29 additions & 7 deletions sktime/forecasting/statsforecast.py
Expand Up @@ -188,6 +188,7 @@ class StatsForecastAutoARIMA(_GeneralisedStatsForecastAdapter):
"ignores-exogeneous-X": False,
"capability:pred_int": True,
"capability:pred_int:insample": True,
"python_dependencies": ["statsforecast>=1.0.0"],
}

def __init__(
Expand Down Expand Up @@ -376,6 +377,7 @@ class StatsForecastAutoTheta(_GeneralisedStatsForecastAdapter):
"ignores-exogeneous-X": True,
"capability:pred_int": True,
"capability:pred_int:insample": True,
"python_dependencies": ["statsforecast>=1.3.0"],
}

def __init__(
Expand Down Expand Up @@ -457,6 +459,8 @@ class StatsForecastAutoETS(_GeneralisedStatsForecastAdapter):
Controlling state-space-equations.
damped : bool
A parameter that 'dampens' the trend.
phi : float, optional (default=None)
Smoothing parameter for trend damping. Only used when `damped=True`.

References
----------
Expand All @@ -480,14 +484,20 @@ class StatsForecastAutoETS(_GeneralisedStatsForecastAdapter):
"ignores-exogeneous-X": True,
"capability:pred_int": True,
"capability:pred_int:insample": True,
"python_dependencies": ["statsforecast>=1.3.2"],
}

def __init__(
self, season_length: int = 1, model: str = "ZZZ", damped: Optional[bool] = None
self,
season_length: int = 1,
model: str = "ZZZ",
damped: Optional[bool] = None,
phi: Optional[float] = None,
):
self.season_length = season_length
self.model = model
self.damped = damped
self.phi = phi

super().__init__()

Expand All @@ -502,6 +512,7 @@ def _get_statsforecast_params(self):
"season_length": self.season_length,
"model": self.model,
"damped": self.damped,
"phi": self.phi,
}

@classmethod
Expand Down Expand Up @@ -572,6 +583,7 @@ class StatsForecastAutoCES(_GeneralisedStatsForecastAdapter):
"ignores-exogeneous-X": True,
"capability:pred_int": True,
"capability:pred_int:insample": True,
"python_dependencies": ["statsforecast>=1.1.0"],
}

def __init__(self, season_length: int = 1, model: str = "Z"):
Expand Down Expand Up @@ -784,8 +796,9 @@ class StatsForecastMSTL(_GeneralisedStatsForecastAdapter):
# estimator type
# --------------
"ignores-exogeneous-X": True,
"capability:pred_int": True,
"capability:pred_int:insample": True,
"capability:pred_int": False,
"capability:pred_int:insample": False,
"python_dependencies": ["statsforecast>=1.2.0"],
}

def __init__(
Expand All @@ -795,18 +808,27 @@ def __init__(
stl_kwargs: Optional[Dict] = None,
pred_int_kwargs: Optional[Dict] = None,
):
self.season_length = season_length
self.trend_forecaster = trend_forecaster
self.stl_kwargs = stl_kwargs
self.pred_int_kwargs = pred_int_kwargs

super().__init__()

# adapter class sets probabilistic capability as true
# because level is present in statsforecast signature
# happens in _check_supports_pred_int method
# manually overriding this temporarily
self.set_tags(
**{"capability:pred_int": False, "capability:pred_int:insample": False}
)

from sklearn.base import clone

self.trend_forecaster = trend_forecaster
self.season_length = season_length
if trend_forecaster:
self._trend_forecaster = clone(trend_forecaster)
else:
self._trend_forecaster = StatsForecastAutoETS(model="ZZN")
self.stl_kwargs = stl_kwargs
self.pred_int_kwargs = pred_int_kwargs

# checks if trend_forecaster is already wrapped with
# StatsForecastBackAdapter
Expand Down
3 changes: 3 additions & 0 deletions sktime/forecasting/tests/test_statsforecast.py
Expand Up @@ -14,6 +14,9 @@
from sktime.tests.test_switch import run_test_for_class


@pytest.mark.skip(
reason="probabilistic capability of StatsForecastMSTL is disabled, see #5703, #5920"
)
@pytest.mark.skipif(
not run_test_for_class(StatsForecastMSTL),
reason="run test only if softdeps are present and incrementally (if requested)",
Expand Down
2 changes: 2 additions & 0 deletions sktime/tests/_config.py
Expand Up @@ -189,6 +189,8 @@
"LTSFDLinearForecaster": ["test_predict_time_index_in_sample_full"],
"LTSFNLinearForecaster": ["test_predict_time_index_in_sample_full"],
"WEASEL": ["test_multiprocessing_idempotent"], # see 5658
# StatsForecastMSTL is failing in probabistic forecasts, see #5703, #5920
"StatsForecastMSTL": ["test_pred_int_tag", "test_statsforecast_mstl"],
}

# We use estimator tags in addition to class hierarchies to further distinguish
Expand Down