Skip to content

Commit

Permalink
Feat/fit predict encodings (#1925)
Browse files Browse the repository at this point in the history
* added encode_train_inference to encoders

* added generate_fit_predict_encodings to ForecastingModel

* simplify TransferrableFut..Model.generatice_predict_encodings

* update changelog

* Apply suggestions from code review

Co-authored-by: madtoinou <32447896+madtoinou@users.noreply.github.com>

* apply suggestions from PR review part 2

---------

Co-authored-by: madtoinou <32447896+madtoinou@users.noreply.github.com>
  • Loading branch information
dennisbader and madtoinou committed Jul 31, 2023
1 parent d30f163 commit 3376f27
Show file tree
Hide file tree
Showing 6 changed files with 553 additions and 172 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Added model property `ForecastingModel.supports_multivariate` to indicate whether the model supports multivariate forecasting. [#1848](https://github.com/unit8co/darts/pull/1848) by [Felix Divo](https://github.com/felixdivo).
- `Prophet` now supports conditional seasonalities, and properly handles all parameters passed to `Prophet.add_seasonality()` and model creation parameter `add_seasonalities` [#1829](https://github.com/unit8co/darts/pull/#1829) by [Idan Shilon](https://github.com/id5h).
- Added support for direct prediction of the likelihood parameters to probabilistic models using a likelihood (regression and torch models). Set `predict_likelihood_parameters=True` when calling `predict()`. [#1811](https://github.com/unit8co/darts/pull/1811) by [Antoine Madrona](https://github.com/madtoinou).
- Added method `generate_fit_predict_encodings()` to generate the encodings (from `add_encoders` at model creation) required for training and prediction. [#1925](https://github.com/unit8co/darts/pull/1925) by [Dennis Bader](https://github.com/dennisbader).
- Improvements to `EnsembleModel`:
- Model creation parameter `forecasting_models` now supports a mix of `LocalForecastingModel` and `GlobalForecastingModel` (single `TimeSeries` training/inference only, due to the local models). [#1745](https://github.com/unit8co/darts/pull/1745) by [Antoine Madrona](https://github.com/madtoinou).
- Future and past covariates can now be used even if `forecasting_models` have different covariates support. The covariates passed to `fit()`/`predict()` are used only by models that support it. [#1745](https://github.com/unit8co/darts/pull/1745) by [Antoine Madrona](https://github.com/madtoinou).
Expand Down
135 changes: 130 additions & 5 deletions darts/dataprocessing/encoders/encoder_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,36 @@
from darts.logging import get_logger, raise_if, raise_log
from darts.utils.timeseries_generation import generate_index

try:
from typing import Literal
except ImportError:
from typing_extensions import Literal

SupportedIndex = Union[pd.DatetimeIndex, pd.RangeIndex]
EncoderOutputType = Optional[Union[Sequence[TimeSeries], List[TimeSeries]]]
logger = get_logger(__name__)


class _EncoderMethod:
"""Connects the encoder stage to the corresponding methods"""

def __init__(self, stage: Literal["train", "inference", "train_inference"]):
self.method = None
if stage == "train":
self.method = "encode_train"
elif stage == "inference":
self.method = "encode_inference"
elif stage == "train_inference":
self.method = "encode_train_inference"
else:
raise_log(
ValueError(
f"Unknown encoder `stage={stage}`. Must be on of `('train', 'inference', 'train_inference')`"
),
logger,
)


class CovariatesIndexGenerator(ABC):
def __init__(
self,
Expand Down Expand Up @@ -128,6 +153,33 @@ def generate_inference_idx(
"""
pass

def generate_train_inference_idx(
self, n: int, target: TimeSeries, covariates: Optional[TimeSeries] = None
) -> Tuple[SupportedIndex, pd.Timestamp]:
"""
Generates/extracts time index (or integer index) for covariates for training and inference / prediction.
Parameters
----------
n
The forecasting horizon.
target
The target TimeSeries used for training and inference / prediction as `series`.
covariates
Optionally, the covariates used for training and inference / prediction.
If given, the returned time index is equal to the `covariates` time index. Else, the returned time index
covers the minimum required covariate time spans for performing training and inference / prediction with a
specific forecasting model. These requirements are derived from parameters set at
:class:`CovariatesIndexGenerator` creation.
"""
train_idx, target_end = self.generate_train_idx(
target=target, covariates=covariates
)
inference_idx, _ = self.generate_inference_idx(
n=n, target=target, covariates=covariates
)
return train_idx.__class__.union(train_idx, inference_idx), target_end

@property
@abstractmethod
def base_component_name(self) -> str:
Expand Down Expand Up @@ -444,7 +496,7 @@ def encode_train(
covariates
Optionally, the past or future covariates used for training.
merge_covariates
Whether or not to merge the encoded TimeSeries with `covariates`.
Whether to merge the encoded TimeSeries with `covariates`.
"""
pass

Expand All @@ -468,8 +520,32 @@ def encode_inference(
covariates
Optionally, the past or future covariates used for prediction.
merge_covariates
Whether or not to merge the encoded TimeSeries with `covariates`.
Whether to merge the encoded TimeSeries with `covariates`.
"""
pass

@abstractmethod
def encode_train_inference(
self,
n: int,
target: TimeSeries,
covariates: Optional[TimeSeries] = None,
merge_covariates: bool = True,
**kwargs,
) -> TimeSeries:
"""Each subclass must implement a method to encode the covariates index for training and prediction.
Parameters
----------
n
The forecast horizon
target
The target TimeSeries used during training and prediction.
covariates
Optionally, the past or future covariates used for training and prediction.
merge_covariates
Whether to merge the encoded TimeSeries with `covariates`.
"""
pass

Expand Down Expand Up @@ -589,7 +665,7 @@ def encode_train(
`PastCovariatesIndexGenerator`, future covariates if `self.index_generator` is a
`FutureCovariatesIndexGenerator`
merge_covariates
Whether or not to merge the encoded TimeSeries with `covariates`.
Whether to merge the encoded TimeSeries with `covariates`.
"""
# exclude encoded components from covariates to add the newly encoded components later
covariates = self._drop_encoded_components(covariates, self.components)
Expand Down Expand Up @@ -636,7 +712,7 @@ def encode_inference(
`PastCovariatesIndexGenerator`, future covariates if `self.index_generator` is a
`FutureCovariatesIndexGenerator`
merge_covariates
Whether or not to merge the encoded TimeSeries with `covariates`.
Whether to merge the encoded TimeSeries with `covariates`.
"""
# some encoders must be fit before `encode_inference()`
raise_if(
Expand Down Expand Up @@ -671,6 +747,55 @@ def encode_inference(

return encoded

def encode_train_inference(
self,
n: int,
target: TimeSeries,
covariates: Optional[TimeSeries] = None,
merge_covariates: bool = True,
**kwargs,
) -> TimeSeries:
"""Returns encoded index for inference/prediction.
Parameters
----------
n
The forecast horizon
target
The target TimeSeries used during training and prediction.
covariates
Optionally, the covariates used for training and prediction: past covariates if `self.index_generator` is a
`PastCovariatesIndexGenerator`, future covariates if `self.index_generator` is a
`FutureCovariatesIndexGenerator`
merge_covariates
Whether to merge the encoded TimeSeries with `covariates`.
"""
# exclude encoded components from covariates to add the newly encoded components later
covariates = self._drop_encoded_components(covariates, self.components)

# generate index and encodings
index, target_end = self.index_generator.generate_train_inference_idx(
n, target, covariates
)
encoded = self._encode(index, target_end, target.dtype)

# optionally, merge encodings with original `covariates` series
encoded = (
self._merge_covariates(encoded, covariates=covariates)
if merge_covariates
else encoded
)

# save encoded component names
if self.components.empty:
components = encoded.components
if covariates is not None:
components = components[~components.isin(covariates.components)]
self._components = components

self._fit_called = True
return encoded

@property
@abstractmethod
def accept_transformer(self) -> List[bool]:
Expand Down Expand Up @@ -782,7 +907,7 @@ def _update_mask(self, covariates: List[TimeSeries]) -> None:

@property
def fit_called(self) -> bool:
"""Return whether or not the transformer has been fitted."""
"""Return whether the transformer has been fitted."""
return self._fit_called


Expand Down

0 comments on commit 3376f27

Please sign in to comment.