diff --git a/CHANGELOG.md b/CHANGELOG.md index 682e34606..2b068dcd2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Method `set_params` to change parameters of ETNA objects [#1102](https://github.com/tinkoff-ai/etna/pull/1102) +- Function `plot_forecast_decomposition` [#1129](https://github.com/tinkoff-ai/etna/pull/1129) - ### Changed diff --git a/etna/analysis/__init__.py b/etna/analysis/__init__.py index cb2b34a99..76114fe9a 100644 --- a/etna/analysis/__init__.py +++ b/etna/analysis/__init__.py @@ -34,6 +34,7 @@ from etna.analysis.plotters import plot_correlation_matrix from etna.analysis.plotters import plot_feature_relevance from etna.analysis.plotters import plot_forecast +from etna.analysis.plotters import plot_forecast_decomposition from etna.analysis.plotters import plot_holidays from etna.analysis.plotters import plot_imputation from etna.analysis.plotters import plot_metric_per_segment diff --git a/etna/analysis/plotters.py b/etna/analysis/plotters.py index 1ecf8ec0a..6bf3499a2 100644 --- a/etna/analysis/plotters.py +++ b/etna/analysis/plotters.py @@ -34,6 +34,7 @@ from etna.analysis.feature_selection import AGGREGATION_FN from etna.analysis.feature_selection import AggregationMode from etna.analysis.utils import prepare_axes +from etna.datasets.utils import match_target_components from etna.transforms import Transform if TYPE_CHECKING: @@ -1958,3 +1959,116 @@ def update(**kwargs): plt.show() interact(update, **sliders) + + +class ComponentsMode(str, Enum): + """Enum for components plotting modes.""" + + per_component = "per-component" + joint = "joint" + + @classmethod + def _missing_(cls, value): + raise NotImplementedError( + f"{value} is not a valid {cls.__name__}. Supported modes: {', '.join([repr(m.value) for m in cls])}" + ) + + +def plot_forecast_decomposition( + forecast_ts: "TSDataset", + test_ts: Optional["TSDataset"] = None, + mode: Union[Literal["per-component"], Literal["joint"]] = "per-component", + segments: Optional[List[str]] = None, + columns_num: int = 1, + figsize: Tuple[int, int] = (10, 5), + show_grid: bool = False, +): + """ + Plot of prediction and its components. + + Parameters + ---------- + forecast_ts: + forecasted TSDataset with timeseries data, single-forecast mode + test_ts: + TSDataset with timeseries data + mode: + Components plotting type + + #. ``per-component`` -- plot each component in separate axes + + #. ``joint`` -- plot all the components in the same axis + + segments: + segments to plot; if not given plot all the segments + columns_num: + number of graphics columns; when mode=``per-component`` all plots will be in the single column + figsize: + size of the figure per subplot with one segment in inches + show_grid: + whether to show grid for each chart + + Raises + ------ + ValueError: + if components aren't present in ``forecast_ts`` + NotImplementedError: + unknown ``mode`` is given + """ + components_mode = ComponentsMode(mode) + + if segments is None: + segments = list(forecast_ts.columns.get_level_values("segment").unique()) + + column_names = set(forecast_ts.columns.get_level_values("feature")) + components = list(match_target_components(column_names)) + + if len(components) == 0: + raise ValueError("No components were detected in the provided `forecast_ts`.") + + if components_mode == ComponentsMode.joint: + num_plots = len(segments) + else: + # plotting target and forecast separately from components, thus +1 for each segment + num_plots = math.ceil(len(segments) / columns_num) * columns_num * (len(components) + 1) + + _, ax = prepare_axes(num_plots=num_plots, columns_num=columns_num, figsize=figsize, set_grid=show_grid) + + if test_ts is not None: + test_ts.df.sort_values(by="timestamp", inplace=True) + + alpha = 0.5 if components_mode == ComponentsMode.joint else 1.0 + ax_array = np.asarray(ax).reshape(-1, columns_num).T.ravel() + + i = 0 + for segment in segments: + if test_ts is not None: + segment_test_df = test_ts[:, segment, :][segment] + else: + segment_test_df = pd.DataFrame(columns=["timestamp", "target", "segment"]) + + segment_forecast_df = forecast_ts[:, segment, :][segment].sort_values(by="timestamp") + + ax_array[i].set_title(segment) + + ax_array[i].plot(segment_forecast_df.index.values, segment_forecast_df["target"].values, label="forecast") + + if test_ts is not None: + ax_array[i].plot(segment_test_df.index.values, segment_test_df["target"].values, label="target") + else: + # skip color for target + next(ax_array[i]._get_lines.prop_cycler) + + for component in components: + if components_mode == ComponentsMode.per_component: + ax_array[i].legend(loc="upper left") + ax_array[i].set_xticklabels([]) + i += 1 + + ax_array[i].plot( + segment_forecast_df.index.values, segment_forecast_df[component].values, label=component, alpha=alpha + ) + + ax_array[i].tick_params("x", rotation=45) + ax_array[i].legend(loc="upper left") + i += 1 diff --git a/etna/datasets/utils.py b/etna/datasets/utils.py index a2e56f073..19af4c0ab 100644 --- a/etna/datasets/utils.py +++ b/etna/datasets/utils.py @@ -187,6 +187,11 @@ def match_target_quantiles(features: Set[str]) -> Set[str]: return {i for i in list(features) if pattern.match(i) is not None} +def match_target_components(features: Set[str]) -> Set[str]: + """Find target components in a set of features.""" + return set(filter(lambda f: f.startswith("target_component_"), features)) + + def get_target_with_quantiles(columns: pd.Index) -> Set[str]: """Find "target" column and target quantiles among dataframe columns.""" column_names = set(columns.get_level_values(level="feature")) diff --git a/tests/test_datasets/test_utils.py b/tests/test_datasets/test_utils.py index f44a7f154..1a0094e08 100644 --- a/tests/test_datasets/test_utils.py +++ b/tests/test_datasets/test_utils.py @@ -8,6 +8,7 @@ from etna.datasets.utils import _TorchDataset from etna.datasets.utils import get_level_dataframe from etna.datasets.utils import get_target_with_quantiles +from etna.datasets.utils import match_target_components from etna.datasets.utils import set_columns_wide @@ -186,6 +187,7 @@ def test_set_columns_wide( ({"a", "b", "target"}, {"target"}), ({"a", "b", "target", "target_0.5"}, {"target", "target_0.5"}), ({"a", "b", "target", "target_0.5", "target1"}, {"target", "target_0.5"}), + ({"target_component_a", "a", "b", "target_component_c", "target", "target_0.95"}, {"target", "target_0.95"}), ), ) def test_get_target_with_quantiles(segments, columns, answer): @@ -242,3 +244,19 @@ def test_get_level_dataframe_segm_errors( source_level_segments=source_level_segments, target_level_segments=target_level_segments, ) + + +@pytest.mark.parametrize( + "features,answer", + ( + (set(), set()), + ({"a", "b"}, set()), + ( + {"target_component_a", "a", "b", "target_component_c", "target", "target_0.95"}, + {"target_component_a", "target_component_c"}, + ), + ), +) +def test_match_target_components(features, answer): + components = match_target_components(features) + assert components == answer