Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/hist fc predict kwargs #2103

Merged
merged 6 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
**Improved**

**Fixed**
- Fixed a bug in historical forecasts, where some `fit/predict_kwargs` were not passed to the underlying model's fit/predict methods. [#2103](https://github.com/unit8co/darts/pull/2103) by [Dennis Bader](https://github.com/dennisbader).
- Fixed an import error when trying to create a `TorchForecastingModel` with PyTorch Lightning v<2.0.0. [#2087](https://github.com/unit8co/darts/pull/2087) by [Eschibli](https://github.com/eschibli).

### For developers of the library:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
each parallel job. If `parallel_params` is a `Sequence` of fixed attribute names, only those
attribute names specified will take on different values between different parallel jobs.
mask_components
Optionally, whether or not to automatically apply any provided `component_mask`s to the
Optionally, whether to automatically apply any provided `component_mask`s to the
`TimeSeries` inputs passed to `transform`, `fit`, `inverse_transform`, or `fit_transform`.
If `True`, any specified `component_mask` will be applied to each input timeseries
before passing them to the called method; the masked components will also be automatically
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
each parallel job. If `parallel_params` is a `Sequence` of fixed attribute names, only those
attribute names specified will take on different values between different parallel jobs.
mask_components
Optionally, whether or not to automatically apply any provided `component_mask`s to the
Optionally, whether to automatically apply any provided `component_mask`s to the
`TimeSeries` inputs passed to `transform`, `fit`, `inverse_transform`, or `fit_transform`.
If `True`, any specified `component_mask` will be applied to each input timeseries
before passing them to the called method; the masked components will also be automatically
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
each parallel job. If `parallel_params` is a `Sequence` of fixed attribute names, only those
attribute names specified will take on different values between different parallel jobs.
mask_components
Optionally, whether or not to automatically apply any provided `component_mask`s to the
Optionally, whether to automatically apply any provided `component_mask`s to the
`TimeSeries` inputs passed to `transform`, `fit`, `inverse_transform`, or `fit_transform`.
If `True`, any specified `component_mask` will be applied to each input timeseries
before passing them to the called method; the masked components will also be automatically
Expand Down
2 changes: 1 addition & 1 deletion darts/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _get_values_or_raise(
series_b
A univariate (deterministic or stochastic) ``TimeSeries`` instance (the predicted series).
intersect
A boolean for whether or not to only consider the time intersection between `series_a` and `series_b`
A boolean for whether to only consider the time intersection between `series_a` and `series_b`
stochastic_quantile
Optionally, for stochastic predicted series, return either all sample values with (`stochastic_quantile=None`)
or any deterministic quantile sample values by setting `stochastic_quantile=quantile` {>=0,<=1}.
Expand Down
4 changes: 4 additions & 0 deletions darts/models/forecasting/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def predict(
n: int,
num_samples: int = 1,
verbose: bool = False,
show_warnings: bool = True,
):
super().predict(n, num_samples)
forecast = np.tile(self.mean_val, (n, 1))
Expand Down Expand Up @@ -125,6 +126,7 @@ def predict(
n: int,
num_samples: int = 1,
verbose: bool = False,
show_warnings: bool = True,
):
super().predict(n, num_samples)
forecast = np.array([self.last_k_vals[i % self.K, :] for i in range(n)])
Expand Down Expand Up @@ -174,6 +176,7 @@ def predict(
n: int,
num_samples: int = 1,
verbose: bool = False,
show_warnings: bool = True,
):
super().predict(n, num_samples)
first, last = (
Expand Down Expand Up @@ -244,6 +247,7 @@ def predict(
n: int,
num_samples: int = 1,
verbose: bool = False,
show_warnings: bool = True,
):
super().predict(n, num_samples)

Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/block_rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def __init__(
If set to ``True``, any previously-existing model with the same name will be reset (all checkpoints will
be discarded). Default: ``False``.
save_checkpoints
Whether or not to automatically save the untrained model and checkpoints from training.
Whether to automatically save the untrained model and checkpoints from training.
To load the model from checkpoint, call :func:`MyModelClass.load_from_checkpoint()`, where
:class:`MyModelClass` is the :class:`TorchForecastingModel` class that was used (such as :class:`TFTModel`,
:class:`NBEATSModel`, etc.). If set to ``False``, the model can still be manually saved using
Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/dlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def __init__(
If set to ``True``, any previously-existing model with the same name will be reset (all checkpoints will
be discarded). Default: ``False``.
save_checkpoints
Whether or not to automatically save the untrained model and checkpoints from training.
Whether to automatically save the untrained model and checkpoints from training.
To load the model from checkpoint, call :func:`MyModelClass.load_from_checkpoint()`, where
:class:`MyModelClass` is the :class:`TorchForecastingModel` class that was used (such as :class:`TFTModel`,
:class:`NBEATSModel`, etc.). If set to ``False``, the model can still be manually saved using
Expand Down
8 changes: 7 additions & 1 deletion darts/models/forecasting/exponential_smoothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,13 @@ def fit(self, series: TimeSeries):

return self

def predict(self, n, num_samples=1, verbose: bool = False):
def predict(
self,
n: int,
num_samples: int = 1,
verbose: bool = False,
show_warnings: bool = True,
):
super().predict(n, num_samples)

if num_samples == 1:
Expand Down
8 changes: 7 additions & 1 deletion darts/models/forecasting/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,13 @@ def fit(self, series: TimeSeries):

return self

def predict(self, n: int, num_samples: int = 1, verbose: bool = False):
def predict(
self,
n: int,
num_samples: int = 1,
verbose: bool = False,
show_warnings: bool = True,
):
super().predict(n, num_samples)
trend_forecast = np.array(
[self.trend_function(i + len(self.training_series)) for i in range(n)]
Expand Down
117 changes: 89 additions & 28 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,14 @@ def supports_likelihood_parameter_prediction(self) -> bool:
"""
return getattr(self, "likelihood", None) is not None

@property
@abstractmethod
def supports_transferrable_series_prediction(self) -> bool:
"""
Whether the model supports prediction for any input `series`.
"""
pass

@property
def uses_past_covariates(self) -> bool:
"""
Expand Down Expand Up @@ -278,7 +286,13 @@ def output_chunk_length(self) -> Optional[int]:
return None

@abstractmethod
def predict(self, n: int, num_samples: int = 1) -> TimeSeries:
def predict(
self,
n: int,
num_samples: int = 1,
verbose: bool = False,
show_warnings: bool = True,
) -> TimeSeries:
"""Forecasts values for `n` time steps after the end of the training series.

Parameters
Expand All @@ -288,6 +302,10 @@ def predict(self, n: int, num_samples: int = 1) -> TimeSeries:
num_samples
Number of times a prediction is sampled from a probabilistic model. Should be left set to 1
for deterministic models.
verbose
Optionally, set the prediction verbosity. Not effective for all models.
show_warnings
Optionally, control whether warnings are shown. Not effective for all models.

Returns
-------
Expand Down Expand Up @@ -315,49 +333,59 @@ def predict(self, n: int, num_samples: int = 1) -> TimeSeries:
def _fit_wrapper(
self,
series: TimeSeries,
past_covariates: Optional[TimeSeries],
future_covariates: Optional[TimeSeries],
past_covariates: Optional[TimeSeries] = None,
future_covariates: Optional[TimeSeries] = None,
**kwargs,
):
supported_params = inspect.signature(self.fit).parameters
kwargs_ = {k: v for k, v in kwargs.items() if k in supported_params}

add_kwargs = {}
# handle past and future covariates based on model support
for covs, name in zip([past_covariates, future_covariates], ["past", "future"]):
covs_name = f"{name}_covariates"
for covs, covs_name in zip(
[past_covariates, future_covariates],
["past_covariates", "future_covariates"],
):
if getattr(self, f"supports_{covs_name}"):
kwargs_[covs_name] = covs
add_kwargs[covs_name] = covs
elif covs is not None:
raise_log(
ValueError(f"Model cannot be fit/trained with `{covs_name}`."),
logger,
)
self.fit(series, **kwargs_)
self.fit(series=series, **add_kwargs, **kwargs)

def _predict_wrapper(
self,
n: int,
series: Optional[TimeSeries] = None,
past_covariates: Optional[TimeSeries] = None,
future_covariates: Optional[TimeSeries] = None,
predict_likelihood_parameters: bool = False,
**kwargs,
) -> Union[TimeSeries, Sequence[TimeSeries]]:
supported_params = set(inspect.signature(self.predict).parameters)

# if predict() accepts covariates, the model might not support them at inference
for covs_name in ["past_covariates", "future_covariates"]:
if covs_name in kwargs and not getattr(self, f"supports_{covs_name}"):
if kwargs[covs_name] is None:
supported_params = supported_params - {covs_name}
else:
raise_log(
ValueError(
f"Model prediction does not support `{covs_name}`, either because it "
f"does not support `{covs_name}` in general, or because it was fit/trained "
f"without using `{covs_name}`."
),
logger,
)
add_kwargs = {}
# not all models supports input `series` at inference
if self.supports_transferrable_series_prediction:
add_kwargs["series"] = series

# even if predict() accepts covariates, the model might not support them at inference
for covs, name in zip(
[past_covariates, future_covariates],
["past_covariates", "future_covariates"],
):
if getattr(self, f"supports_{name}"):
add_kwargs[name] = covs
elif covs is not None:
raise_log(
ValueError(
f"Model prediction does not support `{name}`, either because it "
f"does not support `{name}` in general, or because it was fit/trained "
f"without using `{name}`."
),
logger,
)

kwargs_ = {k: v for k, v in kwargs.items() if k in supported_params}
return self.predict(n, **kwargs_)
if self.supports_likelihood_parameter_prediction:
add_kwargs["predict_likelihood_parameters"] = predict_likelihood_parameters
return self.predict(n=n, **add_kwargs, **kwargs)

@property
def min_train_series_length(self) -> int:
Expand Down Expand Up @@ -2099,6 +2127,13 @@ def extreme_lags(
# that use an input to predict an output.
return -self.min_train_series_length, -1, None, None, None, None

@property
def supports_transferrable_series_prediction(self) -> bool:
"""
Whether the model supports prediction for any input `series`.
"""
return False


class GlobalForecastingModel(ForecastingModel, ABC):
"""The base class for "global" forecasting models, handling several time series and optional covariates.
Expand Down Expand Up @@ -2314,6 +2349,13 @@ def supports_optimized_historical_forecasts(self) -> bool:
"""
return True

@property
def supports_transferrable_series_prediction(self) -> bool:
"""
Whether the model supports prediction for any input `series`.
"""
return True

def _sanity_check_predict_likelihood_parameters(
self, n: int, output_chunk_length: Union[int, None], num_samples: int
):
Expand Down Expand Up @@ -2420,6 +2462,8 @@ def predict(
n: int,
future_covariates: Optional[TimeSeries] = None,
num_samples: int = 1,
verbose: bool = False,
show_warnings: bool = True,
**kwargs,
) -> TimeSeries:
"""Forecasts values for `n` time steps after the end of the training series.
Expand All @@ -2437,6 +2481,10 @@ def predict(
num_samples
Number of times a prediction is sampled from a probabilistic model. Should be left set to 1
for deterministic models.
verbose
Optionally, set the prediction verbosity. Not effective for all models.
show_warnings
Optionally, control whether warnings are shown. Not effective for all models.

Returns
-------
Expand Down Expand Up @@ -2580,6 +2628,8 @@ def predict(
series: Optional[TimeSeries] = None,
future_covariates: Optional[TimeSeries] = None,
num_samples: int = 1,
verbose: bool = False,
show_warnings: bool = True,
**kwargs,
) -> TimeSeries:
"""If the `series` parameter is not set, forecasts values for `n` time steps after the end of the training
Expand All @@ -2605,6 +2655,10 @@ def predict(
num_samples
Number of times a prediction is sampled from a probabilistic model. Should be left set to 1
for deterministic models.
verbose
Optionally, set the prediction verbosity. Not effective for all models.
show_warnings
Optionally, control whether warnings are shown. Not effective for all models.

Returns
-------
Expand Down Expand Up @@ -2699,6 +2753,13 @@ def _predict(
"""
pass

@property
def supports_transferrable_series_prediction(self) -> bool:
"""
Whether the model supports prediction for any input `series`.
"""
return True

@property
def _supports_non_retrainable_historical_forecasts(self) -> bool:
return True
Expand Down
2 changes: 2 additions & 0 deletions darts/models/forecasting/kalman_forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def predict(
series: Optional[TimeSeries] = None,
future_covariates: Optional[TimeSeries] = None,
num_samples: int = 1,
verbose: bool = False,
show_warnings: bool = True,
**kwargs,
) -> TimeSeries:
# we override `predict()` to pass a non-None `series`, so that historic_future_covariates
Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def __init__(
If set to ``True``, any previously-existing model with the same name will be reset (all checkpoints will
be discarded). Default: ``False``.
save_checkpoints
Whether or not to automatically save the untrained model and checkpoints from training.
Whether to automatically save the untrained model and checkpoints from training.
To load the model from checkpoint, call :func:`MyModelClass.load_from_checkpoint()`, where
:class:`MyModelClass` is the :class:`TorchForecastingModel` class that was used (such as :class:`TFTModel`,
:class:`NBEATSModel`, etc.). If set to ``False``, the model can still be manually saved using
Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/nhits.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def __init__(
If set to ``True``, any previously-existing model with the same name will be reset (all checkpoints will
be discarded). Default: ``False``.
save_checkpoints
Whether or not to automatically save the untrained model and checkpoints from training.
Whether to automatically save the untrained model and checkpoints from training.
To load the model from checkpoint, call :func:`MyModelClass.load_from_checkpoint()`, where
:class:`MyModelClass` is the :class:`TorchForecastingModel` class that was used (such as :class:`TFTModel`,
:class:`NBEATSModel`, etc.). If set to ``False``, the model can still be manually saved using
Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/nlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def __init__(
If set to ``True``, any previously-existing model with the same name will be reset (all checkpoints will
be discarded). Default: ``False``.
save_checkpoints
Whether or not to automatically save the untrained model and checkpoints from training.
Whether to automatically save the untrained model and checkpoints from training.
To load the model from checkpoint, call :func:`MyModelClass.load_from_checkpoint()`, where
:class:`MyModelClass` is the :class:`TorchForecastingModel` class that was used (such as :class:`TFTModel`,
:class:`NBEATSModel`, etc.). If set to ``False``, the model can still be manually saved using
Expand Down
2 changes: 2 additions & 0 deletions darts/models/forecasting/regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,8 @@ def predict(
**kwargs : dict, optional
Additional keyword arguments passed to the `predict` method of the model. Only works with
univariate target series.
show_warnings
Optionally, control whether warnings are shown. Not effective for all models.
"""
if series is None:
# then there must be a single TS, and that was saved in super().fit as self.training_series
Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def __init__(
If set to ``True``, any previously-existing model with the same name will be reset (all checkpoints will
be discarded). Default: ``False``.
save_checkpoints
Whether or not to automatically save the untrained model and checkpoints from training.
Whether to automatically save the untrained model and checkpoints from training.
To load the model from checkpoint, call :func:`MyModelClass.load_from_checkpoint()`, where
:class:`MyModelClass` is the :class:`TorchForecastingModel` class that was used (such as :class:`TFTModel`,
:class:`NBEATSModel`, etc.). If set to ``False``, the model can still be manually saved using
Expand Down
1 change: 1 addition & 0 deletions darts/models/forecasting/sf_auto_ces.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def predict(
n: int,
num_samples: int = 1,
verbose: bool = False,
show_warnings: bool = True,
):
super().predict(n, num_samples)
forecast_dict = self.model.predict(
Expand Down