Skip to content

Commit

Permalink
Fix/ensemble predict with series (#1357)
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisbader committed Nov 13, 2022
1 parent a0ebdfd commit 2590a69
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 36 deletions.
24 changes: 10 additions & 14 deletions darts/models/forecasting/ensemble_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __init__(

super().__init__()
self.models = models
self.is_single_series = None

def fit(
self,
Expand All @@ -83,16 +82,16 @@ def fit(
logger,
)

self.is_single_series = isinstance(series, TimeSeries)
is_single_series = isinstance(series, TimeSeries)

# check that if timeseries is single series, than covariates are as well and vice versa
error = False

if past_covariates is not None:
error = self.is_single_series != isinstance(past_covariates, TimeSeries)
error = is_single_series != isinstance(past_covariates, TimeSeries)

if future_covariates is not None:
error = self.is_single_series != isinstance(future_covariates, TimeSeries)
error = is_single_series != isinstance(future_covariates, TimeSeries)

raise_if(
error,
Expand Down Expand Up @@ -125,6 +124,7 @@ def _make_multiple_predictions(
future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
num_samples: int = 1,
):
is_single_series = isinstance(series, TimeSeries) or series is None
predictions = [
model._predict_wrapper(
n=n,
Expand All @@ -135,11 +135,11 @@ def _make_multiple_predictions(
)
for model in self.models
]

if self.is_single_series:
return self._stack_ts_seq(predictions)
else:
return self._stack_ts_multiseq(predictions)
return (
self._stack_ts_seq(predictions)
if is_single_series
else self._stack_ts_multiseq(predictions)
)

def predict(
self,
Expand All @@ -165,11 +165,7 @@ def predict(
future_covariates=future_covariates,
num_samples=num_samples,
)

if self.is_single_series:
return self.ensemble(predictions)
else:
return self.ensemble(predictions, series)
return self.ensemble(predictions, series=series)

@abstractmethod
def ensemble(
Expand Down
16 changes: 9 additions & 7 deletions darts/models/forecasting/regression_ensemble_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from darts.models.forecasting.linear_regression_model import LinearRegressionModel
from darts.models.forecasting.regression_model import RegressionModel
from darts.timeseries import TimeSeries
from darts.utils.utils import seq2series, series2seq

logger = get_logger(__name__)

Expand Down Expand Up @@ -90,7 +91,8 @@ def fit(
)

# spare train_n_points points to serve as regression target
if self.is_single_series:
is_single_series = isinstance(series, TimeSeries)
if is_single_series:
train_n_points_too_big = len(self.training_series) <= self.train_n_points
else:
train_n_points_too_big = any(
Expand All @@ -104,7 +106,7 @@ def fit(
logger,
)

if self.is_single_series:
if is_single_series:
forecast_training = self.training_series[: -self.train_n_points]
regression_target = self.training_series[-self.train_n_points :]
else:
Expand Down Expand Up @@ -156,15 +158,15 @@ def ensemble(
predictions: Union[TimeSeries, Sequence[TimeSeries]],
series: Optional[Sequence[TimeSeries]] = None,
) -> Union[TimeSeries, Sequence[TimeSeries]]:
if self.is_single_series:
predictions = [predictions]
series = [series]

is_single_series = isinstance(series, TimeSeries) or series is None
predictions = series2seq(predictions)
series = series2seq(series) if series is not None else [None]

ensembled = [
self.regression_model.predict(
n=len(prediction), series=serie, future_covariates=prediction
)
for serie, prediction in zip(series, predictions)
]

return ensembled[0] if self.is_single_series else ensembled
return seq2series(ensembled) if is_single_series else ensembled
67 changes: 67 additions & 0 deletions darts/tests/models/forecasting/test_ensemble_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from darts.logging import get_logger
from darts.models import (
ExponentialSmoothing,
LinearRegressionModel,
NaiveDrift,
NaiveEnsembleModel,
NaiveSeasonal,
Expand Down Expand Up @@ -148,6 +149,72 @@ def test_fit_univar_ts_with_covariates_for_local_models(self):
with self.assertRaises(ValueError):
naive.fit(self.series1, self.series2)

def test_predict_with_target(self):
series_long = self.series1
series_short = series_long[:25]

# train with a single series
ensemble_model = self.get_global_ensembe_model()
ensemble_model.fit(series_short, past_covariates=series_long)
# predict after end of train series
preds = ensemble_model.predict(n=5, past_covariates=series_long)
self.assertTrue(isinstance(preds, TimeSeries))
# predict a new target series
preds = ensemble_model.predict(
n=5, series=series_long, past_covariates=series_long
)
self.assertTrue(isinstance(preds, TimeSeries))
# predict multiple target series
preds = ensemble_model.predict(
n=5, series=[series_long] * 2, past_covariates=[series_long] * 2
)
self.assertTrue(isinstance(preds, list) and len(preds) == 2)
# predict single target series in list
preds = ensemble_model.predict(
n=5, series=[series_long], past_covariates=[series_long]
)
self.assertTrue(isinstance(preds, list) and len(preds) == 1)

# train with multiple series
ensemble_model = self.get_global_ensembe_model()
ensemble_model.fit([series_short] * 2, past_covariates=[series_long] * 2)
with self.assertRaises(ValueError):
# predict without passing series should raise an error
ensemble_model.predict(n=5, past_covariates=series_long)
# predict a new target series
preds = ensemble_model.predict(
n=5, series=series_long, past_covariates=series_long
)
self.assertTrue(isinstance(preds, TimeSeries))
# predict multiple target series
preds = ensemble_model.predict(
n=5, series=[series_long] * 2, past_covariates=[series_long] * 2
)
self.assertTrue(isinstance(preds, list) and len(preds) == 2)
# predict single target series in list
preds = ensemble_model.predict(
n=5, series=[series_long], past_covariates=[series_long]
)
self.assertTrue(isinstance(preds, list) and len(preds) == 1)

@staticmethod
def get_global_ensembe_model(output_chunk_length=5):
lags = [-1, -2, -5]
return NaiveEnsembleModel(
models=[
LinearRegressionModel(
lags=lags,
lags_past_covariates=lags,
output_chunk_length=output_chunk_length,
),
LinearRegressionModel(
lags=lags,
lags_past_covariates=lags,
output_chunk_length=output_chunk_length,
),
],
)


if __name__ == "__main__":
import unittest
Expand Down
92 changes: 77 additions & 15 deletions darts/tests/models/forecasting/test_regression_ensemble_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@
from darts import TimeSeries
from darts.logging import get_logger
from darts.metrics import rmse
from darts.models import NaiveDrift, NaiveSeasonal
from darts.models import (
LinearRegressionModel,
NaiveDrift,
NaiveSeasonal,
RandomForest,
RegressionEnsembleModel,
RegressionModel,
)
from darts.tests.base_test_class import DartsBaseTestClass
from darts.tests.models.forecasting.test_ensemble_models import _make_ts
from darts.tests.models.forecasting.test_regression_models import train_test_split
Expand All @@ -19,14 +26,7 @@
try:
import torch

from darts.models import (
BlockRNNModel,
LinearRegressionModel,
RandomForest,
RegressionEnsembleModel,
RegressionModel,
RNNModel,
)
from darts.models import BlockRNNModel, RNNModel

TORCH_AVAILABLE = True
except ImportError:
Expand Down Expand Up @@ -85,7 +85,25 @@ def get_global_models(self, output_chunk_length=5):
),
]

@unittest.skipUnless(TORCH_AVAILABLE, "requires torch")
@staticmethod
def get_global_ensembe_model(output_chunk_length=5):
lags = [-1, -2, -5]
return RegressionEnsembleModel(
forecasting_models=[
LinearRegressionModel(
lags=lags,
lags_past_covariates=lags,
output_chunk_length=output_chunk_length,
),
LinearRegressionModel(
lags=lags,
lags_past_covariates=lags,
output_chunk_length=output_chunk_length,
),
],
regression_train_n_points=10,
)

def test_accepts_different_regression_models(self):
regr1 = LinearRegression()
regr2 = RandomForestRegressor()
Expand All @@ -101,7 +119,6 @@ def test_accepts_different_regression_models(self):
model.fit(series=self.combined)
model.predict(10)

@unittest.skipUnless(TORCH_AVAILABLE, "requires torch")
def test_accepts_one_model(self):
regr1 = LinearRegression()
regr2 = RandomForest(lags_future_covariates=[0])
Expand All @@ -115,12 +132,11 @@ def test_accepts_one_model(self):
model.fit(series=self.combined)
model.predict(10)

@unittest.skipUnless(TORCH_AVAILABLE, "requires torch")
def test_train_n_points(self):
regr = LinearRegressionModel(lags_future_covariates=[0])

# same values
ensemble = RegressionEnsembleModel(self.get_local_models(), 5, regr)
_ = RegressionEnsembleModel(self.get_local_models(), 5, regr)

# too big value to perform the split
ensemble = RegressionEnsembleModel(self.get_local_models(), 100)
Expand Down Expand Up @@ -182,7 +198,54 @@ def test_train_predict_global_models_multivar_with_covariates(self):
ensemble.fit(self.seq1, self.cov1)
ensemble.predict(10, self.seq2, self.cov2)

@unittest.skipUnless(TORCH_AVAILABLE, "requires torch")
def test_predict_with_target(self):
series_long = self.combined
series_short = series_long[:25]

# train with a single series
ensemble_model = self.get_global_ensembe_model()
ensemble_model.fit(series_short, past_covariates=series_long)
# predict after end of train series
preds = ensemble_model.predict(n=5, past_covariates=series_long)
self.assertTrue(isinstance(preds, TimeSeries))
# predict a new target series
preds = ensemble_model.predict(
n=5, series=series_long, past_covariates=series_long
)
self.assertTrue(isinstance(preds, TimeSeries))
# predict multiple target series
preds = ensemble_model.predict(
n=5, series=[series_long] * 2, past_covariates=[series_long] * 2
)
self.assertTrue(isinstance(preds, list) and len(preds) == 2)
# predict single target series in list
preds = ensemble_model.predict(
n=5, series=[series_long], past_covariates=[series_long]
)
self.assertTrue(isinstance(preds, list) and len(preds) == 1)

# train with multiple series
ensemble_model = self.get_global_ensembe_model()
ensemble_model.fit([series_short] * 2, past_covariates=[series_long] * 2)
with self.assertRaises(ValueError):
# predict without passing series should raise an error
ensemble_model.predict(n=5, past_covariates=series_long)
# predict a new target series
preds = ensemble_model.predict(
n=5, series=series_long, past_covariates=series_long
)
self.assertTrue(isinstance(preds, TimeSeries))
# predict multiple target series
preds = ensemble_model.predict(
n=5, series=[series_long] * 2, past_covariates=[series_long] * 2
)
self.assertTrue(isinstance(preds, list) and len(preds) == 2)
# predict single target series in list
preds = ensemble_model.predict(
n=5, series=[series_long], past_covariates=[series_long]
)
self.assertTrue(isinstance(preds, list) and len(preds) == 1)

def helper_test_models_accuracy(
self, model_instance, n, series, past_covariates, min_rmse
):
Expand All @@ -201,7 +264,6 @@ def helper_test_models_accuracy(
f"Model was not able to denoise data. A rmse score of {current_rmse} was recorded.",
)

@unittest.skipUnless(TORCH_AVAILABLE, "requires torch")
def denoising_input(self):
np.random.seed(self.RANDOM_SEED)

Expand Down

0 comments on commit 2590a69

Please sign in to comment.