-
Notifications
You must be signed in to change notification settings - Fork 77
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: martin <33594071+martins0n@users.noreply.github.com>
- Loading branch information
1 parent
a0ffd6a
commit 8d09303
Showing
4 changed files
with
298 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
from typing import Any | ||
from typing import Dict | ||
from typing import Optional | ||
from typing import Union | ||
|
||
import pandas as pd | ||
from statsmodels.tsa.arima.model import ARIMA | ||
from statsmodels.tsa.base.tsa_model import TimeSeriesModel | ||
from statsmodels.tsa.exponential_smoothing.ets import ETSModel | ||
from statsmodels.tsa.forecasting.stl import STLForecast | ||
from statsmodels.tsa.forecasting.stl import STLForecastResults | ||
|
||
from etna.transforms.base import PerSegmentWrapper | ||
from etna.transforms.base import Transform | ||
|
||
|
||
class _OneSegmentSTLTransform(Transform): | ||
def __init__( | ||
self, | ||
in_column: str, | ||
period: int, | ||
model: Union[str, TimeSeriesModel] = "arima", | ||
robust: bool = False, | ||
model_kwargs: Optional[Dict[str, Any]] = None, | ||
stl_kwargs: Optional[Dict[str, Any]] = None, | ||
): | ||
""" | ||
Init _OneSegmentSTLTransform. | ||
Parameters | ||
---------- | ||
in_column: | ||
name of processed column | ||
period: | ||
size of seasonality | ||
model: | ||
model to predict trend, default options are: | ||
1. "arima": `ARIMA(data, 1, 1, 0)` (default) | ||
2. "holt": `ETSModel(data, trend='add')` | ||
Custom model should be a subclass of statsmodels.tsa.base.tsa_model.TimeSeriesModel | ||
and have method `get_prediction` (not just `predict`) | ||
robust: | ||
flag indicating whether to use robust version of STL | ||
model_kwargs: | ||
parameters for the model like in statsmodels.tsa.seasonal.STLForecast | ||
stl_kwargs: | ||
additional parameters for statsmodels.tsa.seasonal.STLForecast | ||
""" | ||
if model_kwargs is None: | ||
model_kwargs = {} | ||
if stl_kwargs is None: | ||
stl_kwargs = {} | ||
|
||
self.in_column = in_column | ||
self.period = period | ||
|
||
if isinstance(model, str): | ||
if model == "arima": | ||
self.model = ARIMA | ||
model_kwargs = {"order": (1, 1, 0)} | ||
elif model == "holt": | ||
self.model = ETSModel | ||
model_kwargs = {"trend": "add"} | ||
else: | ||
raise ValueError(f"Not a valid option for model: {model}") | ||
elif isinstance(model, TimeSeriesModel): | ||
self.model = model | ||
else: | ||
raise ValueError("Model should be a string or TimeSeriesModel") | ||
|
||
self.robust = robust | ||
self.model_kwargs = model_kwargs | ||
self.stl_kwargs = stl_kwargs | ||
self.fit_results: Optional[STLForecastResults] = None | ||
|
||
def fit(self, df: pd.DataFrame) -> "_OneSegmentSTLTransform": | ||
""" | ||
Perform STL decomposition and fit trend model. | ||
Parameters | ||
---------- | ||
df: | ||
Features dataframe with time | ||
Returns | ||
------- | ||
result: _OneSegmentSTLTransform | ||
instance after processing | ||
""" | ||
model = STLForecast( | ||
df[self.in_column], | ||
self.model, | ||
model_kwargs=self.model_kwargs, | ||
period=self.period, | ||
robust=self.robust, | ||
**self.stl_kwargs, | ||
) | ||
self.fit_results = model.fit() | ||
return self | ||
|
||
def transform(self, df: pd.DataFrame) -> pd.DataFrame: | ||
""" | ||
Subtract trend and seasonal component. | ||
Parameters | ||
---------- | ||
df: | ||
Features dataframe with time | ||
Returns | ||
------- | ||
result: pd.DataFrame | ||
Dataframe with extracted features | ||
""" | ||
result = df.copy() | ||
season_trend = self.fit_results.get_prediction(start=df.index.min(), end=df.index.max()).predicted_mean | ||
result[self.in_column] -= season_trend | ||
return result | ||
|
||
def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame: | ||
""" | ||
Add trend and seasonal component. | ||
Parameters | ||
---------- | ||
df: | ||
Features dataframe with time | ||
Returns | ||
------- | ||
result: pd.DataFrame | ||
Dataframe with extracted features | ||
""" | ||
result = df.copy() | ||
season_trend = self.fit_results.get_prediction(start=df.index.min(), end=df.index.max()).predicted_mean | ||
result[self.in_column] += season_trend | ||
return result | ||
|
||
|
||
class STLTransform(PerSegmentWrapper): | ||
"""Transform that uses statsmodels.tsa.seasonal.STL to substract season and trend from the data.""" | ||
|
||
def __init__( | ||
self, | ||
in_column: str, | ||
period: int, | ||
model: Union[str, TimeSeriesModel] = "arima", | ||
robust: bool = False, | ||
model_kwargs: Optional[Dict[str, Any]] = None, | ||
stl_kwargs: Optional[Dict[str, Any]] = None, | ||
): | ||
""" | ||
Init STLTransform. | ||
Parameters | ||
---------- | ||
in_column: | ||
name of processed column | ||
period: | ||
size of seasonality | ||
model: | ||
model to predict trend, custom model or one of: | ||
1. "arima": `ARIMA(data, 1, 0, 0)` (default) | ||
2. "holt": `ETSModel(data, trend='add')` | ||
robust: | ||
flag indicating whether to use robust version of STL | ||
model_kwargs: | ||
parameters for the model like in statsmodels.tsa.seasonal.STLForecast | ||
stl_kwargs: | ||
additional parameters for statsmodels.tsa.seasonal.STLForecast | ||
""" | ||
self.in_column = in_column | ||
self.period = period | ||
self.model = model | ||
self.robust = robust | ||
self.model_kwargs = model_kwargs | ||
self.stl_kwargs = stl_kwargs | ||
super().__init__( | ||
transform=_OneSegmentSTLTransform( | ||
in_column=self.in_column, | ||
period=self.period, | ||
model=self.model, | ||
robust=self.robust, | ||
model_kwargs=self.model_kwargs, | ||
stl_kwargs=self.stl_kwargs, | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
|
||
from etna.datasets.tsdataset import TSDataset | ||
from etna.models import NaiveModel | ||
from etna.transforms.stl import STLTransform | ||
from etna.transforms.stl import _OneSegmentSTLTransform | ||
|
||
|
||
def add_trend(series: pd.Series, coef: float = 1) -> pd.Series: | ||
"""Add trend to given series.""" | ||
new_series = series.copy() | ||
size = series.shape[0] | ||
indices = np.arange(size) | ||
new_series += indices * coef | ||
return new_series | ||
|
||
|
||
def add_seasonality(series: pd.Series, period: int, magnitude: float) -> pd.Series: | ||
"""Add seasonality to given series.""" | ||
new_series = series.copy() | ||
size = series.shape[0] | ||
indices = np.arange(size) | ||
new_series += np.sin(2 * np.pi * indices / period) * magnitude | ||
return new_series | ||
|
||
|
||
def get_one_df(coef: float, period: int, magnitude: float) -> pd.DataFrame: | ||
df = pd.DataFrame() | ||
df["timestamp"] = pd.date_range(start="2020-01-01", end="2020-03-01", freq="D") | ||
df["target"] = 0 | ||
df["target"] = add_seasonality(df["target"], period=period, magnitude=magnitude) | ||
df["target"] = add_trend(df["target"], coef=coef) | ||
return df | ||
|
||
|
||
@pytest.fixture | ||
def df_trend_seasonal_one_segment() -> pd.DataFrame: | ||
df = get_one_df(coef=0.1, period=7, magnitude=1) | ||
df.set_index("timestamp") | ||
return df | ||
|
||
|
||
@pytest.fixture | ||
def ts_trend_seasonal() -> TSDataset: | ||
df_1 = get_one_df(coef=0.1, period=7, magnitude=1) | ||
df_1["segment"] = "segment_1" | ||
df_2 = get_one_df(coef=0.05, period=7, magnitude=2) | ||
df_2["segment"] = "segment_2" | ||
classic_df = pd.concat([df_1, df_2], ignore_index=True) | ||
return TSDataset(TSDataset.to_dataset(classic_df), freq="D") | ||
|
||
|
||
@pytest.mark.parametrize("model", ["arima", "holt"]) | ||
def test_transform_one_segment(df_trend_seasonal_one_segment, model): | ||
"""Test that transform for one segment removes trend and seasonality.""" | ||
transform = _OneSegmentSTLTransform(in_column="target", period=7, model=model) | ||
df_transformed = transform.fit_transform(df_trend_seasonal_one_segment) | ||
np.testing.assert_allclose(df_transformed["target"], 0, atol=0.2) | ||
|
||
|
||
@pytest.mark.parametrize("model", ["arima", "holt"]) | ||
def test_transform_multi_segments(ts_trend_seasonal, model): | ||
"""Test that transform for all segments removes trend and seasonality.""" | ||
transform = STLTransform(in_column="target", period=7, model=model) | ||
ts_trend_seasonal.fit_transform(transforms=[transform]) | ||
np.testing.assert_allclose(ts_trend_seasonal[:, :, "target"], 0, atol=0.2) | ||
|
||
|
||
@pytest.mark.parametrize("model", ["arima", "holt"]) | ||
def test_inverse_transform_one_segment(df_trend_seasonal_one_segment, model): | ||
"""Test that transform + inverse_transform don't change dataframe.""" | ||
transform = _OneSegmentSTLTransform(in_column="target", period=7, model=model) | ||
df_transformed = transform.fit_transform(df_trend_seasonal_one_segment) | ||
df_inverse_transformed = transform.inverse_transform(df_transformed) | ||
assert np.all(df_trend_seasonal_one_segment["target"] == df_inverse_transformed["target"]) | ||
|
||
|
||
@pytest.mark.parametrize("model", ["arima", "holt"]) | ||
def test_inverse_transform_multi_segments(ts_trend_seasonal, model): | ||
"""Test that transform + inverse_transform don't change tsdataset.""" | ||
transform = STLTransform(in_column="target", period=7, model=model) | ||
dataframe_initial = ts_trend_seasonal.to_pandas() | ||
ts_trend_seasonal.fit_transform(transforms=[transform]) | ||
ts_trend_seasonal.inverse_transform() | ||
for segment in ts_trend_seasonal.segments: | ||
assert np.all( | ||
ts_trend_seasonal[:, segment, "target"] == dataframe_initial.loc[:, pd.IndexSlice[segment, "target"]] | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("model_stl", ["arima", "holt"]) | ||
def test_forecast(ts_trend_seasonal, model_stl): | ||
"""Test that transform works correctly in forecast.""" | ||
transform = STLTransform(in_column="target", period=7, model=model_stl) | ||
ts_train, ts_test = ts_trend_seasonal.train_test_split( | ||
ts_trend_seasonal.index[0], | ||
ts_trend_seasonal.index[-4], | ||
ts_trend_seasonal.index[-3], | ||
ts_trend_seasonal.index[-1], | ||
) | ||
ts_train.fit_transform(transforms=[transform]) | ||
model = NaiveModel() | ||
model.fit(ts_train) | ||
ts_future = ts_train.make_future(3) | ||
ts_forecast = model.forecast(ts_future) | ||
for segment in ts_forecast.segments: | ||
np.testing.assert_allclose(ts_forecast[:, segment, "target"], ts_test[:, segment, "target"], atol=0.1) |