Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/local and global models in EnsembleModel #1745

Merged
merged 24 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c5786cb
feat: remove restriction in EnsembleModel, models can be a mix of Loc…
madtoinou May 4, 2023
9dc9001
feat: EnsembleModel accepts a mixture of local and global models for …
madtoinou May 5, 2023
7467a4c
feat: updated unittests
madtoinou May 5, 2023
0ca18ca
doc: fix typo in docstring, SeasonalityMode must be imported from dar…
madtoinou May 5, 2023
838b40c
doc: updated changelog
madtoinou May 5, 2023
bad67e0
Merge branch 'master' into feat/local-and-global-ensemble
madtoinou May 8, 2023
f3d060e
Merge branch 'master' into feat/local-and-global-ensemble
dennisbader May 15, 2023
608fb03
Merge branch 'master' into feat/local-and-global-ensemble
madtoinou May 16, 2023
5035cf4
feat: logger info when all the models in the ensemble do not support …
madtoinou May 17, 2023
e7fa8e4
fix: typo, using parenthesis to call proterty method
madtoinou May 17, 2023
e48e754
Apply suggestions from code review
madtoinou May 22, 2023
ce6606b
fix: made the covariates handling in ensemble model more transparent,…
madtoinou May 22, 2023
62397ef
Merge branch 'feat/local-and-global-ensemble' of https://github.com/u…
madtoinou May 22, 2023
fe80f65
Merge branch 'master' into feat/local-and-global-ensemble
madtoinou May 22, 2023
9011104
Merge branch 'master' into feat/local-and-global-ensemble
dennisbader May 23, 2023
a0cf04e
Update CHANGELOG.md
dennisbader May 23, 2023
0423418
Update CHANGELOG.md
dennisbader May 23, 2023
1efbcc1
Update CHANGELOG.md
dennisbader May 23, 2023
29d39be
Apply suggestions from code review
madtoinou May 24, 2023
93a9cd7
fix: addressed reviewer comments, added show_warning arg to ensemble_…
madtoinou May 24, 2023
d4a5370
fix typo
madtoinou May 24, 2023
79ee5ff
fix: improve warning synthax
madtoinou May 24, 2023
33fd7c5
Merge branch 'master' into feat/local-and-global-ensemble
madtoinou May 24, 2023
6a0970d
Merge branch 'master' into feat/local-and-global-ensemble
dennisbader May 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@ We do our best to avoid the introduction of breaking changes,
but cannot always guarantee backwards compatibility. Changes that may **break code which uses a previous release of Darts** are marked with a "🔴".

## [Unreleased](https://github.com/unit8co/darts/tree/master)

