Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/historical retrain on condition #1139

Merged
merged 17 commits into from
Aug 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ Darts is still in an early development phase and we cannot always guarantee back

## [Unreleased](https://github.com/unit8co/darts/tree/master)

- Added support for retraining model(s) every `n` iteration and on custom condition in `historical_forecasts` method of `ForecastingModel` abstract class. Addressed issues [#135](https://github.com/unit8co/darts/issues/135) and [#623](https://github.com/unit8co/darts/issues/623) by [Francesco Bruzzesi](https://github.com/fbruzzesi).

[Full Changelog](https://github.com/unit8co/darts/compare/0.21.0...master)


Expand Down Expand Up @@ -797,4 +799,4 @@ All implementations of `GlobalForecastingModel`s support multivariate time serie

## Older releases

[Full Changelog](https://github.com/unit8co/darts/compare/f618c4536bf7ed6e3b6a2239fbca4e3089736426...0.2.3)
[Full Changelog](https://github.com/unit8co/darts/compare/f618c4536bf7ed6e3b6a2239fbca4e3089736426...0.2.3)
81 changes: 71 additions & 10 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
_build_tqdm_iterator,
_historical_forecasts_general_checks,
_parallel_apply,
_retrain_wrapper,
_with_sanity_checks,
)
from darts.utils.data.encoders import SequentialEncoder
Expand Down Expand Up @@ -303,7 +304,7 @@ def historical_forecasts(
start: Union[pd.Timestamp, float, int] = 0.5,
forecast_horizon: int = 1,
stride: int = 1,
retrain: bool = True,
retrain: Union[bool, int, Callable[..., bool]] = True,
overlap_end: bool = False,
last_points_only: bool = True,
verbose: bool = False,
Expand Down Expand Up @@ -357,8 +358,22 @@ def historical_forecasts(
stride
The number of time steps between two consecutive predictions.
retrain
Whether to retrain the model for every prediction or not. Not all models support setting
`retrain` to `False`. Notably, this is supported by neural networks based models.
Whether and/or on which condition to retrain the model before predicting.
This parameter supports 3 different datatypes: ``bool``, (positive) ``int``, and
``Callable`` (returning a ``bool``).
In the case of ``bool``: retrain the model at each step (`True`), or never retrains the model (`False`).
In the case of ``int``: the model is retrained every `retrain` iterations.
In the case of ``Callable``: the model is retrained whenever callable returns `True`.
Arguments passed to the callable are as follows:
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved

- `pred_time (pd.Timestamp or int)`: timestamp of forecast time (end of the training series)
- `train_series (TimeSeries)`: train series up to `pred_time`
- `past_covariates (TimeSeries)`: past_covariates series up to `pred_time`
- `future_covariates (TimeSeries)`: future_covariates series up
to `min(pred_time + series.freq * forecast_horizon, series.end_time())`

Note: some models do require being retrained every time
and do not support anything else than `retrain=True`.
overlap_end
Whether the returned forecasts can go beyond the series' end or not
last_points_only
Expand All @@ -384,10 +399,11 @@ def historical_forecasts(
# only GlobalForecastingModels support historical forecastings without retraining the model
base_class_name = self.__class__.__base__.__name__
raise_if(
not retrain and not self._supports_non_retrainable_historical_forecasts(),
(isinstance(retrain, Callable) or int(retrain) != 1)
and (not self._supports_non_retrainable_historical_forecasts()),
f"{base_class_name} does not support historical forecastings with `retrain` set to `False`. "
f"For now, this is only supported with GlobalForecastingModels such as TorchForecastingModels. "
f"Fore more information, read the documentation for `retrain` in `historical_forecastings()`",
f"For more information, read the documentation for `retrain` in `historical_forecasts()`",
logger,
)

Expand All @@ -409,6 +425,25 @@ def historical_forecasts(
logger,
)

if isinstance(retrain, bool) or (isinstance(retrain, int) and retrain >= 0):
retrain_func = _retrain_wrapper(
hrzn marked this conversation as resolved.
Show resolved Hide resolved
lambda counter: counter % int(retrain) == 0 if retrain else False
)

elif isinstance(retrain, Callable):
retrain_func = _retrain_wrapper(retrain)

else:
raise_log(
ValueError(
"`retrain` argument must be either `bool`, positive `int` or `Callable` (returning `bool`)"
),
logger,
)
retrain_func_signature = tuple(
inspect.signature(retrain_func).parameters.keys()
)

# prepare the start parameter -> pd.Timestamp
start = series.get_timestamp_at_point(start)

Expand All @@ -435,15 +470,27 @@ def historical_forecasts(
last_points_values = []

# iterate and forecast
for pred_time in iterator:
for _counter, pred_time in enumerate(iterator):
# build the training series
train = series.drop_after(pred_time)
if train_length and len(train) > train_length:
train = train[-train_length:]

# train_cov = covariates.drop_after(pred_time) if covariates else None

if retrain or not self._fit_called:
if (not self._fit_called) or retrain_func(
counter=_counter,
pred_time=pred_time,
train_series=train,
past_covariates=past_covariates.drop_after(pred_time)
if past_covariates and ("past_covariates" in retrain_func_signature)
else None,
future_covariates=future_covariates.drop_after(
hrzn marked this conversation as resolved.
Show resolved Hide resolved
min(pred_time + series.freq * forecast_horizon, series.end_time())
)
if future_covariates and ("future_covariates" in retrain_func_signature)
else None,
):
self._fit_wrapper(
series=train,
past_covariates=past_covariates,
Expand Down Expand Up @@ -498,7 +545,7 @@ def backtest(
start: Union[pd.Timestamp, float, int] = 0.5,
forecast_horizon: int = 1,
stride: int = 1,
retrain: bool = True,
retrain: Union[bool, int, Callable[..., bool]] = True,
overlap_end: bool = False,
last_points_only: bool = False,
metric: Callable[[TimeSeries, TimeSeries], float] = metrics.mape,
Expand Down Expand Up @@ -556,8 +603,22 @@ def backtest(
stride
The number of time steps between two consecutive training sets.
retrain
Whether to retrain the model for every prediction or not. Not all models support setting
`retrain` to `False`. Notably, this is supported by neural networks based models.
Whether and/or on which condition to retrain the model before predicting.
This parameter supports 3 different datatypes: ``bool``, (positive) ``int``, and
``Callable`` (returning a ``bool``).
In the case of ``bool``: retrain the model at each step (`True`), or never retrains the model (`False`).
In the case of ``int``: the model is retrained every `retrain` iterations.
In the case of ``Callable``: the model is retrained whenever callable returns `True`.
Arguments passed to the callable are as follows:
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved

- `pred_time (pd.Timestamp or int)`: timestamp of forecast time (end of the training series)
- `train_series (TimeSeries)`: train series up to `pred_time`
- `past_covariates (TimeSeries)`: past_covariates series up to `pred_time`
- `future_covariates (TimeSeries)`: future_covariates series up
to `min(pred_time + series.freq * forecast_horizon, series.end_time())`

Note: some models do require being retrained every time
and do not support anything else than `retrain=True`.
overlap_end
Whether the returned forecasts can go beyond the series' end or not
last_points_only
Expand Down
118 changes: 118 additions & 0 deletions darts/tests/models/forecasting/test_local_forecasting_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import shutil
import tempfile
from typing import Callable
from unittest.mock import Mock, patch

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -376,3 +378,119 @@ def test_statsmodels_dual_models(self):
# check backtesting with retrain=False
model: TransferableDualCovariatesForecastingModel = model_cls(**kwargs)
model.backtest(series1, future_covariates=exog1, retrain=False)

@patch("typing.Callable")
def test_backtest_retrain(
self,
patch_retrain_func,
):
"""
Test backtest method with different retrain arguments
"""

series = self.ts_pass_train

lr_univ_args = {"lags": [-1, -2, -3]}

lr_multi_args = {
"lags": [-1, -2, -3],
"lags_past_covariates": [-1, -2, -3],
}
params = [ # tuple of (model, retrain-able, multivariate, retrain parameter, model type)
(ExponentialSmoothing(), False, False, "hello", "ForecastingModel"),
(ExponentialSmoothing(), False, False, True, "ForecastingModel"),
(ExponentialSmoothing(), False, False, -2, "ForecastingModel"),
(ExponentialSmoothing(), False, False, 2, "ForecastingModel"),
(
ExponentialSmoothing(),
False,
False,
patch_retrain_func,
"ForecastingModel",
),
(
LinearRegressionModel(**lr_univ_args),
True,
False,
True,
"GlobalForecastingModel",
),
(
LinearRegressionModel(**lr_univ_args),
True,
False,
2,
"GlobalForecastingModel",
),
(
LinearRegressionModel(**lr_univ_args),
True,
False,
patch_retrain_func,
"GlobalForecastingModel",
),
(
LinearRegressionModel(**lr_multi_args),
True,
True,
True,
"GlobalForecastingModel",
),
(
LinearRegressionModel(**lr_multi_args),
True,
True,
2,
"GlobalForecastingModel",
),
(
LinearRegressionModel(**lr_multi_args),
True,
True,
patch_retrain_func,
"GlobalForecastingModel",
),
]

for model_cls, retrainable, multivariate, retrain, model_type in params:

if (
not isinstance(retrain, (int, bool, Callable))
or (isinstance(retrain, int) and retrain < 0)
or (isinstance(retrain, (Callable)) and (not retrainable))
or ((retrain != 1) and (not retrainable))
):
with self.assertRaises(ValueError):
_ = model_cls.historical_forecasts(series, retrain=retrain)

else:

if isinstance(retrain, Mock):
# resets patch_retrain_func call_count to 0
retrain.call_count = 0
retrain.side_effect = [True, False] * (len(series) // 2)

fit_method_to_patch = f"darts.models.forecasting.forecasting_model.{model_type}._fit_wrapper"
predict_method_to_patch = f"darts.models.forecasting.forecasting_model.{model_type}._predict_wrapper"

with patch(fit_method_to_patch) as patch_fit_method:
with patch(
predict_method_to_patch, side_effect=series
) as patch_predict_method:

# Set _fit_called attribute to True, otherwise retrain function is never called
model_cls._fit_called = True

# run backtest
_ = model_cls.historical_forecasts(
series,
past_covariates=series if multivariate else None,
retrain=retrain,
)

assert patch_predict_method.call_count > 1
assert patch_fit_method.call_count > 1

if isinstance(retrain, Mock):
# check that patch_retrain_func has been called at each iteration
assert retrain.call_count > 1
1 change: 1 addition & 0 deletions darts/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
_build_tqdm_iterator,
_historical_forecasts_general_checks,
_parallel_apply,
_retrain_wrapper,
_with_sanity_checks,
retain_period_common_to_all,
)
19 changes: 19 additions & 0 deletions darts/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,22 @@ def _check_quantiles(quantiles):
"quantiles lower than `q=0.5` need to share same difference to `0.5` as quantiles "
"higher than `q=0.5`",
)


def _retrain_wrapper(func: Callable[..., bool]):
"""Utility function that keeps original signature in `retrain` function param in `historical_forecasts` method"""

@wraps(func)
def wrapper(*args, **kwargs):

original_signature = tuple(signature(func).parameters.keys())
result = func(
*args, **{k: v for k, v in kwargs.items() if k in original_signature}
)

if not isinstance(result, bool):
raise_log(ValueError("Return value of `retrain` must be bool"), logger)

return result

return wrapper