Skip to content

Add forecast components handling to base classes of models #1158

Merged
merged 16 commits into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased
### Added
- Target components logic into base classes of models ([#1158](https://github.com/tinkoff-ai/etna/pull/1158))
- Target components logic to TSDataset ([#1153](https://github.com/tinkoff-ai/etna/pull/1153))
- Methods `save` and `load` to HierarchicalPipeline ([#1096](https://github.com/tinkoff-ai/etna/pull/1096))
- New data access methods in `TSDataset` : `update_columns_from_pandas`, `add_columns_from_pandas`, `drop_features` ([#809](https://github.com/tinkoff-ai/etna/pull/809))
Expand Down
54 changes: 46 additions & 8 deletions etna/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,15 @@ def context_size(self) -> int:
return 0

@abstractmethod
def forecast(self, ts: TSDataset) -> TSDataset:
def forecast(self, ts: TSDataset, return_components: bool = False) -> TSDataset:
"""Make predictions.

Parameters
----------
ts:
Dataset with features
return_components:
If True additionally returns forecast components

Returns
-------
Expand All @@ -109,13 +111,15 @@ def forecast(self, ts: TSDataset) -> TSDataset:
pass

@abstractmethod
def predict(self, ts: TSDataset) -> TSDataset:
def predict(self, ts: TSDataset, return_components: bool = False) -> TSDataset:
"""Make predictions with using true values as autoregression context if possible (teacher forcing).

Parameters
----------
ts:
Dataset with features
return_components:
If True additionally returns prediction components

Returns
-------
Expand All @@ -129,7 +133,7 @@ class NonPredictionIntervalContextRequiredAbstractModel(AbstractModel):
"""Interface for models that don't support prediction intervals and need context for prediction."""

@abstractmethod
def forecast(self, ts: TSDataset, prediction_size: int) -> TSDataset:
def forecast(self, ts: TSDataset, prediction_size: int, return_components: bool = False) -> TSDataset:
"""Make predictions.

Parameters
Expand All @@ -139,6 +143,8 @@ def forecast(self, ts: TSDataset, prediction_size: int) -> TSDataset:
prediction_size:
Number of last timestamps to leave after making prediction.
Previous timestamps will be used as a context for models that require it.
return_components:
If True additionally returns forecast components

Returns
-------
Expand All @@ -148,7 +154,7 @@ def forecast(self, ts: TSDataset, prediction_size: int) -> TSDataset:
pass

@abstractmethod
def predict(self, ts: TSDataset, prediction_size: int) -> TSDataset:
def predict(self, ts: TSDataset, prediction_size: int, return_components: bool = False) -> TSDataset:
"""Make predictions with using true values as autoregression context if possible (teacher forcing).

Parameters
Expand All @@ -158,6 +164,8 @@ def predict(self, ts: TSDataset, prediction_size: int) -> TSDataset:
prediction_size:
Number of last timestamps to leave after making prediction.
Previous timestamps will be used as a context for models that require it.
return_components:
If True additionally returns prediction components

Returns
-------
Expand All @@ -180,7 +188,11 @@ def context_size(self) -> int:

@abstractmethod
def forecast(
self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975)
self,
ts: TSDataset,
prediction_interval: bool = False,
quantiles: Sequence[float] = (0.025, 0.975),
return_components: bool = False,
) -> TSDataset:
"""Make predictions.

Expand All @@ -192,6 +204,8 @@ def forecast(
If True returns prediction interval for forecast
quantiles:
Levels of prediction distribution. By default 2.5% and 97.5% are taken to form a 95% prediction interval
return_components:
If True additionally returns forecast components

Returns
-------
Expand All @@ -202,7 +216,11 @@ def forecast(

@abstractmethod
def predict(
self, ts: TSDataset, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975)
self,
ts: TSDataset,
prediction_interval: bool = False,
quantiles: Sequence[float] = (0.025, 0.975),
return_components: bool = False,
) -> TSDataset:
"""Make predictions with using true values as autoregression context if possible (teacher forcing).

Expand All @@ -214,6 +232,8 @@ def predict(
If True returns prediction interval for forecast
quantiles:
Levels of prediction distribution. By default 2.5% and 97.5% are taken to form a 95% prediction interval
return_components:
If True additionally returns prediction components

Returns
-------
Expand All @@ -233,6 +253,7 @@ def forecast(
prediction_size: int,
prediction_interval: bool = False,
quantiles: Sequence[float] = (0.025, 0.975),
return_components: bool = False,
) -> TSDataset:
"""Make predictions.

Expand All @@ -247,6 +268,8 @@ def forecast(
If True returns prediction interval for forecast
quantiles:
Levels of prediction distribution. By default 2.5% and 97.5% are taken to form a 95% prediction interval
return_components:
If True additionally returns forecast components

Returns
-------
Expand All @@ -262,6 +285,7 @@ def predict(
prediction_size: int,
prediction_interval: bool = False,
quantiles: Sequence[float] = (0.025, 0.975),
return_components: bool = False,
) -> TSDataset:
"""Make predictions with using true values as autoregression context if possible (teacher forcing).

Expand All @@ -276,6 +300,8 @@ def predict(
If True returns prediction interval for forecast
quantiles:
Levels of prediction distribution. By default 2.5% and 97.5% are taken to form a 95% prediction interval
return_components:
If True additionally returns prediction components

Returns
-------
Expand Down Expand Up @@ -604,7 +630,7 @@ def raw_predict(self, torch_dataset: "Dataset") -> Dict[Tuple[str, str], np.ndar
return predictions_dict

@log_decorator
def forecast(self, ts: "TSDataset", prediction_size: int) -> "TSDataset":
def forecast(self, ts: "TSDataset", prediction_size: int, return_components: bool = False) -> "TSDataset":
"""Make predictions.

This method will make autoregressive predictions.
Expand All @@ -616,12 +642,17 @@ def forecast(self, ts: "TSDataset", prediction_size: int) -> "TSDataset":
prediction_size:
Number of last timestamps to leave after making prediction.
Previous timestamps will be used as a context.
return_components:
If True additionally returns forecast components

Returns
-------
:
Dataset with predictions
"""
if return_components:
raise NotImplementedError("This mode isn't currently implemented!")

test_dataset = ts.to_torch_dataset(
make_samples=functools.partial(
self.net.make_samples, encoder_length=self.encoder_length, decoder_length=prediction_size
Expand All @@ -636,7 +667,12 @@ def forecast(self, ts: "TSDataset", prediction_size: int) -> "TSDataset":
return future_ts

@log_decorator
def predict(self, ts: "TSDataset", prediction_size: int) -> "TSDataset":
def predict(
self,
ts: "TSDataset",
prediction_size: int,
return_components: bool = False,
) -> "TSDataset":
"""Make predictions.

This method will make predictions using true values instead of predicted on a previous step.
Expand All @@ -649,6 +685,8 @@ def predict(self, ts: "TSDataset", prediction_size: int) -> "TSDataset":
prediction_size:
Number of last timestamps to leave after making prediction.
Previous timestamps will be used as a context.
return_components:
If True additionally returns prediction components

Returns
-------
Expand Down
Loading