Skip to content

Visualization for forecast decomposition #1129

Merged
merged 6 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Method `set_params` to change parameters of ETNA objects [#1102](https://github.com/tinkoff-ai/etna/pull/1102)
- Added `plot_forecast_decomposition` function [#1129](https://github.com/tinkoff-ai/etna/pull/1129)
Mr-Geekman marked this conversation as resolved.
Show resolved Hide resolved
-
### Changed

Expand Down
1 change: 1 addition & 0 deletions etna/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from etna.analysis.plotters import plot_correlation_matrix
from etna.analysis.plotters import plot_feature_relevance
from etna.analysis.plotters import plot_forecast
from etna.analysis.plotters import plot_forecast_decomposition
from etna.analysis.plotters import plot_holidays
from etna.analysis.plotters import plot_imputation
from etna.analysis.plotters import plot_metric_per_segment
Expand Down
114 changes: 114 additions & 0 deletions etna/analysis/plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from etna.analysis.feature_selection import AGGREGATION_FN
from etna.analysis.feature_selection import AggregationMode
from etna.analysis.utils import prepare_axes
from etna.datasets.utils import match_target_components
from etna.transforms import Transform

if TYPE_CHECKING:
Expand Down Expand Up @@ -1958,3 +1959,116 @@ def update(**kwargs):
plt.show()

interact(update, **sliders)


class ComponentsMode(str, Enum):
"""Enum for components plotting modes."""

per_component = "per-component"
together = "together"
Mr-Geekman marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def _missing_(cls, value):
raise NotImplementedError(
f"{value} is not a valid {cls.__name__}. Supported modes: {', '.join([repr(m.value) for m in cls])}"
)


def plot_forecast_decomposition(
forecast_ts: "TSDataset",
test_ts: Optional["TSDataset"] = None,
mode: Union[Literal["per-component"], Literal["together"]] = "per-component",
segments: Optional[List[str]] = None,
columns_num: int = 1,
figsize: Tuple[int, int] = (10, 5),
show_grid: bool = False,
):
"""
Plot of prediction and its components.

Parameters
----------
forecast_ts:
forecasted TSDataset with timeseries data, single-forecast mode
test_ts:
TSDataset with timeseries data
mode:
Components plotting type

#. ``per-component`` -- plot each component in separate axes

#. ``together`` -- plot all the components in the same axis

segments:
segments to plot; if not given plot all the segments
columns_num:
number of graphics columns; when mode=``per-component`` all plots will be in the single column
figsize:
size of the figure per subplot with one segment in inches
show_grid:
whether to show grid for each chart

Raises
------
ValueError:
if components aren't present in ``forecast_ts``
NotImplementedError:
unknown ``mode`` is given
"""
components_mode = ComponentsMode(mode)

if segments is None:
segments = list(forecast_ts.columns.get_level_values("segment").unique())

column_names = set(forecast_ts.columns.get_level_values("feature"))
components = list(match_target_components(column_names))

if len(components) == 0:
raise ValueError("No components were detected in the provided `forecast_ts`.")

if components_mode == ComponentsMode.together:
num_plots = len(segments)
else:
# separate chart for target/forecast
Mr-Geekman marked this conversation as resolved.
Show resolved Hide resolved
num_plots = math.ceil(len(segments) / columns_num) * columns_num * (len(components) + 1)

_, ax = prepare_axes(num_plots=num_plots, columns_num=columns_num, figsize=figsize, set_grid=show_grid)

if test_ts is not None:
test_ts.df.sort_values(by="timestamp", inplace=True)

alpha = 0.5 if components_mode == ComponentsMode.together else 1.0
ax_array = np.asarray(ax).reshape(-1, columns_num).T.ravel()

i = 0
for segment in segments:
if test_ts is not None:
segment_test_df = test_ts[:, segment, :][segment]
else:
segment_test_df = pd.DataFrame(columns=["timestamp", "target", "segment"])

segment_forecast_df = forecast_ts[:, segment, :][segment].sort_values(by="timestamp")

ax_array[i].set_title(segment)

ax_array[i].plot(segment_forecast_df.index.values, segment_forecast_df["target"].values, label="forecast")

if test_ts is not None:
ax_array[i].plot(segment_test_df.index.values, segment_test_df["target"].values, label="target")
else:
# skip color for target
next(ax_array[i]._get_lines.prop_cycler)

for component in components:
if components_mode == ComponentsMode.per_component:
ax_array[i].legend(loc="upper left")
ax_array[i].set_xticklabels([])
i += 1

ax_array[i].plot(
segment_forecast_df.index.values, segment_forecast_df[component].values, label=component, alpha=alpha
)

ax_array[i].tick_params("x", rotation=45)
ax_array[i].legend(loc="upper left")
i += 1
5 changes: 5 additions & 0 deletions etna/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ def match_target_quantiles(features: Set[str]) -> Set[str]:
return {i for i in list(features) if pattern.match(i) is not None}


def match_target_components(features: Set[str]) -> Set[str]:
Mr-Geekman marked this conversation as resolved.
Show resolved Hide resolved
"""Find target components in a set of features."""
return set(filter(lambda f: f.startswith("target_component_"), features))


def get_target_with_quantiles(columns: pd.Index) -> Set[str]:
"""Find "target" column and target quantiles among dataframe columns."""
column_names = set(columns.get_level_values(level="feature"))
Expand Down
18 changes: 18 additions & 0 deletions tests/test_datasets/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from etna.datasets.utils import _TorchDataset
from etna.datasets.utils import get_level_dataframe
from etna.datasets.utils import get_target_with_quantiles
from etna.datasets.utils import match_target_components
from etna.datasets.utils import set_columns_wide


Expand Down Expand Up @@ -186,6 +187,7 @@ def test_set_columns_wide(
({"a", "b", "target"}, {"target"}),
({"a", "b", "target", "target_0.5"}, {"target", "target_0.5"}),
({"a", "b", "target", "target_0.5", "target1"}, {"target", "target_0.5"}),
({"target_component_a", "a", "b", "target_component_c", "target", "target_0.95"}, {"target", "target_0.95"}),
),
)
def test_get_target_with_quantiles(segments, columns, answer):
Expand Down Expand Up @@ -242,3 +244,19 @@ def test_get_level_dataframe_segm_errors(
source_level_segments=source_level_segments,
target_level_segments=target_level_segments,
)


@pytest.mark.parametrize(
"features,answer",
(
(set(), set()),
({"a", "b"}, set()),
(
{"target_component_a", "a", "b", "target_component_c", "target", "target_0.95"},
{"target_component_a", "target_component_c"},
),
),
)
def test_match_target_components(features, answer):
components = match_target_components(features)
assert components == answer