[Full Changelog](https://github.com/unit8co/darts/compare/0.24.0...master)

### For users of the library:

**Improved**
- Added support for `PathLike` to the `save()` and `load()` functions of `ForecastingModel`. [#1754](https://github.com/unit8co/darts/pull/1754) by [Simon Sudrich](https://github.com/sudrich).
- General model improvements:
- Added support for `PathLike` to the `save()` and `load()` functions of all non-deep learning based models. [#1754](https://github.com/unit8co/darts/pull/1754) by [Simon Sudrich](https://github.com/sudrich).
- Improvements to `EnsembleModel`:
- Model creation parameter `forecasting_models` now supports a mix of `LocalForecastingModel` and `GlobalForecastingModel` (single `TimeSeries` training/inference only, due to the local models). [#1745](https://github.com/unit8co/darts/pull/1745) by [Antoine Madrona](https://github.com/madtoinou).
- Future and past covariates can now be used even if `forecasting_models` have different covariates support. The covariates passed to `fit()`/`predict()` are used only by models that support it. [#1745](https://github.com/unit8co/darts/pull/1745) by [Antoine Madrona](https://github.com/madtoinou).

## [0.24.0](https://github.com/unit8co/darts/tree/0.24.0) (2023-04-12)
### For users of the library:
Expand Down
18 changes: 9 additions & 9 deletions darts/models/forecasting/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ def __init__(

Naive implementation of `EnsembleModel`
Returns the average of all predictions of the constituent models

If `future_covariates` or `past_covariates` are provided at training or inference time,
they will be passed only to the models supporting them.
"""
super().__init__(models)

Expand All @@ -184,15 +187,12 @@ def fit(
future_covariates=future_covariates,
)
for model in self.models:
if self.is_global_ensemble:
kwargs = dict(series=series)
if model.supports_past_covariates:
kwargs["past_covariates"] = past_covariates
if model.supports_future_covariates:
kwargs["future_covariates"] = future_covariates
model.fit(**kwargs)
else:
model.fit(series=series)
kwargs = dict(series=series)
if model.supports_past_covariates:
kwargs["past_covariates"] = past_covariates
if model.supports_future_covariates:
kwargs["future_covariates"] = future_covariates
model.fit(**kwargs)

return self

Expand Down
118 changes: 95 additions & 23 deletions darts/models/forecasting/ensemble_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from darts.logging import get_logger, raise_if, raise_if_not
from darts.models.forecasting.forecasting_model import (
ForecastingModel,
GlobalForecastingModel,
LocalForecastingModel,
)
Expand All @@ -22,31 +23,41 @@ class EnsembleModel(GlobalForecastingModel):
Ensemble models take in a list of forecasting models and ensemble their predictions
to make a single one according to the rule defined by their `ensemble()` method.

If `future_covariates` or `past_covariates` are provided at training or inference time,
they will be passed only to the models supporting them.

Parameters
----------
models
List of forecasting models whose predictions to ensemble
"""

def __init__(
self, models: Union[List[LocalForecastingModel], List[GlobalForecastingModel]]
):
def __init__(self, models: List[ForecastingModel]):
raise_if_not(
isinstance(models, list) and models,
"Cannot instantiate EnsembleModel with an empty list of models",
logger,
)

is_local_ensemble = all(
isinstance(model, LocalForecastingModel) for model in models
)
self.is_global_ensemble = all(
is_local_model = [isinstance(model, LocalForecastingModel) for model in models]
is_global_model = [
isinstance(model, GlobalForecastingModel) for model in models
)
]

self.is_local_ensemble = all(is_local_model)
self.is_global_ensemble = all(is_global_model)

raise_if_not(
is_local_ensemble or self.is_global_ensemble,
"All models must be of the same type: either GlobalForecastingModel, or LocalForecastingModel.",
all(
[
local_model or global_model
for local_model, global_model in zip(
is_local_model, is_global_model
)
]
),
"All models must be of type `GlobalForecastingModel`, or `LocalForecastingModel`. "
"Also, make sure that all models in `models` are instantiated.",
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
logger,
)

Expand All @@ -60,6 +71,21 @@ def __init__(
super().__init__()
self.models = models

if self.supports_past_covariates and not self._full_past_covariates_support():
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
logger.info(
"Some models in the ensemble do not support past covariates, the past covariates will be "
"provided only to the models supporting them when calling fit/predict."
)

if (
self.supports_future_covariates
and not self._full_future_covariates_support()
):
logger.info(
"Some models in the ensemble do not support future covariates, the future covariates will be "
"provided only to the models supporting them when calling fit/predict."
)

def fit(
self,
series: Union[TimeSeries, Sequence[TimeSeries]],
Expand All @@ -71,34 +97,42 @@ def fit(
Note that `EnsembleModel.fit()` does NOT call `fit()` on each of its constituent forecasting models.
It is left to classes inheriting from EnsembleModel to do so appropriately when overriding `fit()`
"""

is_single_series = isinstance(series, TimeSeries)

# local models OR mix of local and global models
raise_if(
not self.is_global_ensemble and not isinstance(series, TimeSeries),
"The models are of type LocalForecastingModel, which does not support training on multiple series.",
not self.is_global_ensemble and not is_single_series,
"The models contain at least one LocalForecastingModel, which does not support training on multiple "
"series.",
logger,
)
raise_if(
not self.is_global_ensemble and past_covariates is not None,
self.is_local_ensemble and past_covariates is not None,
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
"The models are of type LocalForecastingModel, which does not support past covariates.",
logger,
)

is_single_series = isinstance(series, TimeSeries)

# check that if timeseries is single series, than covariates are as well and vice versa
error = False
# check that if timeseries is single series, that covariates are as well and vice versa
error_past_cov = False
error_future_cov = False

if past_covariates is not None:
error = is_single_series != isinstance(past_covariates, TimeSeries)
error_past_cov = is_single_series != isinstance(past_covariates, TimeSeries)

if future_covariates is not None:
error = is_single_series != isinstance(future_covariates, TimeSeries)
error_future_cov = is_single_series != isinstance(
future_covariates, TimeSeries
)

raise_if(
error,
"Both series and covariates have to be either univariate or multivariate.",
error_past_cov or error_future_cov,
"Both series and covariates have to be either single TimeSeries or sequences of TimeSeries.",
logger,
)

self._verify_past_future_covariates(past_covariates, future_covariates)

super().fit(series, past_covariates, future_covariates)

return self
Expand All @@ -125,12 +159,17 @@ def _make_multiple_predictions(
num_samples: int = 1,
):
is_single_series = isinstance(series, TimeSeries) or series is None
# maximize covariate usage
predictions = [
model._predict_wrapper(
n=n,
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
past_covariates=past_covariates
if model.supports_past_covariates
else None,
future_covariates=future_covariates
if model.supports_future_covariates
else None,
num_samples=num_samples,
)
for model in self.models
Expand Down Expand Up @@ -160,6 +199,8 @@ def predict(
verbose=verbose,
)

self._verify_past_future_covariates(past_covariates, future_covariates)

predictions = self._make_multiple_predictions(
n=n,
series=series,
Expand Down Expand Up @@ -229,3 +270,34 @@ def find_max_lag_or_none(lag_id, aggregator) -> Optional[int]:

def _is_probabilistic(self) -> bool:
return all([model._is_probabilistic() for model in self.models])

@property
def supports_past_covariates(self) -> bool:
return any([model.supports_past_covariates for model in self.models])

@property
def supports_future_covariates(self) -> bool:
return any([model.supports_future_covariates for model in self.models])

def _full_past_covariates_support(self) -> bool:
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
return all([model.supports_past_covariates for model in self.models])

def _full_future_covariates_support(self) -> bool:
return all([model.supports_future_covariates for model in self.models])

def _verify_past_future_covariates(self, past_covariates, future_covariates):
"""
Verify that any non-None covariates comply with the model type.
"""
raise_if(
past_covariates is not None and not self.supports_past_covariates,
"Some past_covariates have been provided to a EnsembleModel containing no models "
"supporting such covariates.",
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
logger,
)
raise_if(
future_covariates is not None and not self.supports_future_covariates,
"Some future_covariates have been provided to a Ensemble model containing no models "
"supporting such covariates.",
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
logger,
)
23 changes: 14 additions & 9 deletions darts/models/forecasting/regression_ensemble_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@

from darts.logging import get_logger, raise_if, raise_if_not
from darts.models.forecasting.ensemble_model import EnsembleModel
from darts.models.forecasting.forecasting_model import (
GlobalForecastingModel,
LocalForecastingModel,
)
from darts.models.forecasting.forecasting_model import ForecastingModel
from darts.models.forecasting.linear_regression_model import LinearRegressionModel
from darts.models.forecasting.regression_model import RegressionModel
from darts.timeseries import TimeSeries
Expand All @@ -23,9 +20,7 @@
class RegressionEnsembleModel(EnsembleModel):
def __init__(
self,
forecasting_models: Union[
List[LocalForecastingModel], List[GlobalForecastingModel]
],
forecasting_models: List[ForecastingModel],
regression_train_n_points: int,
regression_model=None,
):
Expand All @@ -38,6 +33,11 @@ def __init__(
as in :class:`RegressionModel`, where the regression model is used to produce forecasts based on the
lagged series.

If `future_covariates` or `past_covariates` are provided at training or inference time,
they will be passed only to the forecasting models supporting them.

The regression model does not leverage the covariates passed to ``fit()`` and ``predict()``.

Parameters
----------
forecasting_models
Expand Down Expand Up @@ -115,10 +115,15 @@ def fit(
)

for model in self.models:
# maximize covariate usage
model._fit_wrapper(
series=forecast_training,
past_covariates=past_covariates,
future_covariates=future_covariates,
past_covariates=past_covariates
if model.supports_past_covariates
else None,
future_covariates=future_covariates
if model.supports_future_covariates
else None,
)

predictions = self._make_multiple_predictions(
Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/theta.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(

`season_mode` must be a ``SeasonalityMode`` Enum member.

You can access the Enum with ``from darts import SeasonalityMode``.
You can access the Enum with ``from darts.utils.utils import SeasonalityMode``.
madtoinou marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
Expand Down
35 changes: 34 additions & 1 deletion darts/tests/models/forecasting/test_ensemble_models.py
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
NaiveEnsembleModel,
NaiveSeasonal,
RegressionEnsembleModel,
StatsForecastAutoARIMA,
Theta,
)
from darts.tests.base_test_class import DartsBaseTestClass
Expand Down Expand Up @@ -84,8 +85,10 @@ def test_extreme_lag_inference(self):
def test_input_models_local_models(self):
with self.assertRaises(ValueError):
NaiveEnsembleModel([])
# models are not instantiated
with self.assertRaises(ValueError):
NaiveEnsembleModel([NaiveDrift, NaiveSeasonal, Theta, ExponentialSmoothing])
# one model is not instantiated
with self.assertRaises(ValueError):
NaiveEnsembleModel(
[NaiveDrift(), NaiveSeasonal, Theta(), ExponentialSmoothing()]
Expand Down Expand Up @@ -136,6 +139,9 @@ def test_stochastic_ensemble(self):

@unittest.skipUnless(TORCH_AVAILABLE, "requires torch")
def test_input_models_global_models(self):
# one model is not instantiated
with self.assertRaises(ValueError):
NaiveEnsembleModel([RNNModel(12), TCNModel(10, 2), NBEATSModel])
NaiveEnsembleModel([RNNModel(12), TCNModel(10, 2), NBEATSModel(10, 2)])

@unittest.skipUnless(TORCH_AVAILABLE, "requires torch")
Expand Down Expand Up @@ -183,8 +189,35 @@ def test_call_predict_global_models_multivariate_input_with_covariates(self):

@unittest.skipUnless(TORCH_AVAILABLE, "requires torch")
def test_input_models_mixed(self):
# NaiveDrift is local, RNNModel is global
naive_ensemble = NaiveEnsembleModel([NaiveDrift(), RNNModel(12, n_epochs=1)])
# ensemble is neither local, nor global
self.assertFalse(naive_ensemble.is_local_ensemble)
self.assertFalse(naive_ensemble.is_global_ensemble)

# ensemble contains one local model, no support for multiple ts fit
with self.assertRaises(ValueError):
naive_ensemble.fit([self.series1, self.series2])

@unittest.skipUnless(TORCH_AVAILABLE, "requires torch")
def test_mixed_models_with_covariates(self):
dennisbader marked this conversation as resolved.
Show resolved Hide resolved
naive_ensemble_one_covs = NaiveEnsembleModel(
[NaiveDrift(), RNNModel(12, n_epochs=1)]
)
# none of the models support past covariates
with self.assertRaises(ValueError):
naive_ensemble_one_covs.fit(self.series1, past_covariates=self.series2)
# only RNN supports future covariates
naive_ensemble_one_covs.fit(self.series1, future_covariates=self.series2)

naive_ensemble_future_covs = NaiveEnsembleModel(
[StatsForecastAutoARIMA(), RNNModel(12, n_epochs=1)]
)
# none of the models support past covariates
with self.assertRaises(ValueError):
NaiveEnsembleModel([NaiveDrift(), Theta(), RNNModel(12)])
naive_ensemble_future_covs.fit(self.series1, past_covariates=self.series2)
# both models supports future covariates
naive_ensemble_future_covs.fit(self.series1, future_covariates=self.series2)

def test_fit_multivar_ts_with_local_models(self):
naive = NaiveEnsembleModel(
Expand Down