Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/tft static categorical #1081

Merged
merged 20 commits into from
Jul 20, 2022
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
2dd3e7b
categorical static covariate support for TFTModel
dennisbader Jun 26, 2022
624feee
from_group_dataframe fix
dennisbader Jun 26, 2022
de0963d
Merge branch 'master' into feat/tft_static_categorical
dennisbader Jul 8, 2022
34221b7
added static covariate transformer
dennisbader Jul 8, 2022
6aa1635
OneHotEncoder support for StaticCovariatesTransformer
dennisbader Jul 9, 2022
95b9b84
small fix
dennisbader Jul 9, 2022
c82272f
TFTModel static covariate handling
dennisbader Jul 10, 2022
da1143b
improved transformer with specifying which columns to transform
dennisbader Jul 11, 2022
d7b4f4f
Merge branch 'master' into feat/tft_static_categorical
dennisbader Jul 12, 2022
a274438
docs improvements
dennisbader Jul 12, 2022
aef7e09
docstring improvement for StaticCvoariatesTransformer
dennisbader Jul 17, 2022
91770f0
added static covariates notebook example
dennisbader Jul 17, 2022
e785ecb
Merge branch 'master' into feat/tft_static_categorical
dennisbader Jul 17, 2022
044bd2b
TFTModel docstring update
dennisbader Jul 17, 2022
094d57f
Apply suggestions from code review
dennisbader Jul 18, 2022
8e10b86
applied suggestions from PR review
dennisbader Jul 18, 2022
00021df
Merge branch 'master' into feat/tft_static_categorical
dennisbader Jul 18, 2022
b9f5b34
applied suggestions from PR review part 2
dennisbader Jul 18, 2022
ba82e6d
added automatic embedding size option for TFTModel
dennisbader Jul 19, 2022
81be526
added test for TFTModel categorical static covariate support
dennisbader Jul 20, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/merge.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
example-name: [00-quickstart.ipynb, 01-multi-time-series-and-covariates.ipynb, 02-data-processing.ipynb, 03-FFT-examples.ipynb, 04-RNN-examples.ipynb, 05-TCN-examples.ipynb, 06-Transformer-examples.ipynb, 07-NBEATS-examples.ipynb, 08-DeepAR-examples.ipynb, 09-DeepTCN-examples.ipynb, 10-Kalman-filter-examples.ipynb, 11-GP-filter-examples.ipynb, 12-Dynamic-Time-Warping-example.ipynb, 13-TFT-examples.ipynb]
example-name: [00-quickstart.ipynb, 01-multi-time-series-and-covariates.ipynb, 02-data-processing.ipynb, 03-FFT-examples.ipynb, 04-RNN-examples.ipynb, 05-TCN-examples.ipynb, 06-Transformer-examples.ipynb, 07-NBEATS-examples.ipynb, 08-DeepAR-examples.ipynb, 09-DeepTCN-examples.ipynb, 10-Kalman-filter-examples.ipynb, 11-GP-filter-examples.ipynb, 12-Dynamic-Time-Warping-example.ipynb, 13-TFT-examples.ipynb, 15-static-covariates.ipynb]
steps:
- name: "1. Clone repository"
uses: actions/checkout@v2
Expand Down
1 change: 1 addition & 0 deletions darts/dataprocessing/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
TopDownReconciliator,
)
from .scaler import Scaler
from .static_covariates_transformer import StaticCovariatesTransformer
366 changes: 366 additions & 0 deletions darts/dataprocessing/transformers/static_covariates_transformer.py

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions darts/models/forecasting/block_rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@ def __init__(
self.n_rnn_layers = n_rnn_layers
self.dropout = dropout

@staticmethod
def _supports_static_covariates() -> bool:
return False

def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module:
# samples are made of (past_target, past_covariates, future_target)
input_dim = train_sample[0].shape[1] + (
Expand Down
6 changes: 6 additions & 0 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def __init__(self, *args, **kwargs):
# This is only used if the model has been fit on one time series.
self.training_series: Optional[TimeSeries] = None

# Static covariates sample from the (first) target series used for training the model through the `fit()`
# function.
self.static_covariates: Optional[pd.DataFrame] = None

# state; whether the model has been fit (on a single time series)
self._fit_called = False

Expand Down Expand Up @@ -959,11 +963,13 @@ def fit(
if isinstance(series, TimeSeries):
# if only one series is provided, save it for prediction time (including covariates, if available)
self.training_series = series
self.static_covariates = series.static_covariates
if past_covariates is not None:
self.past_covariate_series = past_covariates
if future_covariates is not None:
self.future_covariate_series = future_covariates
else:
self.static_covariates = series[0].static_covariates
if past_covariates is not None:
self._expect_past_covariates = True
if future_covariates is not None:
Expand Down
4 changes: 4 additions & 0 deletions darts/models/forecasting/nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,10 @@ def __init__(
if isinstance(layer_widths, int):
self.layer_widths = [layer_widths] * num_stacks

@staticmethod
def _supports_static_covariates() -> bool:
return False

def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module:
# samples are made of (past_target, past_covariates, future_target)
input_dim = train_sample[0].shape[1] + (
Expand Down
4 changes: 4 additions & 0 deletions darts/models/forecasting/nhits.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,10 @@ def _check_sizes(tup, name):

return pooling_kernel_sizes, n_freq_downsample

@staticmethod
def _supports_static_covariates() -> bool:
return False

def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module:
# samples are made of (past_target, past_covariates, future_target)
input_dim = train_sample[0].shape[1] + (
Expand Down
4 changes: 4 additions & 0 deletions darts/models/forecasting/rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,7 @@ def _verify_train_dataset_type(self, train_dataset: TrainingDataset):
train_dataset.ds_past.shift == 1,
"RNNModel requires a shifted training dataset with shift=1.",
)

@staticmethod
def _supports_static_covariates() -> bool:
return False
4 changes: 4 additions & 0 deletions darts/models/forecasting/tcn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,3 +490,7 @@ def _build_train_dataset(
shift=self.output_chunk_length,
max_samples_per_ts=max_samples_per_ts,
)

@staticmethod
def _supports_static_covariates() -> bool:
return False