Skip to content

Commit

Permalink
Merge branch 'master' into trainer_kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
martins0n authored Feb 17, 2022
2 parents d1fd2cd + a2393b4 commit 9f86088
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 1 deletion.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
-
-
- [BUG] nn models make forecast without inverse_transform ([#541](https://github.com/tinkoff-ai/etna/pull/541))

## [1.6.3] - 2022-02-14

Expand Down
1 change: 1 addition & 0 deletions etna/models/nn/deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,4 +175,5 @@ def forecast(self, ts: TSDataset) -> TSDataset:
# shape (segments, encoder_length)

ts.loc[:, pd.IndexSlice[:, "target"]] = predicts.T[-len(ts.df) :]
ts.inverse_transform()
return ts
1 change: 1 addition & 0 deletions etna/models/nn/tft.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,4 +182,5 @@ def forecast(self, ts: TSDataset) -> TSDataset:
# shape (segments, encoder_length)

ts.loc[:, pd.IndexSlice[:, "target"]] = predicts.T[-len(ts.df) :]
ts.inverse_transform()
return ts
41 changes: 41 additions & 0 deletions tests/test_models/nn/test_deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from etna.transforms import AddConstTransform
from etna.transforms import DateFlagsTransform
from etna.transforms import PytorchForecastingTransform
from etna.transforms import StandardScalerTransform


def test_fit_wrong_order_transform(weekly_period_df):
Expand Down Expand Up @@ -68,6 +69,46 @@ def test_deepar_model_run_weekly_overfit(weekly_period_df, horizon):
assert mae(ts_test, ts_pred) < 0.2207


@pytest.mark.long
@pytest.mark.parametrize("horizon", [8])
def test_deepar_model_run_weekly_overfit_with_scaler(weekly_period_df, horizon):
"""
Given: I have dataframe with 2 segments with weekly seasonality with known future
When: I use scale transformations
Then: I get {horizon} periods per dataset as a forecast and they "the same" as past
"""

ts_start = sorted(set(weekly_period_df.timestamp))[-horizon]
train, test = (
weekly_period_df[lambda x: x.timestamp < ts_start],
weekly_period_df[lambda x: x.timestamp >= ts_start],
)

ts_train = TSDataset(TSDataset.to_dataset(train), "D")
ts_test = TSDataset(TSDataset.to_dataset(test), "D")
std = StandardScalerTransform(in_column="target")
dft = DateFlagsTransform(day_number_in_week=True, day_number_in_month=False, out_column="regressor_dateflags")
pft = PytorchForecastingTransform(
max_encoder_length=21,
max_prediction_length=horizon,
time_varying_known_reals=["time_idx"],
time_varying_known_categoricals=["regressor_dateflags_day_number_in_week"],
time_varying_unknown_reals=["target"],
target_normalizer=GroupNormalizer(groups=["segment"]),
)

ts_train.fit_transform([std, dft, pft])

model = DeepARModel(max_epochs=300, learning_rate=[0.1])
ts_pred = ts_train.make_future(horizon)
model.fit(ts_train)
ts_pred = model.forecast(ts_pred)

mae = MAE("macro")

assert mae(ts_test, ts_pred) < 0.2207


def test_forecast_without_make_future(weekly_period_df):
ts = TSDataset(TSDataset.to_dataset(weekly_period_df), "D")
pft = PytorchForecastingTransform(
Expand Down
42 changes: 42 additions & 0 deletions tests/test_models/nn/test_tft.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from etna.transforms import AddConstTransform
from etna.transforms import DateFlagsTransform
from etna.transforms import PytorchForecastingTransform
from etna.transforms import StandardScalerTransform


def test_fit_wrong_order_transform(weekly_period_df):
Expand Down Expand Up @@ -68,6 +69,47 @@ def test_tft_model_run_weekly_overfit(weekly_period_df, horizon):
assert mae(ts_test, ts_pred) < 0.24


@pytest.mark.long
@pytest.mark.parametrize("horizon", [8])
def test_tft_model_run_weekly_overfit_with_scaler(weekly_period_df, horizon):
"""
Given: I have dataframe with 2 segments with weekly seasonality with known future
When: I use scale transformations
Then: I get {horizon} periods per dataset as a forecast and they "the same" as past
"""

ts_start = sorted(set(weekly_period_df.timestamp))[-horizon]
train, test = (
weekly_period_df[lambda x: x.timestamp < ts_start],
weekly_period_df[lambda x: x.timestamp >= ts_start],
)

ts_train = TSDataset(TSDataset.to_dataset(train), "D")
ts_test = TSDataset(TSDataset.to_dataset(test), "D")
std = StandardScalerTransform(in_column="target")
dft = DateFlagsTransform(day_number_in_week=True, day_number_in_month=False, out_column="regressor_dateflag")
pft = PytorchForecastingTransform(
max_encoder_length=21,
min_encoder_length=21,
max_prediction_length=horizon,
time_varying_known_reals=["time_idx"],
time_varying_known_categoricals=["regressor_dateflag_day_number_in_week"],
time_varying_unknown_reals=["target"],
static_categoricals=["segment"],
target_normalizer=None,
)

ts_train.fit_transform([std, dft, pft])

model = TFTModel(max_epochs=300, learning_rate=[0.1])
ts_pred = ts_train.make_future(horizon)
model.fit(ts_train)
ts_pred = model.forecast(ts_pred)

mae = MAE("macro")
assert mae(ts_test, ts_pred) < 0.24


def test_forecast_without_make_future(weekly_period_df):
ts = TSDataset(TSDataset.to_dataset(weekly_period_df), "D")
pft = PytorchForecastingTransform(
Expand Down

0 comments on commit 9f86088

Please sign in to comment.