Skip to content

Commit

Permalink
added test for TFTModel categorical static covariate support
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisbader committed Jul 20, 2022
1 parent ba82e6d commit 81be526
Showing 1 changed file with 30 additions and 2 deletions.
32 changes: 30 additions & 2 deletions darts/tests/models/forecasting/test_TFT.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch.nn import MSELoss

from darts.models.forecasting.tft_model import TFTModel
from darts.models.forecasting.tft_submodels import get_embedding_size
from darts.utils.likelihood_models import QuantileRegression

TORCH_AVAILABLE = True
Expand Down Expand Up @@ -171,22 +172,49 @@ def test_static_covariates_support(self):
)

target_multi = target_multi.with_static_covariates(
pd.DataFrame([[0.0, 1.0], [2.0, 3.0]], index=["st1", "st2"])
pd.DataFrame(
[[0.0, 1.0, 0, 2], [2.0, 3.0, 1, 3]],
columns=["st1", "st2", "cat1", "cat2"],
)
)

# should work with cyclic encoding for time index
# set categorical embedding sizes once with automatic embedding size with an `int` and once by
# manually setting it with `tuple(int, int)`
model = TFTModel(
input_chunk_length=3,
output_chunk_length=4,
add_encoders={"cyclic": {"future": "hour"}},
categorical_embedding_sizes={"cat1": 2, "cat2": (2, 2)},
pl_trainer_kwargs={"fast_dev_run": True},
)
model.fit(target_multi, verbose=False)

assert len(model.model.static_variables) == len(
target_multi.static_covariates.columns
)

model.predict(n=1, series=target_multi, verbose=False)
# check model embeddings
target_embedding = {
"static_covariate_2": (
2,
get_embedding_size(2),
), # automatic embedding size
"static_covariate_3": (2, 2), # manual embedding size
}
assert model.categorical_embedding_sizes == target_embedding
for cat_var, embedding_dims in target_embedding.items():
assert (
model.model.input_embeddings.embeddings[cat_var].num_embeddings
== embedding_dims[0]
)
assert (
model.model.input_embeddings.embeddings[cat_var].embedding_dim
== embedding_dims[1]
)

preds = model.predict(n=1, series=target_multi, verbose=False)
assert preds.static_covariates.equals(target_multi.static_covariates)

# raise an error when trained with static covariates of wrong dimensionality
target_multi = target_multi.with_static_covariates(
Expand Down

0 comments on commit 81be526

Please sign in to comment.