Skip to content

Commit

Permalink
add fit_called checks in global models (#944)
Browse files Browse the repository at this point in the history
* add fit_called checks in global models

* adapt TFT tests
  • Loading branch information
hrzn committed May 18, 2022
1 parent 699bf17 commit adb66fd
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
7 changes: 3 additions & 4 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ def predict(self, n: int, num_samples: int = 1) -> TimeSeries:
if not self._fit_called:
raise_log(
ValueError(
"The model must be fit before calling `predict()`."
"For global models, if `predict()` is called without specifying a series,"
"The model must be fit before calling predict(). "
"For global models, if predict() is called without specifying a series, "
"the model must have been fit on a single training series."
),
logger,
Expand Down Expand Up @@ -1014,8 +1014,7 @@ def predict(
If `series` is given and is a sequence of several time series, this function returns
a sequence where each element contains the corresponding `n` points forecasts.
"""
if series is None and past_covariates is None and future_covariates is None:
super().predict(n, num_samples)
super().predict(n, num_samples)
if self._expect_past_covariates and past_covariates is None:
raise_log(
ValueError(
Expand Down
11 changes: 9 additions & 2 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@
raise_log,
suppress_lightning_warnings,
)
from darts.models.forecasting.forecasting_model import GlobalForecastingModel
from darts.models.forecasting.forecasting_model import (
ForecastingModel,
GlobalForecastingModel,
)
from darts.models.forecasting.pl_forecasting_module import PLForecastingModule
from darts.timeseries import TimeSeries
from darts.utils.data.encoders import SequentialEncoder
Expand Down Expand Up @@ -830,7 +833,7 @@ def fit_from_dataset(
self
Fitted model.
"""

self._fit_called = True
self._verify_train_dataset_type(train_dataset)
raise_if(
len(train_dataset) == 0,
Expand Down Expand Up @@ -1173,6 +1176,10 @@ def predict_from_dataset(
Sequence[TimeSeries]
Returns one or more forecasts for time series.
"""

# We need to call super's super's method directly, because GlobalForecastingModel expects series:
ForecastingModel.predict(self, n, num_samples)

self._verify_inference_dataset_type(input_series_dataset)

# check that covariates and dimensions are matching what we had during training
Expand Down
10 changes: 7 additions & 3 deletions darts/tests/models/forecasting/test_TFT.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,14 @@ def helper_fit_predict(
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
num_samples=100,
num_samples=(100 if model._is_probabilistic() else 1),
)

if isinstance(y_hat, TimeSeries):
y_hat = y_hat.quantile_timeseries(0.5)
y_hat = y_hat.quantile_timeseries(0.5) if y_hat.n_samples > 1 else y_hat
else:
y_hat = [ts.quantile_timeseries(0.5) for ts in y_hat]
y_hat = [
ts.quantile_timeseries(0.5) if ts.n_samples > 1 else ts
for ts in y_hat
]
return y_hat

0 comments on commit adb66fd

Please sign in to comment.