Skip to content

Commit

Permalink
add feature projection for past covariates to TiDEModel (#1993)
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisbader committed Sep 15, 2023
1 parent a9b6fbc commit fca3993
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 45 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Added short examples in the docstring of all the models, including covariates usage and some model-specific parameters. [#1956](https://github.com/unit8co/darts/pull/1956) by [Antoine Madrona](https://github.com/madtoinou).
- All `RegressionModel`s now support component/column-specific lags for target, past, and future covariates series. [#1962](https://github.com/unit8co/darts/pull/1962) by [Antoine Madrona](https://github.com/madtoinou).
- Added method `TimeSeries.cumsum()` to get the cumulative sum of the time series along the time axis. [#1988](https://github.com/unit8co/darts/pull/1988) by [Eliot Zubkoff](https://github.com/Eliotdoesprogramming).
- 🔴 Added past covariates feature projection to `TiDEModel` with parameter `temporal_width_past` following the advice of the model architect. Parameter `temporal_width` was renamed to `temporal_width_future`. Additionally, added the option to bypass the feature projection with `temporal_width_past/future=0`. [#1993](https://github.com/unit8co/darts/pull/1993) by [Dennis Bader](https://github.com/dennisbader).

**Fixed**
- Fixed a bug in `TimeSeries.from_dataframe()` when using a pandas.DataFrame with `df.columns.name != None`. [#1938](https://github.com/unit8co/darts/pull/1938) by [Antoine Madrona](https://github.com/madtoinou).
Expand Down
150 changes: 107 additions & 43 deletions darts/models/forecasting/tide_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
import torch.nn as nn

from darts.logging import get_logger
from darts.logging import get_logger, raise_log
from darts.models.forecasting.pl_forecasting_module import (
PLMixedCovariatesModule,
io_processor,
Expand Down Expand Up @@ -77,7 +77,8 @@ def __init__(
decoder_output_dim: int,
hidden_size: int,
temporal_decoder_hidden: int,
temporal_width: int,
temporal_width_past: int,
temporal_width_future: int,
use_layer_norm: bool,
dropout: float,
**kwargs,
Expand Down Expand Up @@ -106,7 +107,9 @@ def __init__(
The width of the hidden layers in the encoder/decoder Residual Blocks.
temporal_decoder_hidden
The width of the hidden layers in the temporal decoder.
temporal_width
temporal_width_past
The width of the past covariate embedding space.
temporal_width_future
The width of the future covariate embedding space.
use_layer_norm
Whether to use layer normalization in the Residual Blocks.
Expand All @@ -131,6 +134,7 @@ def __init__(

self.input_dim = input_dim
self.output_dim = output_dim
self.past_cov_dim = input_dim - output_dim - future_cov_dim
self.future_cov_dim = future_cov_dim
self.static_cov_dim = static_cov_dim
self.nr_params = nr_params
Expand All @@ -141,28 +145,52 @@ def __init__(
self.temporal_decoder_hidden = temporal_decoder_hidden
self.use_layer_norm = use_layer_norm
self.dropout = dropout
self.temporal_width = temporal_width
self.temporal_width_past = temporal_width_past
self.temporal_width_future = temporal_width_future

# past covariates handling: either feature projection, raw features, or no features
self.past_cov_projection = None
if self.past_cov_dim and temporal_width_past:
# residual block for past covariates feature projection
self.past_cov_projection = _ResidualBlock(
input_dim=self.past_cov_dim,
output_dim=temporal_width_past,
hidden_size=hidden_size,
use_layer_norm=use_layer_norm,
dropout=dropout,
)
past_covariates_flat_dim = self.input_chunk_length * temporal_width_past
elif self.past_cov_dim:
# skip projection and use raw features
past_covariates_flat_dim = self.input_chunk_length * self.past_cov_dim
else:
past_covariates_flat_dim = 0

# residual block for input feature projection
# this is only needed when covariates are used
if future_cov_dim:
self.feature_projection = _ResidualBlock(
# future covariates handling: either feature projection, raw features, or no features
self.future_cov_projection = None
if future_cov_dim and self.temporal_width_future:
# residual block for future covariates feature projection
self.future_cov_projection = _ResidualBlock(
input_dim=future_cov_dim,
output_dim=temporal_width,
output_dim=temporal_width_future,
hidden_size=hidden_size,
use_layer_norm=use_layer_norm,
dropout=dropout,
)
historical_future_covariates_flat_dim = (
self.input_chunk_length + self.output_chunk_length
) * temporal_width_future
elif future_cov_dim:
# skip projection and use raw features
historical_future_covariates_flat_dim = (
self.input_chunk_length + self.output_chunk_length
) * future_cov_dim
else:
self.feature_projection = None
historical_future_covariates_flat_dim = 0

# original paper doesn't specify how to use past covariates
# we assume that they pass them raw to the encoder
historical_future_covariates_flat_dim = (
self.input_chunk_length + self.output_chunk_length
) * (self.temporal_width if future_cov_dim > 0 else 0)
encoder_dim = (
self.input_chunk_length * (input_dim - future_cov_dim)
self.input_chunk_length * output_dim
+ past_covariates_flat_dim
+ historical_future_covariates_flat_dim
+ static_cov_dim
)
Expand Down Expand Up @@ -210,9 +238,14 @@ def __init__(
),
)

decoder_input_dim = decoder_output_dim * self.nr_params
if temporal_width_future and future_cov_dim:
decoder_input_dim += temporal_width_future
elif future_cov_dim:
decoder_input_dim += future_cov_dim

self.temporal_decoder = _ResidualBlock(
input_dim=decoder_output_dim * self.nr_params
+ (temporal_width if future_cov_dim > 0 else 0),
input_dim=decoder_input_dim,
output_dim=output_dim * self.nr_params,
hidden_size=temporal_decoder_hidden,
use_layer_norm=use_layer_norm,
Expand Down Expand Up @@ -246,44 +279,49 @@ def forward(

x_lookback = x[:, :, : self.output_dim]

# future covariates need to be extracted from x and stacked with historical future covariates
if self.future_cov_dim > 0:
x_dynamic_covariates = torch.cat(
# future covariates: feature projection or raw features
# historical future covariates need to be extracted from x and stacked with part of future covariates
if self.future_cov_dim:
x_dynamic_future_covariates = torch.cat(
[
x_future_covariates,
x[
:,
:,
None if self.future_cov_dim == 0 else -self.future_cov_dim :,
],
x_future_covariates,
],
dim=1,
)

# project input features across all input time steps
x_dynamic_covariates_proj = self.feature_projection(x_dynamic_covariates)

if self.temporal_width_future:
# project input features across all input and output time steps
x_dynamic_future_covariates = self.future_cov_projection(
x_dynamic_future_covariates
)
else:
x_dynamic_covariates = None
x_dynamic_covariates_proj = None
x_dynamic_future_covariates = None

# extract past covariates, if they exist
if self.input_dim - self.output_dim - self.future_cov_dim > 0:
x_past_covariates = x[
# past covariates: feature projection or raw features
# the past covariates are embedded in `x`
if self.past_cov_dim:
x_dynamic_past_covariates = x[
:,
:,
self.output_dim : None
if self.future_cov_dim == 0
else -self.future_cov_dim :,
self.output_dim : self.output_dim + self.past_cov_dim,
]
if self.temporal_width_past:
# project input features across all input time steps
x_dynamic_past_covariates = self.past_cov_projection(
x_dynamic_past_covariates
)
else:
x_past_covariates = None
x_dynamic_past_covariates = None

# setup input to encoder
encoded = [
x_lookback,
x_past_covariates,
x_dynamic_covariates_proj,
x_dynamic_past_covariates,
x_dynamic_future_covariates,
x_static_covariates,
]
encoded = [t.flatten(start_dim=1) for t in encoded if t is not None]
Expand All @@ -299,7 +337,7 @@ def forward(
# stack and temporally decode with future covariate last output steps
temporal_decoder_input = [
decoded,
x_dynamic_covariates_proj[:, -self.output_chunk_length :, :]
x_dynamic_future_covariates[:, -self.output_chunk_length :, :]
if self.future_cov_dim > 0
else None,
]
Expand Down Expand Up @@ -331,7 +369,8 @@ def __init__(
num_decoder_layers: int = 1,
decoder_output_dim: int = 16,
hidden_size: int = 128,
temporal_width: int = 4,
temporal_width_past: int = 4,
temporal_width_future: int = 4,
temporal_decoder_hidden: int = 32,
use_layer_norm: bool = False,
dropout: float = 0.1,
Expand Down Expand Up @@ -369,8 +408,12 @@ def __init__(
The dimensionality of the output of the decoder.
hidden_size
The width of the layers in the residual blocks of the encoder and decoder.
temporal_width
The width of the layers in the future covariate projection residual block.
temporal_width_past
The width of the layers in the past covariate projection residual block. If `0`,
will bypass feature projection and use the raw feature data.
temporal_width_future
The width of the layers in the future covariate projection residual block. If `0`,
will bypass feature projection and use the raw feature data.
temporal_decoder_hidden
The width of the layers in the temporal decoder.
use_layer_norm
Expand Down Expand Up @@ -550,6 +593,13 @@ def encode_year(idx):
`TiDE example notebook <https://unit8co.github.io/darts/examples/18-TiDE-examples.html>`_ presents
techniques that can be used to improve the forecasts quality compared to this simple usage example.
"""
if temporal_width_past < 0 or temporal_width_future < 0:
raise_log(
ValueError(
"`temporal_width_past` and `temporal_width_future` must be >= 0."
),
logger=logger,
)
super().__init__(**self._extract_torch_model_params(**self.model_params))

# extract pytorch lightning module kwargs
Expand All @@ -559,7 +609,8 @@ def encode_year(idx):
self.num_decoder_layers = num_decoder_layers
self.decoder_output_dim = decoder_output_dim
self.hidden_size = hidden_size
self.temporal_width = temporal_width
self.temporal_width_past = temporal_width_past
self.temporal_width_future = temporal_width_future
self.temporal_decoder_hidden = temporal_decoder_hidden

self._considers_static_covariates = use_static_covariates
Expand Down Expand Up @@ -603,6 +654,18 @@ def _create_model(

nr_params = 1 if self.likelihood is None else self.likelihood.num_parameters

past_cov_dim = input_dim - output_dim - future_cov_dim
if past_cov_dim and self.temporal_width_past >= past_cov_dim:
logger.warning(
f"number of `past_covariates` features is <= `temporal_width_past`, leading to feature expansion."
f"number of covariates: {past_cov_dim}, `temporal_width_past={self.temporal_width_past}`."
)
if future_cov_dim and self.temporal_width_future >= future_cov_dim:
logger.warning(
f"number of `future_covariates` features is <= `temporal_width_future`, leading to feature expansion."
f"number of covariates: {future_cov_dim}, `temporal_width_future={self.temporal_width_future}`."
)

return _TideModule(
input_dim=input_dim,
output_dim=output_dim,
Expand All @@ -613,7 +676,8 @@ def _create_model(
num_decoder_layers=self.num_decoder_layers,
decoder_output_dim=self.decoder_output_dim,
hidden_size=self.hidden_size,
temporal_width=self.temporal_width,
temporal_width_past=self.temporal_width_past,
temporal_width_future=self.temporal_width_future,
temporal_decoder_hidden=self.temporal_decoder_hidden,
use_layer_norm=self.use_layer_norm,
dropout=self.dropout,
Expand Down
61 changes: 59 additions & 2 deletions darts/tests/models/forecasting/test_tide_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,54 @@ def test_future_and_past_covariate_handling(self):
)
model.fit(ts_time_index, verbose=False, epochs=1)

model = TiDEModel(
input_chunk_length=1,
output_chunk_length=1,
add_encoders={"cyclic": {"future": "hour", "past": "hour"}},
**tfm_kwargs
)
model.fit(ts_time_index, verbose=False, epochs=1)

@pytest.mark.parametrize("temporal_widths", [(-1, 1), (1, -1)])
def test_failing_future_and_past_temporal_widths(self, temporal_widths):
# invalid temporal widths
with pytest.raises(ValueError):
TiDEModel(
input_chunk_length=1,
output_chunk_length=1,
temporal_width_past=temporal_widths[0],
temporal_width_future=temporal_widths[1],
**tfm_kwargs
)

@pytest.mark.parametrize(
"temporal_widths",
[
(2, 2), # feature projection to same amount of features
(1, 2), # past: feature reduction, future: same amount of features
(2, 1), # past: same amount of features, future: feature reduction
(3, 3), # feature expansion
(0, 2), # bypass past feature projection
(2, 0), # bypass future feature projection
(0, 0), # bypass all feature projection
],
)
def test_future_and_past_temporal_widths(self, temporal_widths):
ts_time_index = tg.sine_timeseries(length=2, freq="h")

# feature projection to 2 features (same amount as input features)
model = TiDEModel(
input_chunk_length=1,
output_chunk_length=1,
temporal_width_past=temporal_widths[0],
temporal_width_future=temporal_widths[1],
add_encoders={"cyclic": {"future": "hour", "past": "hour"}},
**tfm_kwargs
)
model.fit(ts_time_index, verbose=False, epochs=1)
assert model.model.temporal_width_past == temporal_widths[0]
assert model.model.temporal_width_future == temporal_widths[1]

def test_past_covariate_handling(self):
ts_time_index = tg.sine_timeseries(length=2, freq="h")

Expand All @@ -142,7 +190,12 @@ def test_future_and_past_covariate_as_timeseries_handling(self):
use_reversible_instance_norm=enable_rin,
**tfm_kwargs
)
model.fit(ts_time_index, ts_time_index, verbose=False, epochs=1)
model.fit(
ts_time_index,
past_covariates=ts_time_index,
verbose=False,
epochs=1,
)

# test with past_covariates and future_covariates timeseries
model = TiDEModel(
Expand All @@ -153,7 +206,11 @@ def test_future_and_past_covariate_as_timeseries_handling(self):
**tfm_kwargs
)
model.fit(
ts_time_index, ts_time_index, ts_time_index, verbose=False, epochs=1
ts_time_index,
past_covariates=ts_time_index,
future_covariates=ts_time_index,
verbose=False,
epochs=1,
)

def test_static_covariates_support(self):
Expand Down

0 comments on commit fca3993

Please sign in to comment.