From 9cf9f530c5f4a291270abeceece9388f7d16998f Mon Sep 17 00:00:00 2001 From: Artyom Makhin Date: Tue, 22 Feb 2022 12:27:59 +0300 Subject: [PATCH 1/8] plot trend --- etna/analysis/__init__.py | 1 + etna/analysis/plotters.py | 58 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/etna/analysis/__init__.py b/etna/analysis/__init__.py index 1c34b26f7..03f329eb2 100644 --- a/etna/analysis/__init__.py +++ b/etna/analysis/__init__.py @@ -24,3 +24,4 @@ from etna.analysis.plotters import plot_forecast from etna.analysis.plotters import plot_residuals from etna.analysis.plotters import plot_time_series_with_change_points +from etna.analysis.plotters import plot_trend diff --git a/etna/analysis/plotters.py b/etna/analysis/plotters.py index 14b1d0548..7df12ff69 100644 --- a/etna/analysis/plotters.py +++ b/etna/analysis/plotters.py @@ -19,6 +19,10 @@ import plotly.graph_objects as go import seaborn as sns +from etna.transforms import ChangePointsTrendTransform +from etna.transforms import LinearTrendTransform +from etna.transforms import STLTransform +from etna.transforms import TheilSenTrendTransform from etna.transforms import Transform if TYPE_CHECKING: @@ -729,3 +733,57 @@ def plot_residuals( ax[i].set_title(segment) ax[i].tick_params("x", rotation=45) ax[i].set_xlabel(feature) + + +TrendTransformType = Union[ChangePointsTrendTransform, LinearTrendTransform, TheilSenTrendTransform, STLTransform] + + +def plot_trend( + ts: "TSDataset", + trend_transform: Union["TrendTransformType", List["TrendTransformType"]], + segments: Optional[List[str]] = None, + columns_num: int = 2, + figsize: Tuple[int, int] = (10, 5), +): + """Plot series and trend from trend transform for this series. + + Parameters + ---------- + ts: + dataframe of timeseries that was used for trend plot + trend_transform: + trend transform or list of trend transforms to apply + segments: + segments to use + columns_num: + number of columns in subplots + figsize: + size of the figure per subplot with one segment in inches + """ + if not segments: + segments = list(set(ts.columns.get_level_values("segment"))) + + ax = prepare_axes(segments=segments, columns_num=columns_num, figsize=figsize) + df = ts.df + linear_coeffs = dict(zip(segments, ["" for i in range(len(segments))])) + if isinstance(trend_transform, list): + df_detrend = [transform.fit_transform(df.copy()) for transform in trend_transform] + labels = [transform.__repr__() for transform in trend_transform] + labels_short = [i[: i.find("(")] for i in labels] + if len(np.unique(labels_short)) == len(labels_short): + labels = labels_short + else: + df_detrend = [trend_transform.fit_transform(df.copy())] + labels = [trend_transform.__repr__()[: trend_transform.__repr__().find("(")]] + if isinstance(trend_transform, LinearTrendTransform) or isinstance(trend_transform, TheilSenTrendTransform): + for seg in segments: + + linear_coeffs[seg] = ", k=" + str(trend_transform.segment_transforms[seg]._linear_model.coef_[0]) + + for i, segment in enumerate(segments): + ax[i].plot(df[segment]["target"], label="Initial series") + for label, df_now in zip(labels, df_detrend): + ax[i].plot(df[segment, "target"] - df_now[segment, "target"], label=label + linear_coeffs[segment]) + ax[i].set_title(segment) + ax[i].tick_params("x", rotation=45) + ax[i].legend() From 87a11859e1b0e4bd65ef51e8323e970677d037ba Mon Sep 17 00:00:00 2001 From: Artyom Makhin Date: Tue, 22 Feb 2022 12:48:26 +0300 Subject: [PATCH 2/8] Fix imports --- CHANGELOG.md | 2 +- etna/analysis/plotters.py | 12 +++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 856e0078c..8be525d85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add plot_time_series_with_change_points function ([#534](https://github.com/tinkoff-ai/etna/pull/534)) - - Add find_change_points function ([#521](https://github.com/tinkoff-ai/etna/pull/521)) -- +- Add plot_trend ([#565](https://github.com/tinkoff-ai/etna/pull/565)) - Add plot_residuals ([#539](https://github.com/tinkoff-ai/etna/pull/539)) - - Create `PerSegmentBaseModel`, `PerSegmentPredictionIntervalModel` ([#537](https://github.com/tinkoff-ai/etna/pull/537)) diff --git a/etna/analysis/plotters.py b/etna/analysis/plotters.py index 7df12ff69..3c9eb0ace 100644 --- a/etna/analysis/plotters.py +++ b/etna/analysis/plotters.py @@ -19,14 +19,14 @@ import plotly.graph_objects as go import seaborn as sns -from etna.transforms import ChangePointsTrendTransform -from etna.transforms import LinearTrendTransform -from etna.transforms import STLTransform -from etna.transforms import TheilSenTrendTransform from etna.transforms import Transform if TYPE_CHECKING: from etna.datasets import TSDataset + from etna.transforms import ChangePointsTrendTransform + from etna.transforms import LinearTrendTransform + from etna.transforms import STLTransform + from etna.transforms import TheilSenTrendTransform def prepare_axes(segments: List[str], columns_num: int, figsize: Tuple[int, int]) -> Sequence[matplotlib.axes.Axes]: @@ -735,7 +735,9 @@ def plot_residuals( ax[i].set_xlabel(feature) -TrendTransformType = Union[ChangePointsTrendTransform, LinearTrendTransform, TheilSenTrendTransform, STLTransform] +TrendTransformType = Union[ + "ChangePointsTrendTransform", "LinearTrendTransform", "TheilSenTrendTransform", "STLTransform" +] def plot_trend( From e47ef8d1fb3a6f2ca1f21749dc0368ee25a58eab Mon Sep 17 00:00:00 2001 From: Artyom Makhin Date: Tue, 22 Feb 2022 12:52:16 +0300 Subject: [PATCH 3/8] changelog --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8be525d85..d37c18f49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,9 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Prediction intervals visualization in `plot_forecast` ([#538](https://github.com/tinkoff-ai/etna/pull/538)) - - Add plot_time_series_with_change_points function ([#534](https://github.com/tinkoff-ai/etna/pull/534)) -- -- Add find_change_points function ([#521](https://github.com/tinkoff-ai/etna/pull/521)) - Add plot_trend ([#565](https://github.com/tinkoff-ai/etna/pull/565)) +- Add find_change_points function ([#521](https://github.com/tinkoff-ai/etna/pull/521)) +- - Add plot_residuals ([#539](https://github.com/tinkoff-ai/etna/pull/539)) - - Create `PerSegmentBaseModel`, `PerSegmentPredictionIntervalModel` ([#537](https://github.com/tinkoff-ai/etna/pull/537)) From 708791c401dc5066431217ad6e09417f0602c65a Mon Sep 17 00:00:00 2001 From: Artyom Makhin Date: Mon, 28 Feb 2022 13:37:24 +0300 Subject: [PATCH 4/8] fix comments --- etna/analysis/plotters.py | 43 +++++++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/etna/analysis/plotters.py b/etna/analysis/plotters.py index 3c9eb0ace..7baeb507f 100644 --- a/etna/analysis/plotters.py +++ b/etna/analysis/plotters.py @@ -20,13 +20,13 @@ import seaborn as sns from etna.transforms import Transform +from etna.transforms.decomposition.change_points_trend import ChangePointsTrendTransform +from etna.transforms.decomposition.detrend import LinearTrendTransform +from etna.transforms.decomposition.detrend import TheilSenTrendTransform +from etna.transforms.decomposition.stl import STLTransform if TYPE_CHECKING: from etna.datasets import TSDataset - from etna.transforms import ChangePointsTrendTransform - from etna.transforms import LinearTrendTransform - from etna.transforms import STLTransform - from etna.transforms import TheilSenTrendTransform def prepare_axes(segments: List[str], columns_num: int, figsize: Tuple[int, int]) -> Sequence[matplotlib.axes.Axes]: @@ -740,6 +740,19 @@ def plot_residuals( ] +def __get_labels_names(trend_transform, segments): + """If only unique transform classes are used then show their short names (without parameters). Otherwise show their full repr as label""" + labels = [transform.__repr__() for transform in trend_transform] + labels_short = [i[: i.find("(")] for i in labels] + if len(np.unique(labels_short)) == len(labels_short): + labels = labels_short + linear_coeffs = dict(zip(segments, ["" for i in range(len(segments))])) + if len(trend_transform) == 1 and isinstance(trend_transform[0], (LinearTrendTransform, TheilSenTrendTransform)): + for seg in segments: + linear_coeffs[seg] = ", k=" + str(trend_transform[0].segment_transforms[seg]._linear_model.coef_[0]) + return labels, linear_coeffs + + def plot_trend( ts: "TSDataset", trend_transform: Union["TrendTransformType", List["TrendTransformType"]], @@ -749,6 +762,8 @@ def plot_trend( ): """Plot series and trend from trend transform for this series. + If only unique transform classes are used then show their short names (without parameters). Otherwise show their full repr as label + Parameters ---------- ts: @@ -767,20 +782,12 @@ def plot_trend( ax = prepare_axes(segments=segments, columns_num=columns_num, figsize=figsize) df = ts.df - linear_coeffs = dict(zip(segments, ["" for i in range(len(segments))])) - if isinstance(trend_transform, list): - df_detrend = [transform.fit_transform(df.copy()) for transform in trend_transform] - labels = [transform.__repr__() for transform in trend_transform] - labels_short = [i[: i.find("(")] for i in labels] - if len(np.unique(labels_short)) == len(labels_short): - labels = labels_short - else: - df_detrend = [trend_transform.fit_transform(df.copy())] - labels = [trend_transform.__repr__()[: trend_transform.__repr__().find("(")]] - if isinstance(trend_transform, LinearTrendTransform) or isinstance(trend_transform, TheilSenTrendTransform): - for seg in segments: - - linear_coeffs[seg] = ", k=" + str(trend_transform.segment_transforms[seg]._linear_model.coef_[0]) + + if not isinstance(trend_transform, list): + trend_transform = [trend_transform] + + df_detrend = [transform.fit_transform(df.copy()) for transform in trend_transform] + labels, linear_coeffs = __get_labels_names(trend_transform, segments) for i, segment in enumerate(segments): ax[i].plot(df[segment]["target"], label="Initial series") From f102cff5e929cfef2b2f3627a5c93573d26f7250 Mon Sep 17 00:00:00 2001 From: Artyom Makhin Date: Mon, 28 Feb 2022 14:53:00 +0300 Subject: [PATCH 5/8] fix imports --- etna/analysis/plotters.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/etna/analysis/plotters.py b/etna/analysis/plotters.py index 7baeb507f..e7c3c1da7 100644 --- a/etna/analysis/plotters.py +++ b/etna/analysis/plotters.py @@ -20,13 +20,13 @@ import seaborn as sns from etna.transforms import Transform -from etna.transforms.decomposition.change_points_trend import ChangePointsTrendTransform -from etna.transforms.decomposition.detrend import LinearTrendTransform -from etna.transforms.decomposition.detrend import TheilSenTrendTransform -from etna.transforms.decomposition.stl import STLTransform if TYPE_CHECKING: from etna.datasets import TSDataset + from etna.transforms.decomposition.change_points_trend import ChangePointsTrendTransform + from etna.transforms.decomposition.detrend import LinearTrendTransform + from etna.transforms.decomposition.detrend import TheilSenTrendTransform + from etna.transforms.decomposition.stl import STLTransform def prepare_axes(segments: List[str], columns_num: int, figsize: Tuple[int, int]) -> Sequence[matplotlib.axes.Axes]: @@ -741,7 +741,10 @@ def plot_residuals( def __get_labels_names(trend_transform, segments): - """If only unique transform classes are used then show their short names (without parameters). Otherwise show their full repr as label""" + """If only unique transform classes are used then show their short names (without parameters). Otherwise show their full repr as label.""" + from etna.transforms.decomposition.detrend import LinearTrendTransform + from etna.transforms.decomposition.detrend import TheilSenTrendTransform + labels = [transform.__repr__() for transform in trend_transform] labels_short = [i[: i.find("(")] for i in labels] if len(np.unique(labels_short)) == len(labels_short): From 4869e12ba8b2111cc9cfb9870f4295331ded6d24 Mon Sep 17 00:00:00 2001 From: Artyom Makhin Date: Sat, 5 Mar 2022 11:45:56 +0300 Subject: [PATCH 6/8] minor fix --- etna/analysis/plotters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/etna/analysis/plotters.py b/etna/analysis/plotters.py index e7c3c1da7..062b188d2 100644 --- a/etna/analysis/plotters.py +++ b/etna/analysis/plotters.py @@ -740,7 +740,7 @@ def plot_residuals( ] -def __get_labels_names(trend_transform, segments): +def _get_labels_names(trend_transform, segments): """If only unique transform classes are used then show their short names (without parameters). Otherwise show their full repr as label.""" from etna.transforms.decomposition.detrend import LinearTrendTransform from etna.transforms.decomposition.detrend import TheilSenTrendTransform @@ -790,7 +790,7 @@ def plot_trend( trend_transform = [trend_transform] df_detrend = [transform.fit_transform(df.copy()) for transform in trend_transform] - labels, linear_coeffs = __get_labels_names(trend_transform, segments) + labels, linear_coeffs = _get_labels_names(trend_transform, segments) for i, segment in enumerate(segments): ax[i].plot(df[segment]["target"], label="Initial series") From 3375261879f6b09f8fa995c421f3e4dcd3fc3480 Mon Sep 17 00:00:00 2001 From: Artyom Makhin Date: Wed, 9 Mar 2022 11:04:23 +0300 Subject: [PATCH 7/8] fix representation of k --- etna/analysis/plotters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/etna/analysis/plotters.py b/etna/analysis/plotters.py index 062b188d2..07d29faea 100644 --- a/etna/analysis/plotters.py +++ b/etna/analysis/plotters.py @@ -752,7 +752,7 @@ def _get_labels_names(trend_transform, segments): linear_coeffs = dict(zip(segments, ["" for i in range(len(segments))])) if len(trend_transform) == 1 and isinstance(trend_transform[0], (LinearTrendTransform, TheilSenTrendTransform)): for seg in segments: - linear_coeffs[seg] = ", k=" + str(trend_transform[0].segment_transforms[seg]._linear_model.coef_[0]) + linear_coeffs[seg] = ", k=" + f'{trend_transform[0].segment_transforms[seg]._linear_model.coef_[0]:g}' return labels, linear_coeffs From 4e4855913697885c3d01142c6507f6721d4da0c3 Mon Sep 17 00:00:00 2001 From: Artyom Makhin Date: Wed, 9 Mar 2022 11:13:21 +0300 Subject: [PATCH 8/8] lint --- etna/analysis/plotters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/etna/analysis/plotters.py b/etna/analysis/plotters.py index c73ef0f56..89445fbe1 100644 --- a/etna/analysis/plotters.py +++ b/etna/analysis/plotters.py @@ -752,7 +752,7 @@ def _get_labels_names(trend_transform, segments): linear_coeffs = dict(zip(segments, ["" for i in range(len(segments))])) if len(trend_transform) == 1 and isinstance(trend_transform[0], (LinearTrendTransform, TheilSenTrendTransform)): for seg in segments: - linear_coeffs[seg] = ", k=" + f'{trend_transform[0].segment_transforms[seg]._linear_model.coef_[0]:g}' + linear_coeffs[seg] = ", k=" + f"{trend_transform[0].segment_transforms[seg]._linear_model.coef_[0]:g}" return labels, linear_coeffs