Skip to content

Commit

Permalink
Prediction intervals visualization (#538)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ama16 committed Feb 18, 2022
1 parent 305a7c8 commit 95a988c
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 3 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `FutureMixin` into some transforms ([#361](https://github.com/tinkoff-ai/etna/pull/361))
- Regressors updating in TSDataset transform loops ([#374](https://github.com/tinkoff-ai/etna/pull/374))
- Regressors handling in TSDataset `make_future` and `train_test_split` ([#447](https://github.com/tinkoff-ai/etna/pull/447))
-
- Prediction intervals visualization in `plot_forecast` ([#538](https://github.com/tinkoff-ai/etna/pull/538))
-
- Add plot_time_series_with_change_points function ([#534](https://github.com/tinkoff-ai/etna/pull/534))
-
Expand Down
65 changes: 63 additions & 2 deletions etna/analysis/plotters.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import math
import warnings
from typing import TYPE_CHECKING
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union

Expand All @@ -26,6 +28,8 @@ def plot_forecast(
n_train_samples: Optional[int] = None,
columns_num: int = 2,
figsize: Tuple[int, int] = (10, 5),
prediction_intervals: bool = False,
quantiles: Optional[Sequence[float]] = None,
):
"""
Plot of prediction for forecast pipeline.
Expand All @@ -46,6 +50,10 @@ def plot_forecast(
number of graphics columns
figsize:
size of the figure per subplot with one segment in inches
prediction_intervals:
if True prediction intervals will be drawn
quantiles:
list of quantiles to draw
"""
if not segments:
segments = list(set(forecast_ts.columns.get_level_values("segment")))
Expand All @@ -57,6 +65,21 @@ def plot_forecast(
_, ax = plt.subplots(rows_num, columns_num, figsize=figsize, constrained_layout=True)
ax = np.array([ax]).ravel()

if prediction_intervals:
cols = [
col
for col in forecast_ts.columns.get_level_values("feature").unique().tolist()
if col.startswith("target_0.")
]
existing_quantiles = [float(col[7:]) for col in cols]
if quantiles is None:
quantiles = sorted(existing_quantiles)
else:
non_existent = set(quantiles) - set(existing_quantiles)
if len(non_existent):
warnings.warn(f"Quantiles {non_existent} do not exist in forecast dataset. They will be dropped.")
quantiles = sorted(list(set(quantiles).intersection(set(existing_quantiles))))

if train_ts is not None:
train_ts.df.sort_values(by="timestamp", inplace=True)
if test_ts is not None:
Expand Down Expand Up @@ -86,8 +109,46 @@ def plot_forecast(
if (train_ts is not None) and (n_train_samples != 0):
ax[i].plot(plot_df.index.values, plot_df.target.values, label="train")
if test_ts is not None:
ax[i].plot(segment_test_df.index.values, segment_test_df.target.values, label="test")
ax[i].plot(segment_forecast_df.index.values, segment_forecast_df.target.values, label="forecast")
ax[i].plot(segment_test_df.index.values, segment_test_df.target.values, color="purple", label="test")
ax[i].plot(segment_forecast_df.index.values, segment_forecast_df.target.values, color="r", label="forecast")

if prediction_intervals and quantiles is not None:
alpha = np.linspace(0, 1, len(quantiles) // 2 + 2)[1:-1]
for quantile in range(len(quantiles) // 2):
values_low = segment_forecast_df["target_" + str(quantiles[quantile])].values
values_high = segment_forecast_df["target_" + str(quantiles[-quantile - 1])].values
if quantile == len(quantiles) // 2 - 1:
ax[i].fill_between(
segment_forecast_df.index.values,
values_low,
values_high,
facecolor="g",
alpha=alpha[quantile],
label=f"{quantiles[quantile]}-{quantiles[-quantile-1]} prediction interval",
)
else:
values_next = segment_forecast_df["target_" + str(quantiles[quantile + 1])].values
ax[i].fill_between(
segment_forecast_df.index.values,
values_low,
values_next,
facecolor="g",
alpha=alpha[quantile],
label=f"{quantiles[quantile]}-{quantiles[-quantile-1]} prediction interval",
)
values_prev = segment_forecast_df["target_" + str(quantiles[-quantile - 2])].values
ax[i].fill_between(
segment_forecast_df.index.values, values_high, values_prev, facecolor="g", alpha=alpha[quantile]
)
if len(quantiles) % 2 != 0:
values = segment_forecast_df["target_" + str(quantiles[len(quantiles) // 2])].values
ax[i].plot(
segment_forecast_df.index.values,
values,
"--",
c="orange",
label=f"{quantiles[len(quantiles)//2]} quantile",
)
ax[i].set_title(segment)
ax[i].tick_params("x", rotation=45)
ax[i].legend()
Expand Down

0 comments on commit 95a988c

Please sign in to comment.