Skip to content

plot trend #565

Merged
merged 12 commits into from
Mar 9, 2022
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Prediction intervals visualization in `plot_forecast` ([#538](https://github.com/tinkoff-ai/etna/pull/538))
-
- Add plot_time_series_with_change_points function ([#534](https://github.com/tinkoff-ai/etna/pull/534))
-
- Add plot_trend ([#565](https://github.com/tinkoff-ai/etna/pull/565))
- Add find_change_points function ([#521](https://github.com/tinkoff-ai/etna/pull/521))
-
- Add plot_residuals ([#539](https://github.com/tinkoff-ai/etna/pull/539))
Expand Down
1 change: 1 addition & 0 deletions etna/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
from etna.analysis.plotters import plot_forecast
from etna.analysis.plotters import plot_residuals
from etna.analysis.plotters import plot_time_series_with_change_points
from etna.analysis.plotters import plot_trend
70 changes: 70 additions & 0 deletions etna/analysis/plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@

if TYPE_CHECKING:
from etna.datasets import TSDataset
from etna.transforms.decomposition.change_points_trend import ChangePointsTrendTransform
from etna.transforms.decomposition.detrend import LinearTrendTransform
from etna.transforms.decomposition.detrend import TheilSenTrendTransform
from etna.transforms.decomposition.stl import STLTransform


def prepare_axes(segments: List[str], columns_num: int, figsize: Tuple[int, int]) -> Sequence[matplotlib.axes.Axes]:
Expand Down Expand Up @@ -729,3 +733,69 @@ def plot_residuals(
ax[i].set_title(segment)
ax[i].tick_params("x", rotation=45)
ax[i].set_xlabel(feature)


TrendTransformType = Union[
"ChangePointsTrendTransform", "LinearTrendTransform", "TheilSenTrendTransform", "STLTransform"
]


def __get_labels_names(trend_transform, segments):
Copy link
Contributor

@iKintosh iKintosh Mar 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __get_labels_names(trend_transform, segments):
def _get_labels_names(trend_transform, segments):

https://stackoverflow.com/a/1301369/7415703

"""If only unique transform classes are used then show their short names (without parameters). Otherwise show their full repr as label."""
from etna.transforms.decomposition.detrend import LinearTrendTransform
from etna.transforms.decomposition.detrend import TheilSenTrendTransform

labels = [transform.__repr__() for transform in trend_transform]
labels_short = [i[: i.find("(")] for i in labels]
if len(np.unique(labels_short)) == len(labels_short):
labels = labels_short
linear_coeffs = dict(zip(segments, ["" for i in range(len(segments))]))
if len(trend_transform) == 1 and isinstance(trend_transform[0], (LinearTrendTransform, TheilSenTrendTransform)):
for seg in segments:
linear_coeffs[seg] = ", k=" + str(trend_transform[0].segment_transforms[seg]._linear_model.coef_[0])
return labels, linear_coeffs


def plot_trend(
ts: "TSDataset",
trend_transform: Union["TrendTransformType", List["TrendTransformType"]],
segments: Optional[List[str]] = None,
columns_num: int = 2,
figsize: Tuple[int, int] = (10, 5),
):
"""Plot series and trend from trend transform for this series.

If only unique transform classes are used then show their short names (without parameters). Otherwise show their full repr as label

Parameters
----------
ts:
dataframe of timeseries that was used for trend plot
trend_transform:
trend transform or list of trend transforms to apply
segments:
segments to use
columns_num:
number of columns in subplots
figsize:
size of the figure per subplot with one segment in inches
"""
if not segments:
segments = list(set(ts.columns.get_level_values("segment")))

ax = prepare_axes(segments=segments, columns_num=columns_num, figsize=figsize)
df = ts.df

if not isinstance(trend_transform, list):
trend_transform = [trend_transform]

df_detrend = [transform.fit_transform(df.copy()) for transform in trend_transform]
labels, linear_coeffs = __get_labels_names(trend_transform, segments)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
labels, linear_coeffs = __get_labels_names(trend_transform, segments)
labels, linear_coeffs = _get_labels_names(trend_transform, segments)

Same as above https://stackoverflow.com/a/1301369/7415703


for i, segment in enumerate(segments):
ax[i].plot(df[segment]["target"], label="Initial series")
for label, df_now in zip(labels, df_detrend):
ax[i].plot(df[segment, "target"] - df_now[segment, "target"], label=label + linear_coeffs[segment])
ax[i].set_title(segment)
ax[i].tick_params("x", rotation=45)
ax[i].legend()