Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/static covs #966

Merged
merged 29 commits into from
Jun 5, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
125c078
added methods ``from_longitudinal_dataframe` and `add_static_covariates`
dennisbader Apr 13, 2022
fde974e
dataset adaption for static covs
dennisbader Apr 23, 2022
d6d4885
extended datasets for static covariates support and unified variable …
dennisbader May 19, 2022
4717ff1
adapted PLXCovariatesModules with static covariates
dennisbader May 19, 2022
5b6b781
adapted TFTModel for static covariate support
dennisbader May 20, 2022
55eaf3b
added temporary fix for static covariates with scalers
dennisbader May 20, 2022
3ced3da
Merge branch 'master' into feat/static_covs
dennisbader May 20, 2022
29924f4
unittests for from_longitudinal_dataframe() and set_static_covariates
dennisbader May 24, 2022
079d969
updated dataset tests
dennisbader May 24, 2022
3511b81
fixed all downstream issues from new static covariates in datasets
dennisbader May 27, 2022
eacaf3b
added check for equal static covariates between fit and predict
dennisbader May 28, 2022
55c5090
added tests for passing static covariates in TimeSeries methods
dennisbader May 28, 2022
cc07f5f
added static covariate support for stacking TimeSeries
dennisbader May 28, 2022
0aacd5a
transpose static covariates
dennisbader May 29, 2022
2845f86
added method `static_covariates_values()`
dennisbader May 29, 2022
2ac58e4
updated docs
dennisbader May 29, 2022
a6fa4fb
static covariate support for concatenation
dennisbader May 30, 2022
a4ba617
static covariate support for concatenation
dennisbader May 30, 2022
0586b7d
static covariates are now passed to the torch models
dennisbader May 30, 2022
c18e806
non-numerical dtype support for static covariates
dennisbader May 31, 2022
a048ecc
added slicing support for static covariates
dennisbader May 31, 2022
3661385
multicomponent static covariate support for TFT
dennisbader May 31, 2022
5b9258b
Merge branch 'master' into feat/static_covs
dennisbader May 31, 2022
3a9ad83
added arithmetic static covariate support
dennisbader May 31, 2022
d00c08d
Merge branch 'master' into feat/static_covs
dennisbader Jun 3, 2022
f5fa989
updated all timeseries methods/operations with static cov transfer
dennisbader Jun 4, 2022
41adf3f
applied suggestion from PR review part 1
dennisbader Jun 4, 2022
6dc7ff8
apply suggestions from code review part 2
dennisbader Jun 4, 2022
d001e17
fix black issue from PR suggestion
dennisbader Jun 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions darts/dataprocessing/transformers/boxcox.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def ts_transform(
)
return series.with_values(
BoxCox._reshape_out(series, transformed_vals, component_mask=component_mask)
)
).set_static_covariates(series.static_covariates)

@staticmethod
def ts_inverse_transform(
Expand All @@ -185,7 +185,7 @@ def ts_inverse_transform(
BoxCox._reshape_out(
series, inv_transformed_vals, component_mask=component_mask
)
)
).set_static_covariates(series.static_covariates)

def fit(
self, series: Union[TimeSeries, Sequence[TimeSeries]], **kwargs
Expand Down
4 changes: 2 additions & 2 deletions darts/dataprocessing/transformers/scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def ts_transform(series: TimeSeries, transformer, **kwargs) -> TimeSeries:
values=transformed_vals,
fill_missing_dates=False,
columns=series.columns,
)
).set_static_covariates(series.static_covariates)

@staticmethod
def ts_inverse_transform(
Expand All @@ -126,7 +126,7 @@ def ts_inverse_transform(
values=inv_transformed_vals,
fill_missing_dates=False,
columns=series.columns,
)
).set_static_covariates(series.static_covariates)

@staticmethod
def ts_fit(series: TimeSeries, transformer, *args, **kwargs) -> Any:
Expand Down
62 changes: 33 additions & 29 deletions darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,16 @@ def epochs_trained(self):

class PLPastCovariatesModule(PLForecastingModule, ABC):
def _produce_train_output(self, input_batch: Tuple):
past_target, past_covariate = input_batch
"""
Feeds PastCovariatesTorchModel with input and output chunks of a PastCovariatesSequentialDataset for
training.

Parameters:
----------
input_batch
``(past_target, past_covariates, static_covariates)``
"""
past_target, past_covariate, _ = input_batch
dennisbader marked this conversation as resolved.
Show resolved Hide resolved
# Currently all our PastCovariates models require past target and covariates concatenated
inpt = (
torch.cat([past_target, past_covariate], dim=2)
Expand All @@ -363,13 +372,13 @@ def _get_batch_prediction(
n
prediction length
input_batch
(past_target, past_covariates, future_past_covariates)
``(past_target, past_covariates, future_past_covariates, static_covariates)``
roll_size
roll input arrays after every sequence by ``roll_size``. Initially, ``roll_size`` is equivalent to
``self.output_chunk_length``
"""
dim_component = 2
past_target, past_covariates, future_past_covariates = input_batch
past_target, past_covariates, future_past_covariates, _ = input_batch

n_targets = past_target.shape[dim_component]
n_past_covs = (
Expand Down Expand Up @@ -462,63 +471,56 @@ class PLMixedCovariatesModule(PLForecastingModule, ABC):
def _produce_train_output(
self, input_batch: Tuple
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Feeds MixedCovariatesTorchModel with input and output chunks of a MixedCovariatesSequentialDataset for
training.

Parameters:
----------
input_batch
``(past_target, past_covariates, historic_future_covariates, future_covariates, static_covariates)``.
"""
return self(self._process_input_batch(input_batch))

def _process_input_batch(
self, input_batch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Converts output of MixedCovariatesDataset (training dataset) into an input/past- and
output/future chunk.

Parameters
----------
input_batch
``(past_target, past_covariates, historic_future_covariates, future_covariates)``.
``(past_target, past_covariates, historic_future_covariates, future_covariates, static_covariates)``.

Returns
-------
tuple
``(x_past, x_future)`` the input/past and output/future chunks.
``(x_past, x_future, x_static)`` the input/past and output/future chunks.
"""

(
past_target,
past_covariates,
historic_future_covariates,
future_covariates,
static_covariates,
) = input_batch
dim_variable = 2

# TODO: impelement static covariates
static_covariates = None

x_past = torch.cat(
[
tensor
for tensor in [
past_target,
past_covariates,
historic_future_covariates,
static_covariates,
]
if tensor is not None
],
dim=dim_variable,
)

x_future = None
if future_covariates is not None or static_covariates is not None:
x_future = torch.cat(
[
tensor
for tensor in [future_covariates, static_covariates]
if tensor is not None
],
dim=dim_variable,
)

return x_past, x_future
return x_past, future_covariates, static_covariates

def _get_batch_prediction(
self, n: int, input_batch: Tuple, roll_size: int
Expand All @@ -545,6 +547,7 @@ def _get_batch_prediction(
historic_future_covariates,
future_covariates,
future_past_covariates,
static_covariates,
) = input_batch

n_targets = past_target.shape[dim_component]
Expand All @@ -557,18 +560,19 @@ def _get_batch_prediction(
else 0
)

input_past, input_future = self._process_input_batch(
input_past, input_future, input_static = self._process_input_batch(
(
past_target,
past_covariates,
historic_future_covariates,
future_covariates[:, :roll_size, :]
if future_covariates is not None
else None,
static_covariates,
)
)

out = self._produce_predict_output(x=(input_past, input_future))[
out = self._produce_predict_output(x=(input_past, input_future, input_static))[
:, self.first_prediction_index :, :
]

Expand Down Expand Up @@ -636,9 +640,9 @@ def _get_batch_prediction(
input_future = future_covariates[:, left_future:right_future, :]

# take only last part of the output sequence where needed
out = self._produce_predict_output(x=(input_past, input_future))[
:, self.first_prediction_index :, :
]
out = self._produce_predict_output(
x=(input_past, input_future, input_static)
)[:, self.first_prediction_index :, :]

batch_prediction.append(out)
prediction_length += self.output_chunk_length
Expand Down
4 changes: 2 additions & 2 deletions darts/models/forecasting/rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def forward(self, x, h=None):
return predictions, last_hidden_state

def _produce_train_output(self, input_batch: Tuple):
past_target, historic_future_covariates, future_covariates = input_batch
past_target, historic_future_covariates, future_covariates, _ = input_batch
# For the RNN we concatenate the past_target with the future_covariates
# (they have the same length because we enforce a Shift dataset for RNNs)
model_input = (
Expand All @@ -127,7 +127,7 @@ def _get_batch_prediction(
"""
This model is recurrent, so we have to write a specific way to obtain the time series forecasts of length n.
"""
past_target, historic_future_covariates, future_covariates = input_batch
past_target, historic_future_covariates, future_covariates, _ = input_batch

if historic_future_covariates is not None:
# RNNs need as inputs (target[t] and covariates[t+1]) so here we shift the covariates
Expand Down
41 changes: 17 additions & 24 deletions darts/models/forecasting/tft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
logger = get_logger(__name__)

MixedCovariatesTrainTensorType = Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
]


Expand Down Expand Up @@ -331,26 +331,25 @@ def get_attention_mask_future(
)
return mask

def forward(self, x: Tuple[torch.Tensor, Optional[torch.Tensor]]) -> torch.Tensor:
def forward(
self, x: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]
) -> torch.Tensor:
"""TFT model forward pass.

Parameters
----------
x
comes as tuple `(x_past, x_future)` where `x_past` is the input/past chunk and `x_future`
comes as tuple `(x_past, x_future, x_static)` where `x_past` is the input/past chunk and `x_future`
is the output/future chunk. Input dimensions are `(n_samples, n_time_steps, n_variables)`

Returns
-------
torch.Tensor
the output tensor
"""
x_cont_past, x_cont_future = x
x_cont_past, x_cont_future, x_static = x
dim_samples, dim_time, dim_variable = 0, 1, 2

# TODO: impelement static covariates
static_covariates = None

batch_size = x_cont_past.shape[dim_samples]
encoder_length = self.input_chunk_length
decoder_length = self.output_chunk_length
Expand Down Expand Up @@ -411,27 +410,21 @@ def forward(self, x: Tuple[torch.Tensor, Optional[torch.Tensor]]) -> torch.Tenso
}

# Embedding and variable selection
if static_covariates is not None:
# TODO: impelement static covariates
# # static embeddings will be constant over entire batch
# static_embedding = {name: input_vectors[name][:, 0] for name in self.static_variables}
# static_embedding, static_covariate_var = self.static_covariates_vsn(static_embedding)
raise NotImplementedError("Static covariates have yet to be defined")
if self.static_variables:
static_embedding = {
name: x_static[:, 0, i].unsqueeze(-1)
dennisbader marked this conversation as resolved.
Show resolved Hide resolved
for i, name in enumerate(self.static_variables)
}
static_embedding, static_covariate_var = self.static_covariates_vsn(
static_embedding
)
else:
static_embedding = torch.zeros(
(x_cont_past.shape[0], self.hidden_size),
dtype=x_cont_past.dtype,
device=self.device,
)

# # TODO: implement below when static covariates are supported
# # this is only to interpret the output
# static_covariate_var = torch.zeros(
# (x_cont_past.shape[0], 0),
# dtype=x_cont_past.dtype,
# device=x_cont_past.device,
# )

static_context_expanded = self.expand_static_context(
context=self.static_context_grn(static_embedding), time_steps=time_steps
)
Expand Down Expand Up @@ -751,7 +744,8 @@ def __init__(
def _create_model(self, train_sample: MixedCovariatesTrainTensorType) -> nn.Module:
"""
`train_sample` contains the following tensors:
(past_target, past_covariates, historic_future_covariates, future_covariates, future_target)
(past_target, past_covariates, historic_future_covariates, future_covariates, static_covariates,
future_target)

each tensor has shape (n_timesteps, n_variables)
- past/historic tensors have shape (input_chunk_length, n_variables)
Expand All @@ -771,6 +765,7 @@ def _create_model(self, train_sample: MixedCovariatesTrainTensorType) -> nn.Modu
past_covariate,
historic_future_covariate,
future_covariate,
static_covariates,
future_target,
) = train_sample

Expand All @@ -797,8 +792,6 @@ def _create_model(self, train_sample: MixedCovariatesTrainTensorType) -> nn.Modu
axis=1,
)

static_covariates = None # placeholder for future

self.output_dim = (
(future_target.shape[1], 1)
if self.likelihood is None
Expand Down
16 changes: 12 additions & 4 deletions darts/models/forecasting/tft_submodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
'
"""

from typing import Dict, List, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -378,17 +378,25 @@ def __init__(
self,
input_sizes: Dict[str, int],
hidden_size: int,
input_embedding_flags: Dict[str, bool] = {},
input_embedding_flags: Optional[Dict[str, bool]] = None,
dennisbader marked this conversation as resolved.
Show resolved Hide resolved
dropout: float = 0.1,
context_size: int = None,
single_variable_grns: Dict[str, _GatedResidualNetwork] = {},
prescalers: Dict[str, nn.Linear] = {},
single_variable_grns: Optional[Dict[str, _GatedResidualNetwork]] = None,
prescalers: Optional[Dict[str, nn.Linear]] = None,
):
"""
Calcualte weights for ``num_inputs`` variables which are each of size ``input_size``
"""
super().__init__()

input_embedding_flags = (
input_embedding_flags if input_embedding_flags is not None else {}
)
single_variable_grns = (
single_variable_grns if single_variable_grns is not None else {}
)
prescalers = prescalers if prescalers is not None else {}

self.hidden_size = hidden_size
self.input_sizes = input_sizes
self.input_embedding_flags = input_embedding_flags
Expand Down