Skip to content

Commit

Permalink
Add STLTransform (#158)
Browse files Browse the repository at this point in the history
Co-authored-by: martin <33594071+martins0n@users.noreply.github.com>
  • Loading branch information
Mr-Geekman and martins0n committed Oct 11, 2021
1 parent a0ffd6a commit 8d09303
Show file tree
Hide file tree
Showing 4 changed files with 298 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- TrendTransform ([#139](https://github.com/tinkoff-ai/etna-ts/pull/139))
- Running notebooks in ci ([#134](https://github.com/tinkoff-ai/etna-ts/issues/134))
- Cluster plotter to EDA ([#169](https://github.com/tinkoff-ai/etna-ts/pull/169))
- STLTransform class ([#158](https://github.com/tinkoff-ai/etna-ts/pull/158))

### Changed
- Delete offset from WindowStatisticsTransform ([#111](https://github.com/tinkoff-ai/etna-ts/pull/111))
Expand Down
1 change: 1 addition & 0 deletions etna/transforms/__init__.py
Expand Up @@ -27,4 +27,5 @@
from etna.transforms.statistics import MinTransform
from etna.transforms.statistics import QuantileTransform
from etna.transforms.statistics import StdTransform
from etna.transforms.stl import STLTransform
from etna.transforms.trend import TrendTransform
187 changes: 187 additions & 0 deletions etna/transforms/stl.py
@@ -0,0 +1,187 @@
from typing import Any
from typing import Dict
from typing import Optional
from typing import Union

import pandas as pd
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tsa.base.tsa_model import TimeSeriesModel
from statsmodels.tsa.exponential_smoothing.ets import ETSModel
from statsmodels.tsa.forecasting.stl import STLForecast
from statsmodels.tsa.forecasting.stl import STLForecastResults

from etna.transforms.base import PerSegmentWrapper
from etna.transforms.base import Transform


class _OneSegmentSTLTransform(Transform):
def __init__(
self,
in_column: str,
period: int,
model: Union[str, TimeSeriesModel] = "arima",
robust: bool = False,
model_kwargs: Optional[Dict[str, Any]] = None,
stl_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Init _OneSegmentSTLTransform.
Parameters
----------
in_column:
name of processed column
period:
size of seasonality
model:
model to predict trend, default options are:
1. "arima": `ARIMA(data, 1, 1, 0)` (default)
2. "holt": `ETSModel(data, trend='add')`
Custom model should be a subclass of statsmodels.tsa.base.tsa_model.TimeSeriesModel
and have method `get_prediction` (not just `predict`)
robust:
flag indicating whether to use robust version of STL
model_kwargs:
parameters for the model like in statsmodels.tsa.seasonal.STLForecast
stl_kwargs:
additional parameters for statsmodels.tsa.seasonal.STLForecast
"""
if model_kwargs is None:
model_kwargs = {}
if stl_kwargs is None:
stl_kwargs = {}

self.in_column = in_column
self.period = period

if isinstance(model, str):
if model == "arima":
self.model = ARIMA
model_kwargs = {"order": (1, 1, 0)}
elif model == "holt":
self.model = ETSModel
model_kwargs = {"trend": "add"}
else:
raise ValueError(f"Not a valid option for model: {model}")
elif isinstance(model, TimeSeriesModel):
self.model = model
else:
raise ValueError("Model should be a string or TimeSeriesModel")

self.robust = robust
self.model_kwargs = model_kwargs
self.stl_kwargs = stl_kwargs
self.fit_results: Optional[STLForecastResults] = None

def fit(self, df: pd.DataFrame) -> "_OneSegmentSTLTransform":
"""
Perform STL decomposition and fit trend model.
Parameters
----------
df:
Features dataframe with time
Returns
-------
result: _OneSegmentSTLTransform
instance after processing
"""
model = STLForecast(
df[self.in_column],
self.model,
model_kwargs=self.model_kwargs,
period=self.period,
robust=self.robust,
**self.stl_kwargs,
)
self.fit_results = model.fit()
return self

def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Subtract trend and seasonal component.
Parameters
----------
df:
Features dataframe with time
Returns
-------
result: pd.DataFrame
Dataframe with extracted features
"""
result = df.copy()
season_trend = self.fit_results.get_prediction(start=df.index.min(), end=df.index.max()).predicted_mean
result[self.in_column] -= season_trend
return result

def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Add trend and seasonal component.
Parameters
----------
df:
Features dataframe with time
Returns
-------
result: pd.DataFrame
Dataframe with extracted features
"""
result = df.copy()
season_trend = self.fit_results.get_prediction(start=df.index.min(), end=df.index.max()).predicted_mean
result[self.in_column] += season_trend
return result


class STLTransform(PerSegmentWrapper):
"""Transform that uses statsmodels.tsa.seasonal.STL to substract season and trend from the data."""

def __init__(
self,
in_column: str,
period: int,
model: Union[str, TimeSeriesModel] = "arima",
robust: bool = False,
model_kwargs: Optional[Dict[str, Any]] = None,
stl_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Init STLTransform.
Parameters
----------
in_column:
name of processed column
period:
size of seasonality
model:
model to predict trend, custom model or one of:
1. "arima": `ARIMA(data, 1, 0, 0)` (default)
2. "holt": `ETSModel(data, trend='add')`
robust:
flag indicating whether to use robust version of STL
model_kwargs:
parameters for the model like in statsmodels.tsa.seasonal.STLForecast
stl_kwargs:
additional parameters for statsmodels.tsa.seasonal.STLForecast
"""
self.in_column = in_column
self.period = period
self.model = model
self.robust = robust
self.model_kwargs = model_kwargs
self.stl_kwargs = stl_kwargs
super().__init__(
transform=_OneSegmentSTLTransform(
in_column=self.in_column,
period=self.period,
model=self.model,
robust=self.robust,
model_kwargs=self.model_kwargs,
stl_kwargs=self.stl_kwargs,
)
)
109 changes: 109 additions & 0 deletions tests/test_transforms/test_stl_transform.py
@@ -0,0 +1,109 @@
import numpy as np
import pandas as pd
import pytest

from etna.datasets.tsdataset import TSDataset
from etna.models import NaiveModel
from etna.transforms.stl import STLTransform
from etna.transforms.stl import _OneSegmentSTLTransform


def add_trend(series: pd.Series, coef: float = 1) -> pd.Series:
"""Add trend to given series."""
new_series = series.copy()
size = series.shape[0]
indices = np.arange(size)
new_series += indices * coef
return new_series


def add_seasonality(series: pd.Series, period: int, magnitude: float) -> pd.Series:
"""Add seasonality to given series."""
new_series = series.copy()
size = series.shape[0]
indices = np.arange(size)
new_series += np.sin(2 * np.pi * indices / period) * magnitude
return new_series


def get_one_df(coef: float, period: int, magnitude: float) -> pd.DataFrame:
df = pd.DataFrame()
df["timestamp"] = pd.date_range(start="2020-01-01", end="2020-03-01", freq="D")
df["target"] = 0
df["target"] = add_seasonality(df["target"], period=period, magnitude=magnitude)
df["target"] = add_trend(df["target"], coef=coef)
return df


@pytest.fixture
def df_trend_seasonal_one_segment() -> pd.DataFrame:
df = get_one_df(coef=0.1, period=7, magnitude=1)
df.set_index("timestamp")
return df


@pytest.fixture
def ts_trend_seasonal() -> TSDataset:
df_1 = get_one_df(coef=0.1, period=7, magnitude=1)
df_1["segment"] = "segment_1"
df_2 = get_one_df(coef=0.05, period=7, magnitude=2)
df_2["segment"] = "segment_2"
classic_df = pd.concat([df_1, df_2], ignore_index=True)
return TSDataset(TSDataset.to_dataset(classic_df), freq="D")


@pytest.mark.parametrize("model", ["arima", "holt"])
def test_transform_one_segment(df_trend_seasonal_one_segment, model):
"""Test that transform for one segment removes trend and seasonality."""
transform = _OneSegmentSTLTransform(in_column="target", period=7, model=model)
df_transformed = transform.fit_transform(df_trend_seasonal_one_segment)
np.testing.assert_allclose(df_transformed["target"], 0, atol=0.2)


@pytest.mark.parametrize("model", ["arima", "holt"])
def test_transform_multi_segments(ts_trend_seasonal, model):
"""Test that transform for all segments removes trend and seasonality."""
transform = STLTransform(in_column="target", period=7, model=model)
ts_trend_seasonal.fit_transform(transforms=[transform])
np.testing.assert_allclose(ts_trend_seasonal[:, :, "target"], 0, atol=0.2)


@pytest.mark.parametrize("model", ["arima", "holt"])
def test_inverse_transform_one_segment(df_trend_seasonal_one_segment, model):
"""Test that transform + inverse_transform don't change dataframe."""
transform = _OneSegmentSTLTransform(in_column="target", period=7, model=model)
df_transformed = transform.fit_transform(df_trend_seasonal_one_segment)
df_inverse_transformed = transform.inverse_transform(df_transformed)
assert np.all(df_trend_seasonal_one_segment["target"] == df_inverse_transformed["target"])


@pytest.mark.parametrize("model", ["arima", "holt"])
def test_inverse_transform_multi_segments(ts_trend_seasonal, model):
"""Test that transform + inverse_transform don't change tsdataset."""
transform = STLTransform(in_column="target", period=7, model=model)
dataframe_initial = ts_trend_seasonal.to_pandas()
ts_trend_seasonal.fit_transform(transforms=[transform])
ts_trend_seasonal.inverse_transform()
for segment in ts_trend_seasonal.segments:
assert np.all(
ts_trend_seasonal[:, segment, "target"] == dataframe_initial.loc[:, pd.IndexSlice[segment, "target"]]
)


@pytest.mark.parametrize("model_stl", ["arima", "holt"])
def test_forecast(ts_trend_seasonal, model_stl):
"""Test that transform works correctly in forecast."""
transform = STLTransform(in_column="target", period=7, model=model_stl)
ts_train, ts_test = ts_trend_seasonal.train_test_split(
ts_trend_seasonal.index[0],
ts_trend_seasonal.index[-4],
ts_trend_seasonal.index[-3],
ts_trend_seasonal.index[-1],
)
ts_train.fit_transform(transforms=[transform])
model = NaiveModel()
model.fit(ts_train)
ts_future = ts_train.make_future(3)
ts_forecast = model.forecast(ts_future)
for segment in ts_forecast.segments:
np.testing.assert_allclose(ts_forecast[:, segment, "target"], ts_test[:, segment, "target"], atol=0.1)

0 comments on commit 8d09303

Please sign in to comment.