Skip to content

Commit

Permalink
Add stl_plot (#575)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mr-Geekman committed Mar 2, 2022
1 parent 1920257 commit a8fcfa3
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 1 deletion.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
- Add option `season_number` to DateFlagsTransform ([#567](https://github.com/tinkoff-ai/etna/pull/567))
-
-
- Add stl_plot ([#575](https://github.com/tinkoff-ai/etna/pull/575))
-
- Create `AbstaractPipeline` ([#573](https://github.com/tinkoff-ai/etna/pull/573))
-
Expand Down
1 change: 1 addition & 0 deletions etna/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from etna.analysis.eda_utils import distribution_plot
from etna.analysis.eda_utils import sample_acf_plot
from etna.analysis.eda_utils import sample_pacf_plot
from etna.analysis.eda_utils import stl_plot
from etna.analysis.feature_relevance.relevance import ModelRelevanceTable
from etna.analysis.feature_relevance.relevance import RelevanceTable
from etna.analysis.feature_relevance.relevance import StatisticsRelevanceTable
Expand Down
75 changes: 75 additions & 0 deletions etna/analysis/eda_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,21 @@
import warnings
from itertools import combinations
from typing import TYPE_CHECKING
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import statsmodels.api as sm
from matplotlib.ticker import MaxNLocator
from statsmodels.graphics import utils
from statsmodels.tsa.seasonal import STL

if TYPE_CHECKING:
from etna.datasets import TSDataset
Expand Down Expand Up @@ -221,3 +226,73 @@ def distribution_plot(
sns.boxplot(data=df_slice.sort_values(by="segment"), y="z", x="segment", ax=ax[i], fliersize=False)
ax[i].set_title(f"{period}")
i += 1


def stl_plot(
ts: "TSDataset",
in_column: str = "target",
period: Optional[int] = None,
segments: Optional[List[str]] = None,
columns_num: int = 2,
figsize: Tuple[int, int] = (10, 10),
plot_kwargs: Optional[Dict[str, Any]] = None,
stl_kwargs: Optional[Dict[str, Any]] = None,
):
"""Plot STL decomposition for segments.
Parameters
----------
ts:
dataset with timeseries data
segments:
segments to plot
columns_num:
number of columns in subplots
figsize:
size of the figure per subplot with one segment in inches
plot_kwargs:
dictionary with parameters for plotting, `matplotlib.axes.Axes.plot` is used
stl_kwargs:
dictionary with parameters for STL decomposition, `statsmodels.tsa.seasonal.STL` is used
"""
if plot_kwargs is None:
plot_kwargs = {}
if stl_kwargs is None:
stl_kwargs = {}
if not segments:
segments = sorted(ts.segments)

segments_number = len(segments)
columns_num = min(columns_num, len(segments))
rows_num = math.ceil(segments_number / columns_num)

figsize = (figsize[0] * columns_num, figsize[1] * rows_num)
fig = plt.figure(figsize=figsize, constrained_layout=True)
subfigs = fig.subfigures(rows_num, columns_num)

df = ts.to_pandas()
for i, segment in enumerate(segments):
segment_df = df.loc[:, pd.IndexSlice[segment, :]][segment]
segment_df = segment_df[segment_df.first_valid_index() : segment_df.last_valid_index()]
decompose_result = STL(endog=segment_df[in_column], period=period, **stl_kwargs).fit()

# start plotting
subfigs.flat[i].suptitle(segment)
axs = subfigs.flat[i].subplots(4, 1, sharex=True)

# plot observed
axs.flat[0].plot(segment_df.index, decompose_result.observed, **plot_kwargs)
axs.flat[0].set_ylabel("Observed")

# plot trend
axs.flat[1].plot(segment_df.index, decompose_result.trend, **plot_kwargs)
axs.flat[1].set_ylabel("Trend")

# plot seasonal
axs.flat[2].plot(segment_df.index, decompose_result.seasonal, **plot_kwargs)
axs.flat[2].set_ylabel("Seasonal")

# plot residuals
axs.flat[3].plot(segment_df.index, decompose_result.resid, **plot_kwargs)
axs.flat[3].set_ylabel("Residual")
axs.flat[3].tick_params("x", rotation=45)

0 comments on commit a8fcfa3

Please sign in to comment.