Skip to content

Commit

Permalink
move series length check in forcasting model
Browse files Browse the repository at this point in the history
  • Loading branch information
grll committed Jul 24, 2020
1 parent 257305c commit 189c097
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions darts/models/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ def fit(self) -> None:
Implements behavior that should happen when calling the `fit` method of every forcasting model regardless of
wether they are univariate or multivariate.
"""
for series in (self.training_series, self.covariate_series, self.target_series):
if series is not None:
raise_if_not(len(series) >= self.min_train_series_length,
"Train series only contains {} elements but {} model requires at least {} entries"
.format(len(series), str(self), self.min_train_series_length))
self._fit_called = True

@abstractmethod
Expand Down Expand Up @@ -131,9 +136,6 @@ def _make_fitable_series(self,
# general checks on covariate / target series
raise_if_not(all(covariate_series.time_index() == target_series.time_index()), "Covariate and target "
"timeseries must have same time indices.")
raise_if_not(len(covariate_series) >= self.min_train_series_length,
"Train series only contains {} elements but {} model requires at least {} entries"
.format(len(covariate_series), str(self), self.min_train_series_length))

return covariate_series, target_series

Expand Down

0 comments on commit 189c097

Please sign in to comment.