Skip to content

Visualization for forecast decomposition #1129

Merged
merged 6 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Method `set_params` to change parameters of ETNA objects [#1102](https://github.com/tinkoff-ai/etna/pull/1102)
- Function `plot_forecast_decomposition` [#1129](https://github.com/tinkoff-ai/etna/pull/1129)
-
### Changed

Expand Down
1 change: 1 addition & 0 deletions etna/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from etna.analysis.plotters import plot_correlation_matrix
from etna.analysis.plotters import plot_feature_relevance
from etna.analysis.plotters import plot_forecast
from etna.analysis.plotters import plot_forecast_decomposition
from etna.analysis.plotters import plot_holidays
from etna.analysis.plotters import plot_imputation
from etna.analysis.plotters import plot_metric_per_segment
Expand Down
114 changes: 114 additions & 0 deletions etna/analysis/plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from etna.analysis.feature_selection import AGGREGATION_FN
from etna.analysis.feature_selection import AggregationMode
from etna.analysis.utils import prepare_axes
from etna.datasets.utils import match_target_components
from etna.transforms import Transform

if TYPE_CHECKING:
Expand Down Expand Up @@ -1958,3 +1959,116 @@ def update(**kwargs):
plt.show()

interact(update, **sliders)


class ComponentsMode(str, Enum):
"""Enum for components plotting modes."""

per_component = "per-component"
joint = "joint"

@classmethod
def _missing_(cls, value):
raise NotImplementedError(
f"{value} is not a valid {cls.__name__}. Supported modes: {', '.join([repr(m.value) for m in cls])}"
)


def plot_forecast_decomposition(
forecast_ts: "TSDataset",
test_ts: Optional["TSDataset"] = None,
mode: Union[Literal["per-component"], Literal["joint"]] = "per-component",
segments: Optional[List[str]] = None,
columns_num: int = 1,
figsize: Tuple[int, int] = (10, 5),
show_grid: bool = False,
):
"""
Plot of prediction and its components.

Parameters
----------
forecast_ts:
forecasted TSDataset with timeseries data, single-forecast mode
test_ts:
TSDataset with timeseries data
mode:
Components plotting type

#. ``per-component`` -- plot each component in separate axes

#. ``joint`` -- plot all the components in the same axis

segments:
segments to plot; if not given plot all the segments
columns_num:
number of graphics columns; when mode=``per-component`` all plots will be in the single column
figsize:
size of the figure per subplot with one segment in inches
show_grid:
whether to show grid for each chart

Raises
------
ValueError:
if components aren't present in ``forecast_ts``
NotImplementedError:
unknown ``mode`` is given
"""
components_mode = ComponentsMode(mode)

if segments is None:
segments = list(forecast_ts.columns.get_level_values("segment").unique())

column_names = set(forecast_ts.columns.get_level_values("feature"))
components = list(match_target_components(column_names))

if len(components) == 0:
raise ValueError("No components were detected in the provided `forecast_ts`.")

if components_mode == ComponentsMode.joint:
num_plots = len(segments)
else:
# plotting target and forecast separately from components, thus +1 for each segment
num_plots = math.ceil(len(segments) / columns_num) * columns_num * (len(components) + 1)

_, ax = prepare_axes(num_plots=num_plots, columns_num=columns_num, figsize=figsize, set_grid=show_grid)

if test_ts is not None:
test_ts.df.sort_values(by="timestamp", inplace=True)

alpha = 0.5 if components_mode == ComponentsMode.joint else 1.0
ax_array = np.asarray(ax).reshape(-1, columns_num).T.ravel()

i = 0
for segment in segments:
if test_ts is not None:
segment_test_df = test_ts[:, segment, :][segment]
else:
segment_test_df = pd.DataFrame(columns=["timestamp", "target", "segment"])

segment_forecast_df = forecast_ts[:, segment, :][segment].sort_values(by="timestamp")

ax_array[i].set_title(segment)

ax_array[i].plot(segment_forecast_df.index.values, segment_forecast_df["target"].values, label="forecast")

if test_ts is not None:
ax_array[i].plot(segment_test_df.index.values, segment_test_df["target"].values, label="target")
else:
# skip color for target
next(ax_array[i]._get_lines.prop_cycler)

for component in components:
if components_mode == ComponentsMode.per_component:
ax_array[i].legend(loc="upper left")
ax_array[i].set_xticklabels([])
i += 1

ax_array[i].plot(
segment_forecast_df.index.values, segment_forecast_df[component].values, label=component, alpha=alpha
)

ax_array[i].tick_params("x", rotation=45)
ax_array[i].legend(loc="upper left")
i += 1
5 changes: 5 additions & 0 deletions etna/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ def match_target_quantiles(features: Set[str]) -> Set[str]:
return {i for i in list(features) if pattern.match(i) is not None}


def match_target_components(features: Set[str]) -> Set[str]:
Mr-Geekman marked this conversation as resolved.
Show resolved Hide resolved
"""Find target components in a set of features."""
return set(filter(lambda f: f.startswith("target_component_"), features))


def get_target_with_quantiles(columns: pd.Index) -> Set[str]:
"""Find "target" column and target quantiles among dataframe columns."""
column_names = set(columns.get_level_values(level="feature"))
Expand Down
18 changes: 18 additions & 0 deletions tests/test_datasets/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from etna.datasets.utils import _TorchDataset
from etna.datasets.utils import get_level_dataframe
from etna.datasets.utils import get_target_with_quantiles
from etna.datasets.utils import match_target_components
from etna.datasets.utils import set_columns_wide


Expand Down Expand Up @@ -186,6 +187,7 @@ def test_set_columns_wide(
({"a", "b", "target"}, {"target"}),
({"a", "b", "target", "target_0.5"}, {"target", "target_0.5"}),
({"a", "b", "target", "target_0.5", "target1"}, {"target", "target_0.5"}),
({"target_component_a", "a", "b", "target_component_c", "target", "target_0.95"}, {"target", "target_0.95"}),
),
)
def test_get_target_with_quantiles(segments, columns, answer):
Expand Down Expand Up @@ -242,3 +244,19 @@ def test_get_level_dataframe_segm_errors(
source_level_segments=source_level_segments,
target_level_segments=target_level_segments,
)


@pytest.mark.parametrize(
"features,answer",
(
(set(), set()),
({"a", "b"}, set()),
(
{"target_component_a", "a", "b", "target_component_c", "target", "target_0.95"},
{"target_component_a", "target_component_c"},
),
),
)
def test_match_target_components(features, answer):
components = match_target_components(features)
assert components == answer