Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix MixedCovTorchModels multi TS predictions with n<ocl #2374

Merged
merged 2 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

**Fixed**
- Fixed a bug where `n_steps_between` did not work properly with custom business frequencies. This affected metrics computation. [#2357](https://github.com/unit8co/darts/pull/2357) by [Dennis Bader](https://github.com/dennisbader).
- Fixed a bug when calling `predict()` with a `MixedCovariatesTorchModel` (e.g. TiDE, N/DLinear, ...) `n<output_chunk_length` and a list of series with length `len(series) < n`, where the predictions did not return the correct number of series. [#2374](https://github.com/unit8co/darts/pull/2374) by [Dennis Bader](https://github.com/dennisbader).

**Dependencies**
- Improvements to linting via updated pre-commit configurations: [#2324](https://github.com/unit8co/darts/pull/2324) by [Jirka Borovec](https://github.com/borda).
Expand Down
10 changes: 6 additions & 4 deletions darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,13 +830,15 @@ def _get_batch_prediction(
batch_prediction = [out[:, :roll_size, :]]
prediction_length = roll_size

while prediction_length < n:
# we want the last prediction to end exactly at `n` into the future.
# predict at least `output_chunk_length` points, so that we use the most recent target values
min_n = n if n >= self.output_chunk_length else self.output_chunk_length
while prediction_length < min_n:
# we want the last prediction to end exactly at `min_n` into the future.
# this means we may have to truncate the previous prediction and step
# back the roll size for the last chunk
if prediction_length + self.output_chunk_length > n:
if prediction_length + self.output_chunk_length > min_n:
spillover_prediction_length = (
prediction_length + self.output_chunk_length - n
prediction_length + self.output_chunk_length - min_n
)
roll_size -= spillover_prediction_length
prediction_length -= spillover_prediction_length
Expand Down
56 changes: 0 additions & 56 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2779,62 +2779,6 @@ def extreme_lags(
None,
)

def predict(
self,
n: int,
series: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
trainer: Optional[pl.Trainer] = None,
batch_size: Optional[int] = None,
verbose: Optional[bool] = None,
n_jobs: int = 1,
roll_size: Optional[int] = None,
num_samples: int = 1,
num_loader_workers: int = 0,
mc_dropout: bool = False,
predict_likelihood_parameters: bool = False,
show_warnings: bool = True,
) -> Union[TimeSeries, Sequence[TimeSeries]]:
# since we have future covariates, the inference dataset for future input must be at least of length
# `output_chunk_length`. If not, we would have to step back which causes past input to be shorter than
# `input_chunk_length`.

if n >= self.output_chunk_length:
return super().predict(
n=n,
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
trainer=trainer,
batch_size=batch_size,
verbose=verbose,
n_jobs=n_jobs,
roll_size=roll_size,
num_samples=num_samples,
num_loader_workers=num_loader_workers,
mc_dropout=mc_dropout,
predict_likelihood_parameters=predict_likelihood_parameters,
show_warnings=show_warnings,
)
else:
return super().predict(
n=self.output_chunk_length,
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
trainer=trainer,
batch_size=batch_size,
verbose=verbose,
n_jobs=n_jobs,
roll_size=roll_size,
num_samples=num_samples,
num_loader_workers=num_loader_workers,
mc_dropout=mc_dropout,
predict_likelihood_parameters=predict_likelihood_parameters,
show_warnings=show_warnings,
)[:n]


class SplitCovariatesTorchModel(TorchForecastingModel, ABC):
def _build_train_dataset(
Expand Down
23 changes: 23 additions & 0 deletions darts/tests/models/forecasting/test_torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1668,6 +1668,29 @@ def test_output_shift(self, config):
_ = model_fc_shift.predict(n=ocl, **add_covs)
assert f"provided {cov_name} covariates at dataset index" in str(err.value)

@pytest.mark.parametrize("config", itertools.product(models, [2, 3, 4]))
def test_multi_ts_prediction(self, config):
(model_cls, model_kwargs), n = config
model_kwargs = copy.deepcopy(model_kwargs)
model_kwargs["output_chunk_length"] = 3
series = tg.linear_timeseries(
length=model_kwargs["input_chunk_length"]
+ model_kwargs["output_chunk_length"]
)
model = model_cls(**model_kwargs)
model.fit(series)
# test with more series that `n`
n_series_more = 5
pred = model.predict(n=n, series=[series] * n_series_more)
assert len(pred) == n_series_more
assert all(len(p) == n for p in pred)

# test with less series that `n`
n_series_less = 1
pred = model.predict(n=n, series=[series] * n_series_less)
assert len(pred) == n_series_less
assert all(len(p) == n for p in pred)

def helper_equality_encoders(
self, first_encoders: Dict[str, Any], second_encoders: Dict[str, Any]
):
Expand Down
Loading