Skip to content

Commit

Permalink
update mstl
Browse files Browse the repository at this point in the history
  • Loading branch information
luca-miniati committed Aug 18, 2023
1 parent 03f469d commit dc3e940
Showing 1 changed file with 125 additions and 24 deletions.
149 changes: 125 additions & 24 deletions sktime/forecasting/mstl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
__all__ = ["MSTL"]
__authors__ = ["luca-miniati"]

from typing import Optional
from typing import Dict, Optional, Sequence, Union

from sktime.forecasting.base.adapters import _StatsModelsAdapter
import pandas as pd

from sktime.transformations.base import BaseTransformer

class MSTL(_StatsModelsAdapter):

class MSTL(BaseTransformer):
"""Season-Trend decomposition using LOESS for multiple seasonalities.
Direct interface for `statsmodels.tsa.seasonal.MSTL`.
Expand All @@ -35,48 +37,144 @@ class MSTL(_StatsModelsAdapter):
Number of iterations to use to refine the seasonal component.
stl_kwargs : dict, optional
Arguments to pass to STL.
return_components : bool, default=False
if False, will return only the MSTL transformed series
if True, will return the transformed series, as well as three components
as variables in the returned multivariate series (DataFrame cols)
"transformed" - the transformed series
"seasonal" - the seasonal component
"trend" - the trend component
"resid" - the residuals after de-trending, de-seasonalizing
References
----------
[1] https://www.statsmodels.org/dev/generated/statsmodels.tsa.seasonal.MSTL.html
Examples
--------
>>> from sktime.forecasting.mstl import MSTL
>>> import matplotlib.pyplot as plt # doctest: +SKIP
>>> from sktime.datasets import load_airline
>>> import matplotlib.pyplot as plt
>>> from sktime.forecasting.mstl import MSTL
>>> y = load_airline()
>>> mstl = MSTL() # doctest: +SKIP
>>> res = mstl.fit(y) # doctest: +SKIP
>>> y.index = y.index.to_timestamp()
>>> mstl = MSTL(return_components=True) # doctest: +SKIP
>>> fitted = mstl.fit(y) # doctest: +SKIP
>>> res = fitted.transform(y) # doctest: +SKIP
>>> res.plot() # doctest: +SKIP
>>> plt.tight_layout()
>>> plt.show()
>>> plt.tight_layout() # doctest: +SKIP
>>> plt.show() # doctest: +SKIP
"""

_tags = {
"scitype:transform-input": "Series",
"scitype:transform-output": "Series",
"scitype:instancewise": True,
"X_inner_mtype": "pd.Series",
"y_inner_mtype": "pd.Series",
"transform-returns-same-time-index": True,
"univariate-only": True,
"fit_is_empty": False,
"python_dependencies": "statsmodels",
}

def __init__(
self,
periods=None,
windows=None,
lmbda=None,
iterate=Optional[int],
stl_kwargs=Optional[dict],
periods: Optional[Union[int, Sequence[int]]] = None,
windows: Optional[Union[int, Sequence[int]]] = None,
lmbda: Optional[Union[float, str]] = None,
iterate: Optional[int] = 2,
stl_kwargs: Optional[Dict[str, Union[int, bool, None]]] = None,
return_components: bool = False,
):
super().__init__()

self.periods = periods
self.windows = windows
self.lmbda = lmbda
self.iterate = iterate
self.stl_kwargs = stl_kwargs
self.return_components = return_components
self._X = None

def _fit_forecaster(self, y, X=None):
super().__init__()

def _fit(self, X, y=None):
from statsmodels.tsa.seasonal import MSTL as _MSTL

self._forecaster = _MSTL(
y, self.periods, self.windows, self.lmbda, self.iterate, self.stl_kwargs
)
self._X = X

self.mstl_ = _MSTL(
X,
periods=self.periods,
windows=self.windows,
lmbda=self.lmbda,
iterate=self.iterate,
stl_kwargs=self.stl_kwargs,
).fit()

self.seasonal_ = pd.Series(self.mstl_.seasonal, index=X.index)
self.resid_ = pd.Series(self.mstl_.resid, index=X.index)
self.trend_ = pd.Series(self.mstl_.trend, index=X.index)

return self

self._fitted_forecaster = self._forecaster.fit()
def _transform(self, X, y=None):
"""Transform X and return a transformed version.
private _transform containing core logic, called from transform
Parameters
----------
X : Series, Panel, or Hierarchical data, of mtype X_inner_mtype
if X_inner_mtype is list, _transform must support all types in it
Data to be transformed
y : Series, Panel, or Hierarchical data, of mtype y_inner_mtype, default=None
Additional data, e.g., labels for transformation
Returns
-------
transformed version of X
"""
from statsmodels.tsa.seasonal import MSTL as _MSTL

# fit again if indices not seen, but don't store anything
if not X.index.equals(self._X.index):
X_full = X.combine_first(self._X)
new_mstl = _MSTL(
X_full.values,
periods=self.periods,
windows=self.windows,
lmbda=self.lmbda,
iterate=self.iterate,
stl_kwargs=self.stl_kwargs,
).fit()

ret_obj = self._make_return_object(X_full, new_mstl)
else:
ret_obj = self._make_return_object(X, self.mstl_)

return ret_obj

def _make_return_object(self, X, mstl):
# deseasonalize only
transformed = pd.Series(X.values - mstl.seasonal, index=X.index)
# transformed = pd.Series(X.values - stl.seasonal - stl.trend, index=X.index)

if self.return_components:
seasonal = pd.Series(mstl.seasonal, index=X.index)
resid = pd.Series(mstl.resid, index=X.index)
trend = pd.Series(mstl.trend, index=X.index)

ret = pd.DataFrame(
{
"transformed": transformed,
"seasonal": seasonal,
"trend": trend,
"resid": resid,
}
)
else:
ret = transformed

return ret

@classmethod
def get_test_params(cls, parameter_set="default"):
Expand All @@ -97,10 +195,13 @@ def get_test_params(cls, parameter_set="default"):
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
`create_test_instance` uses the first (or only) dictionary in `params
"""
params1 = {}
params1 = {
"periods": 3,
"windows": 3,
}
params2 = {
"periods": [1, 12],
"windows": 9,
"periods": [3, 24],
"windows": [3, 9],
"lmbda": "auto",
"iterate": 10,
"stl_kwargs": {"trend_deg": 0},
Expand Down

0 comments on commit dc3e940

Please sign in to comment.