From f937dd9d957b6269fe309aec796b1990ec04c93c Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Tue, 28 Sep 2021 14:05:31 +0300 Subject: [PATCH 1/6] Fix bug with hardcoded frequency in PytorchForecastingTransformer, add notes about position sensitivity of transform, write tests --- etna/transforms/pytorch_forecasting.py | 27 ++++++++++++++----- tests/test_models/nn/test_deepar.py | 4 +-- tests/test_models/nn/test_tft.py | 6 ++--- .../test_pytorch_forecasting_transform.py | 25 +++++++++++++++++ 4 files changed, 50 insertions(+), 12 deletions(-) create mode 100644 tests/test_transforms/test_pytorch_forecasting_transform.py diff --git a/etna/transforms/pytorch_forecasting.py b/etna/transforms/pytorch_forecasting.py index 5cab6c31d..b8b3dffa9 100644 --- a/etna/transforms/pytorch_forecasting.py +++ b/etna/transforms/pytorch_forecasting.py @@ -48,6 +48,10 @@ def __init__( ): """Parameters for TimeSeriesDataSet object. + Notes + ----- + This transform should be added at the very end of `transforms` parameter. + Reference --------- https://github.com/jdb78/pytorch-forecasting/blob/v0.8.5/pytorch_forecasting/data/timeseries.py#L117 @@ -76,6 +80,14 @@ def __init__( self.lags = lags self.scalers = scalers + @staticmethod + def _calculate_freq_unit(freq: str) -> pd.Timedelta: + """Calculate frequency unit by its string representation.""" + if freq[0].isdigit(): + return pd.Timedelta(freq) + else: + return pd.Timedelta(1, unit=freq) + def fit(self, df: pd.DataFrame) -> "PytorchForecastingTransform": """ Fit TimeSeriesDataSet. @@ -89,8 +101,8 @@ def fit(self, df: pd.DataFrame) -> "PytorchForecastingTransform": ------- PytorchForecastingTransform """ - ts = TSDataset(df, "1d") - self.freq = ts.freq + self.freq = pd.infer_freq(df.index) + ts = TSDataset(df, self.freq) ts = ts.to_pandas(flatten=True) ts = ts.dropna() self.min_timestamp = ts.timestamp.min() @@ -99,8 +111,8 @@ def fit(self, df: pd.DataFrame) -> "PytorchForecastingTransform": for feature_name in self.time_varying_known_categoricals: ts[feature_name] = ts[feature_name].astype(str) - ts["time_idx"] = ts["timestamp"] - self.min_timestamp - ts["time_idx"] = ts["time_idx"].apply(lambda x: x / self.freq) + freq_unit = self._calculate_freq_unit(self.freq) + ts["time_idx"] = (ts["timestamp"] - self.min_timestamp) / freq_unit ts["time_idx"] = ts["time_idx"].astype(int) pf_dataset = TimeSeriesDataSet( @@ -151,14 +163,15 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame: We save TimeSeriesDataSet in instance to use it in the model. It`s not right pattern of using Transforms and TSDataset. """ - ts = TSDataset(df, "1d") + ts = TSDataset(df, self.freq) ts = ts.to_pandas(flatten=True) ts = ts[ts.timestamp >= self.min_timestamp] ts = ts.fillna(0) - ts["time_idx"] = ts["timestamp"] - self.min_timestamp - ts["time_idx"] = ts["time_idx"].apply(lambda x: x / self.freq) + freq_unit = self._calculate_freq_unit(self.freq) + ts["time_idx"] = (ts["timestamp"] - self.min_timestamp) / freq_unit ts["time_idx"] = ts["time_idx"].astype(int) + if self.time_varying_known_categoricals: for feature_name in self.time_varying_known_categoricals: ts[feature_name] = ts[feature_name].astype(str) diff --git a/tests/test_models/nn/test_deepar.py b/tests/test_models/nn/test_deepar.py index 14c6b23d8..9e8aecb1c 100644 --- a/tests/test_models/nn/test_deepar.py +++ b/tests/test_models/nn/test_deepar.py @@ -31,8 +31,8 @@ def test_deepar_model_run_weekly_overfit(weekly_period_df, horizon): weekly_period_df[lambda x: x.timestamp >= ts_start], ) - ts_train = TSDataset(TSDataset.to_dataset(train), "1d") - ts_test = TSDataset(TSDataset.to_dataset(test), "1d") + ts_train = TSDataset(TSDataset.to_dataset(train), "D") + ts_test = TSDataset(TSDataset.to_dataset(test), "D") dft = DateFlagsTransform(day_number_in_week=True, day_number_in_month=False) pft = PytorchForecastingTransform( max_encoder_length=21, diff --git a/tests/test_models/nn/test_tft.py b/tests/test_models/nn/test_tft.py index 91186e5f7..cbc998810 100644 --- a/tests/test_models/nn/test_tft.py +++ b/tests/test_models/nn/test_tft.py @@ -31,8 +31,8 @@ def test_tft_model_run_weekly_overfit(weekly_period_df, horizon): weekly_period_df[lambda x: x.timestamp >= ts_start], ) - ts_train = TSDataset(TSDataset.to_dataset(train), "1d") - ts_test = TSDataset(TSDataset.to_dataset(test), "1d") + ts_train = TSDataset(TSDataset.to_dataset(train), "D") + ts_test = TSDataset(TSDataset.to_dataset(test), "D") dft = DateFlagsTransform(day_number_in_week=True, day_number_in_month=False) pft = PytorchForecastingTransform( max_encoder_length=21, @@ -53,4 +53,4 @@ def test_tft_model_run_weekly_overfit(weekly_period_df, horizon): ts_pred = tftmodel.forecast(ts_pred) mae = MAE("macro") - assert mae(ts_test, ts_pred) < 0.23 + assert mae(ts_test, ts_pred) < 0.24 diff --git a/tests/test_transforms/test_pytorch_forecasting_transform.py b/tests/test_transforms/test_pytorch_forecasting_transform.py new file mode 100644 index 000000000..1b9b1cad8 --- /dev/null +++ b/tests/test_transforms/test_pytorch_forecasting_transform.py @@ -0,0 +1,25 @@ +import pytest + +from etna.transforms import PytorchForecastingTransform + + +@pytest.mark.parametrize("days_offset", [1, 2, 5, 10]) +def test_time_idx(days_offset, example_tsds): + """Check that PytorchForecastingTransform works with different frequencies correctly.""" + df = example_tsds.to_pandas() + new_df = df.loc[df.index[::days_offset]] + + transform = PytorchForecastingTransform( + max_encoder_length=3, + min_encoder_length=3, + max_prediction_length=3, + time_varying_known_reals=["time_idx"], + time_varying_unknown_reals=["target"], + static_categoricals=["segment"], + ) + transform.fit_transform(new_df) + + time_idx = transform.pf_dataset_train.data["time"].tolist() + expected_len = new_df.shape[0] + expected_list = list(range(expected_len)) * len(example_tsds.segments) + assert time_idx == expected_list From 2c42aea468433d1d1b4c3084fa86743ef1d8b46c Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Tue, 28 Sep 2021 14:09:45 +0300 Subject: [PATCH 2/6] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 62cd27080..6eb89248c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,6 +41,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fix saving init params for SARIMAXModel ([#81](https://github.com/tinkoff-ai/etna-ts/pull/81)) - Imports of nn models, PytorchForecastingTransform and Transform ([#80](https://github.com/tinkoff-ai/etna-ts/pull/80))) - Add more obvious Exception Error for forecasting with unfitted model ([#102](https://github.com/tinkoff-ai/etna-ts/pull/102)) +- Fix bug with hardcoded frequency in PytorchForecastingTransform ([#107](https://github.com/tinkoff-ai/etna-ts/pull/107)) ## [1.0.0] - 2021-09-05 ### Added From 3b7ecb5e64d39bcdf04d666ae3b378a7d10ea181 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Tue, 28 Sep 2021 14:22:53 +0300 Subject: [PATCH 3/6] Refactor names --- etna/transforms/pytorch_forecasting.py | 30 +++++++++++++------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/etna/transforms/pytorch_forecasting.py b/etna/transforms/pytorch_forecasting.py index b8b3dffa9..56d619c18 100644 --- a/etna/transforms/pytorch_forecasting.py +++ b/etna/transforms/pytorch_forecasting.py @@ -103,20 +103,20 @@ def fit(self, df: pd.DataFrame) -> "PytorchForecastingTransform": """ self.freq = pd.infer_freq(df.index) ts = TSDataset(df, self.freq) - ts = ts.to_pandas(flatten=True) - ts = ts.dropna() - self.min_timestamp = ts.timestamp.min() + df_flat = ts.to_pandas(flatten=True) + df_flat = df_flat.dropna() + self.min_timestamp = df_flat.timestamp.min() if self.time_varying_known_categoricals: for feature_name in self.time_varying_known_categoricals: - ts[feature_name] = ts[feature_name].astype(str) + df_flat[feature_name] = df_flat[feature_name].astype(str) freq_unit = self._calculate_freq_unit(self.freq) - ts["time_idx"] = (ts["timestamp"] - self.min_timestamp) / freq_unit - ts["time_idx"] = ts["time_idx"].astype(int) + df_flat["time_idx"] = (df_flat["timestamp"] - self.min_timestamp) / freq_unit + df_flat["time_idx"] = df_flat["time_idx"].astype(int) pf_dataset = TimeSeriesDataSet( - ts, + df_flat, time_idx="time_idx", target="target", group_ids=["segment"], @@ -164,24 +164,24 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame: It`s not right pattern of using Transforms and TSDataset. """ ts = TSDataset(df, self.freq) - ts = ts.to_pandas(flatten=True) - ts = ts[ts.timestamp >= self.min_timestamp] - ts = ts.fillna(0) + df_flat = ts.to_pandas(flatten=True) + df_flat = df_flat[df_flat.timestamp >= self.min_timestamp] + df_flat = df_flat.fillna(0) freq_unit = self._calculate_freq_unit(self.freq) - ts["time_idx"] = (ts["timestamp"] - self.min_timestamp) / freq_unit - ts["time_idx"] = ts["time_idx"].astype(int) + df_flat["time_idx"] = (df_flat["timestamp"] - self.min_timestamp) / freq_unit + df_flat["time_idx"] = df_flat["time_idx"].astype(int) if self.time_varying_known_categoricals: for feature_name in self.time_varying_known_categoricals: - ts[feature_name] = ts[feature_name].astype(str) + df_flat[feature_name] = df_flat[feature_name].astype(str) if inspect.stack()[1].function == "make_future": pf_dataset_predict = TimeSeriesDataSet.from_parameters( - self.pf_dataset_params, ts, predict=True, stop_randomization=True + self.pf_dataset_params, df_flat, predict=True, stop_randomization=True ) self.pf_dataset_predict = pf_dataset_predict else: - pf_dataset_train = TimeSeriesDataSet.from_parameters(self.pf_dataset_params, ts) + pf_dataset_train = TimeSeriesDataSet.from_parameters(self.pf_dataset_params, df_flat) self.pf_dataset_train = pf_dataset_train return df From 817dbf5c3f5b2c51e2bd93a00316d17ebd111fbf Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Tue, 28 Sep 2021 14:53:23 +0300 Subject: [PATCH 4/6] Update changelog --- CHANGELOG.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 19e8e32d8..2c8e1db8a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Add more obvious Exception Error for forecasting with unfitted model ([#102](https://github.com/tinkoff-ai/etna-ts/pull/102)) +- Fix bug with hardcoded frequency in PytorchForecastingTransform ([#107](https://github.com/tinkoff-ai/etna-ts/pull/107)) ## [1.1.1] - 2021-09-23 ### Fixed @@ -46,8 +47,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fix working with 'cap' and 'floor' features in Prophet model ([#62](https://github.com/tinkoff-ai/etna-ts/pull/62))) - Fix saving init params for SARIMAXModel ([#81](https://github.com/tinkoff-ai/etna-ts/pull/81)) - Imports of nn models, PytorchForecastingTransform and Transform ([#80](https://github.com/tinkoff-ai/etna-ts/pull/80))) -- Add more obvious Exception Error for forecasting with unfitted model ([#102](https://github.com/tinkoff-ai/etna-ts/pull/102)) -- Fix bug with hardcoded frequency in PytorchForecastingTransform ([#107](https://github.com/tinkoff-ai/etna-ts/pull/107)) ## [1.0.0] - 2021-09-05 ### Added From e8418b7c3c1128f35bfadf312bd6e13c8e54b26f Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Tue, 28 Sep 2021 17:13:58 +0300 Subject: [PATCH 5/6] Add exception in case of wrong position of PytorchForecastingTransform, add tests for it --- etna/models/nn/deepar.py | 20 +++++++++++++++++--- etna/models/nn/tft.py | 21 +++++++++++++++++---- tests/test_models/nn/test_deepar.py | 25 ++++++++++++++++++++++--- tests/test_models/nn/test_tft.py | 27 ++++++++++++++++++++++++--- 4 files changed, 80 insertions(+), 13 deletions(-) diff --git a/etna/models/nn/deepar.py b/etna/models/nn/deepar.py index 3b5c03f43..0651d93dc 100644 --- a/etna/models/nn/deepar.py +++ b/etna/models/nn/deepar.py @@ -7,6 +7,7 @@ from pytorch_forecasting.data import TimeSeriesDataSet from pytorch_forecasting.models import DeepAR +from etna.transforms import PytorchForecastingTransform from etna.datasets.tsdataset import TSDataset from etna.loggers import tslogger from etna.models.base import Model @@ -89,6 +90,17 @@ def _from_dataset(self, ts_dataset: TimeSeriesDataSet) -> DeepAR: dropout=self.dropout, ) + @staticmethod + def _get_pf_transform(ts: TSDataset) -> PytorchForecastingTransform: + """Get PytorchForecastingTransform from ts.transforms or raise exception if not found.""" + if ts.transforms is not None and isinstance(ts.transforms[-1], PytorchForecastingTransform): + return ts.transforms[-1] + else: + raise ValueError( + "Not valid usage of transforms, please add PytorchForecastingTransform at the end of transforms" + ) + + @log_decorator def fit(self, ts: TSDataset) -> "DeepARModel": """ @@ -103,7 +115,8 @@ def fit(self, ts: TSDataset) -> "DeepARModel": ------- DeepARModel """ - self.model = self._from_dataset(ts.transforms[-1].pf_dataset_train) + pf_transform = self._get_pf_transform(ts) + self.model = self._from_dataset(pf_transform.pf_dataset_train) self.trainer = pl.Trainer( logger=tslogger.pl_loggers, @@ -113,7 +126,7 @@ def fit(self, ts: TSDataset) -> "DeepARModel": gradient_clip_val=self.gradient_clip_val, ) - train_dataloader = ts.transforms[-1].pf_dataset_train.to_dataloader(train=True, batch_size=self.batch_size) + train_dataloader = pf_transform.pf_dataset_train.to_dataloader(train=True, batch_size=self.batch_size) self.trainer.fit(self.model, train_dataloader) @@ -134,7 +147,8 @@ def forecast(self, ts: TSDataset) -> TSDataset: TSDataset TSDataset with predictions. """ - prediction_dataloader = ts.transforms[-1].pf_dataset_predict.to_dataloader( + pf_transform = self._get_pf_transform(ts) + prediction_dataloader = pf_transform.pf_dataset_predict.to_dataloader( train=False, batch_size=self.batch_size * 2 ) diff --git a/etna/models/nn/tft.py b/etna/models/nn/tft.py index b3744a515..b31cc71f3 100644 --- a/etna/models/nn/tft.py +++ b/etna/models/nn/tft.py @@ -8,6 +8,7 @@ from pytorch_forecasting.models import TemporalFusionTransformer from etna.datasets.tsdataset import TSDataset +from etna.transforms import PytorchForecastingTransform from etna.loggers import tslogger from etna.models.base import Model from etna.models.base import log_decorator @@ -97,6 +98,16 @@ def _from_dataset(self, ts_dataset: TimeSeriesDataSet) -> TemporalFusionTransfor hidden_continuous_size=self.hidden_continuous_size, ) + @staticmethod + def _get_pf_transform(ts: TSDataset) -> PytorchForecastingTransform: + """Get PytorchForecastingTransform from ts.transforms or raise exception if not found.""" + if ts.transforms is not None and isinstance(ts.transforms[-1], PytorchForecastingTransform): + return ts.transforms[-1] + else: + raise ValueError( + "Not valid usage of transforms, please add PytorchForecastingTransform at the end of transforms" + ) + @log_decorator def fit(self, ts: TSDataset) -> "TFTModel": """ @@ -111,7 +122,8 @@ def fit(self, ts: TSDataset) -> "TFTModel": ------- TFTModel """ - self.model = self._from_dataset(ts.transforms[-1].pf_dataset_train) + pf_transform = self._get_pf_transform(ts) + self.model = self._from_dataset(pf_transform.pf_dataset_train) self.trainer = pl.Trainer( logger=tslogger.pl_loggers, @@ -121,14 +133,14 @@ def fit(self, ts: TSDataset) -> "TFTModel": gradient_clip_val=self.gradient_clip_val, ) - train_dataloader = ts.transforms[-1].pf_dataset_train.to_dataloader(train=True, batch_size=self.batch_size) + train_dataloader = pf_transform.pf_dataset_train.to_dataloader(train=True, batch_size=self.batch_size) self.trainer.fit(self.model, train_dataloader) return self @log_decorator - def forecast(self, ts: TSDataset) -> pd.DataFrame: + def forecast(self, ts: TSDataset) -> TSDataset: """ Predict future. @@ -142,7 +154,8 @@ def forecast(self, ts: TSDataset) -> pd.DataFrame: TSDataset TSDataset with predictions. """ - prediction_dataloader = ts.transforms[-1].pf_dataset_predict.to_dataloader( + pf_transform = self._get_pf_transform(ts) + prediction_dataloader = pf_transform.pf_dataset_predict.to_dataloader( train=False, batch_size=self.batch_size * 2 ) diff --git a/tests/test_models/nn/test_deepar.py b/tests/test_models/nn/test_deepar.py index e1ca8cc80..c19445d5b 100644 --- a/tests/test_models/nn/test_deepar.py +++ b/tests/test_models/nn/test_deepar.py @@ -9,9 +9,28 @@ from etna.metrics import MAE from etna.models.nn import DeepARModel from etna.transforms import DateFlagsTransform +from etna.transforms import AddConstTransform from etna.transforms import PytorchForecastingTransform +def test_fit_wrong_order_transform(weekly_period_df): + ts = TSDataset(TSDataset.to_dataset(weekly_period_df), "D") + add_const = AddConstTransform(in_column="target", value=1.0) + pft = PytorchForecastingTransform( + max_encoder_length=21, + max_prediction_length=8, + time_varying_known_reals=["time_idx"], + time_varying_unknown_reals=["target"], + target_normalizer=GroupNormalizer(groups=["segment"]), + ) + + ts.fit_transform([pft, add_const]) + + model = DeepARModel(max_epochs=300, learning_rate=[0.1]) + with pytest.raises(ValueError, match="add PytorchForecastingTransform"): + model.fit(ts) + + @pytest.mark.long @pytest.mark.parametrize("horizon", [8, 21]) def test_deepar_model_run_weekly_overfit(weekly_period_df, horizon): @@ -45,10 +64,10 @@ def test_deepar_model_run_weekly_overfit(weekly_period_df, horizon): ts_train.fit_transform([dft, pft]) - tftmodel = DeepARModel(max_epochs=300, learning_rate=[0.1]) + model = DeepARModel(max_epochs=300, learning_rate=[0.1]) ts_pred = ts_train.make_future(horizon) - tftmodel.fit(ts_train) - ts_pred = tftmodel.forecast(ts_pred) + model.fit(ts_train) + ts_pred = model.forecast(ts_pred) mae = MAE("macro") diff --git a/tests/test_models/nn/test_tft.py b/tests/test_models/nn/test_tft.py index ab5d19bb4..2059d1dd1 100644 --- a/tests/test_models/nn/test_tft.py +++ b/tests/test_models/nn/test_tft.py @@ -8,9 +8,30 @@ from etna.metrics import MAE from etna.models.nn import TFTModel from etna.transforms import DateFlagsTransform +from etna.transforms import AddConstTransform from etna.transforms import PytorchForecastingTransform +def test_fit_wrong_order_transform(weekly_period_df): + ts = TSDataset(TSDataset.to_dataset(weekly_period_df), "D") + add_const = AddConstTransform(in_column="target", value=1.0) + pft = PytorchForecastingTransform( + max_encoder_length=21, + min_encoder_length=21, + max_prediction_length=8, + time_varying_known_reals=["time_idx"], + time_varying_unknown_reals=["target"], + static_categoricals=["segment"], + target_normalizer=None, + ) + + ts.fit_transform([pft, add_const]) + + model = TFTModel(max_epochs=300, learning_rate=[0.1]) + with pytest.raises(ValueError, match="add PytorchForecastingTransform"): + model.fit(ts) + + @pytest.mark.long @pytest.mark.parametrize("horizon", [8, 21]) def test_tft_model_run_weekly_overfit(weekly_period_df, horizon): @@ -47,10 +68,10 @@ def test_tft_model_run_weekly_overfit(weekly_period_df, horizon): ts_train.fit_transform([dft, pft]) - tftmodel = TFTModel(max_epochs=300, learning_rate=[0.1]) + model = TFTModel(max_epochs=300, learning_rate=[0.1]) ts_pred = ts_train.make_future(horizon) - tftmodel.fit(ts_train) - ts_pred = tftmodel.forecast(ts_pred) + model.fit(ts_train) + ts_pred = model.forecast(ts_pred) mae = MAE("macro") assert mae(ts_test, ts_pred) < 0.24 From 3290dcdccf88fb6faeba7b766ea0f6c17b1b1dad Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Tue, 28 Sep 2021 17:18:58 +0300 Subject: [PATCH 6/6] Reformat code --- etna/models/nn/deepar.py | 3 +-- etna/models/nn/tft.py | 2 +- tests/test_models/nn/test_deepar.py | 2 +- tests/test_models/nn/test_tft.py | 2 +- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/etna/models/nn/deepar.py b/etna/models/nn/deepar.py index 0651d93dc..d4d53f20f 100644 --- a/etna/models/nn/deepar.py +++ b/etna/models/nn/deepar.py @@ -7,11 +7,11 @@ from pytorch_forecasting.data import TimeSeriesDataSet from pytorch_forecasting.models import DeepAR -from etna.transforms import PytorchForecastingTransform from etna.datasets.tsdataset import TSDataset from etna.loggers import tslogger from etna.models.base import Model from etna.models.base import log_decorator +from etna.transforms import PytorchForecastingTransform class DeepARModel(Model): @@ -100,7 +100,6 @@ def _get_pf_transform(ts: TSDataset) -> PytorchForecastingTransform: "Not valid usage of transforms, please add PytorchForecastingTransform at the end of transforms" ) - @log_decorator def fit(self, ts: TSDataset) -> "DeepARModel": """ diff --git a/etna/models/nn/tft.py b/etna/models/nn/tft.py index b31cc71f3..3de36a4ed 100644 --- a/etna/models/nn/tft.py +++ b/etna/models/nn/tft.py @@ -8,10 +8,10 @@ from pytorch_forecasting.models import TemporalFusionTransformer from etna.datasets.tsdataset import TSDataset -from etna.transforms import PytorchForecastingTransform from etna.loggers import tslogger from etna.models.base import Model from etna.models.base import log_decorator +from etna.transforms import PytorchForecastingTransform class TFTModel(Model): diff --git a/tests/test_models/nn/test_deepar.py b/tests/test_models/nn/test_deepar.py index c19445d5b..1811791d6 100644 --- a/tests/test_models/nn/test_deepar.py +++ b/tests/test_models/nn/test_deepar.py @@ -8,8 +8,8 @@ from etna.datasets.tsdataset import TSDataset from etna.metrics import MAE from etna.models.nn import DeepARModel -from etna.transforms import DateFlagsTransform from etna.transforms import AddConstTransform +from etna.transforms import DateFlagsTransform from etna.transforms import PytorchForecastingTransform diff --git a/tests/test_models/nn/test_tft.py b/tests/test_models/nn/test_tft.py index 2059d1dd1..a5b3f91b6 100644 --- a/tests/test_models/nn/test_tft.py +++ b/tests/test_models/nn/test_tft.py @@ -7,8 +7,8 @@ from etna.datasets.tsdataset import TSDataset from etna.metrics import MAE from etna.models.nn import TFTModel -from etna.transforms import DateFlagsTransform from etna.transforms import AddConstTransform +from etna.transforms import DateFlagsTransform from etna.transforms import PytorchForecastingTransform