Skip to content

Fix bug in new NNs with forecasting interval #1108

Merged
merged 9 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
### Fixed
- Fix inference tests on new segments for `DeepARModel` and `TFTModel` ([#1109](https://github.com/tinkoff-ai/etna/pull/1109))
- Fix alignment during forecasting in new NNs, add validation of context size during forecasting in new NNs, add validation of batch in `MLPNet` ([#1108](https://github.com/tinkoff-ai/etna/pull/1108))
- Fix `MeanSegmentEncoderTransform` to work with subset of segments and raise error on new segments ([#1104](https://github.com/tinkoff-ai/etna/pull/1104))
-
- Fix `SegmentEncoderTransform` to work with subset of segments and raise error on new segments ([#1103](https://github.com/tinkoff-ai/etna/pull/1103))
Expand Down
9 changes: 8 additions & 1 deletion etna/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,14 +622,21 @@ def forecast(self, ts: "TSDataset", prediction_size: int) -> "TSDataset":
:
Dataset with predictions
"""
expected_length = prediction_size + self.encoder_length
if len(ts.index) < expected_length:
raise ValueError(
"Given context isn't big enough, try to decrease context_size, prediction_size or increase length of given dataset!"
)

test_dataset = ts.to_torch_dataset(
make_samples=functools.partial(
self.net.make_samples, encoder_length=self.encoder_length, decoder_length=prediction_size
),
dropna=False,
)
predictions = self.raw_predict(test_dataset)
future_ts = ts.tsdataset_idx_slice(start_idx=self.encoder_length, end_idx=self.encoder_length + prediction_size)
end_idx = len(ts.index)
future_ts = ts.tsdataset_idx_slice(start_idx=end_idx - prediction_size, end_idx=end_idx)
for (segment, feature_nm), value in predictions.items():
# we don't want to change dtype after assignment, but there can happen cast to float32
future_ts.df.loc[:, pd.IndexSlice[segment, feature_nm]] = value[:prediction_size, :].astype(np.float64)
Expand Down
4 changes: 2 additions & 2 deletions etna/models/deadline_ma.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _get_context_beginning(
# if we have len(history_timestamps) == 0, then len(df) <= prediction_size
if len(history_timestamps) == 0:
raise ValueError(
"Given context isn't big enough, try to decrease context_size, prediction_size of increase length of given dataframe!"
"Given context isn't big enough, try to decrease context_size, prediction_size or increase length of given dataframe!"
)

if seasonality is SeasonalityMode.month:
Expand All @@ -130,7 +130,7 @@ def _get_context_beginning(

if first_index < history_timestamps.iloc[0]:
raise ValueError(
"Given context isn't big enough, try to decrease context_size, prediction_size of increase length of given dataframe!"
"Given context isn't big enough, try to decrease context_size, prediction_size or increase length of given dataframe!"
)

return first_index
Expand Down
7 changes: 7 additions & 0 deletions etna/models/nn/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ def __init__(
layers.append(nn.Linear(in_features=hidden_size[-1], out_features=1))
self.mlp = nn.Sequential(*layers)

@staticmethod
def _validate_batch(batch: MLPBatch):
if batch["decoder_real"].isnan().sum().item():
raise ValueError("There are NaNs in features, this model can't work with them!")

def forward(self, batch: MLPBatch): # type: ignore
"""Forward pass.

Expand All @@ -78,6 +83,7 @@ def forward(self, batch: MLPBatch): # type: ignore
:
forecast
"""
self._validate_batch(batch)
decoder_real = batch["decoder_real"].float()
return self.mlp(decoder_real)

Expand All @@ -93,6 +99,7 @@ def step(self, batch: MLPBatch, *args, **kwargs): # type: ignore
:
loss, true_target, prediction_target
"""
self._validate_batch(batch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we validate twice.

We can call forward pass in step

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I understand, step isn't called during forecasting, it is called only during training.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use step in training_step and validation_step, it is for training.

decoder_real = batch["decoder_real"].float()
decoder_target = batch["decoder_target"].float()

Expand Down
4 changes: 2 additions & 2 deletions etna/models/seasonal_ma.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def forecast(self, df: pd.DataFrame, prediction_size: int) -> np.ndarray:
expected_length = prediction_size + self.shift
if len(df) < expected_length:
raise ValueError(
"Given context isn't big enough, try to decrease context_size, prediction_size of increase length of given dataframe!"
"Given context isn't big enough, try to decrease context_size, prediction_size or increase length of given dataframe!"
)

history = df["target"][-expected_length:-prediction_size]
Expand Down Expand Up @@ -127,7 +127,7 @@ def predict(self, df: pd.DataFrame, prediction_size: int) -> np.ndarray:
expected_length = prediction_size + self.shift
if len(df) < expected_length:
raise ValueError(
"Given context isn't big enough, try to decrease context_size, prediction_size of increase length of given dataframe!"
"Given context isn't big enough, try to decrease context_size, prediction_size or increase length of given dataframe!"
)

context = df["target"][-expected_length:].values
Expand Down
23 changes: 22 additions & 1 deletion tests/test_models/nn/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
],
)
def test_mlp_model_run_weekly_overfit_with_scaler(ts_dataset_weekly_function_with_horizon, horizon):

ts_train, ts_test = ts_dataset_weekly_function_with_horizon(horizon)
lag = LagTransform(in_column="target", lags=list(range(horizon, horizon + 4)))
fourier = FourierTransform(period=7, order=3)
Expand Down Expand Up @@ -80,6 +79,17 @@ def test_mlp_make_samples(simple_df_relevance):
np.testing.assert_equal(df[["target"]].iloc[decoder_length : 2 * decoder_length], second_sample["decoder_target"])


def test_mlp_forward_fail_nans():
batch = {
"decoder_real": torch.Tensor([[torch.nan, 2, 3], [1, 2, 3], [1, 2, 3]]),
"decoder_target": torch.Tensor([[1], [2], [3]]),
"segment": "A",
}
model = MLPNet(input_size=3, hidden_size=[1], lr=1e-2, loss=nn.MSELoss(), optimizer_params=None)
with pytest.raises(ValueError, match="There are NaNs in features"):
_ = model.forward(batch)


def test_mlp_step():

batch = {
Expand All @@ -96,6 +106,17 @@ def test_mlp_step():
assert output.shape == torch.Size([3, 1])


def test_mlp_step_fail_nans():
batch = {
"decoder_real": torch.Tensor([[torch.nan, 2, 3], [1, 2, 3], [1, 2, 3]]),
"decoder_target": torch.Tensor([[1], [2], [3]]),
"segment": "A",
}
model = MLPNet(input_size=3, hidden_size=[1], lr=1e-2, loss=nn.MSELoss(), optimizer_params=None)
with pytest.raises(ValueError, match="There are NaNs in features"):
_ = model.step(batch)


def test_mlp_layers():
model = MLPNet(input_size=3, hidden_size=[10], lr=1e-2, loss=None, optimizer_params=None)
model_ = nn.Sequential(
Expand Down
30 changes: 21 additions & 9 deletions tests/test_models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
@pytest.fixture()
def deep_base_model_mock():
model = MagicMock()
model.encoder_length = 10
model.train_batch_size = 32
model.train_dataloader_params = {}
model.val_dataloader_params = {}
Expand All @@ -23,6 +24,13 @@ def deep_base_model_mock():
return model


@pytest.fixture()
def ts_mock():
torch_dataset = MagicMock()
torch_dataset.index.__len__.return_value = 100
return torch_dataset


@pytest.fixture()
def sized_torch_dataset_mock():
torch_dataset = MagicMock()
Expand Down Expand Up @@ -138,29 +146,33 @@ def test_deep_base_model_raw_predict_call(dataloader, deep_base_model_mock):
np.testing.assert_allclose(predictions_dict[("segment2", "target")], batch["target"][1].numpy())


def test_deep_base_model_forecast_inverse_transform_call_check(deep_base_model_mock):
ts = MagicMock()
def test_deep_base_model_forecast_inverse_transform_call_check(deep_base_model_mock, ts_mock):
horizon = 7
DeepBaseModel.forecast(self=deep_base_model_mock, ts=ts, prediction_size=horizon)
ts.tsdataset_idx_slice.return_value.inverse_transform.assert_called_once()
DeepBaseModel.forecast(self=deep_base_model_mock, ts=ts_mock, prediction_size=horizon)
ts_mock.tsdataset_idx_slice.return_value.inverse_transform.assert_called_once()


def test_deep_base_model_forecast_fail_not_enough_context(deep_base_model_mock, ts_mock):
horizon = len(ts_mock.index)
with pytest.raises(ValueError, match="Given context isn't big enough"):
_ = DeepBaseModel.forecast(self=deep_base_model_mock, ts=ts_mock, prediction_size=horizon)


def test_deep_base_model_forecast_loop(simple_df, deep_base_model_mock):
ts = MagicMock()
def test_deep_base_model_forecast_loop(simple_df, deep_base_model_mock, ts_mock):
ts_after_tsdataset_idx_slice = MagicMock()
horizon = 7

raw_predict = {("A", "target"): np.arange(10).reshape(-1, 1), ("B", "target"): -np.arange(10).reshape(-1, 1)}
deep_base_model_mock.raw_predict.return_value = raw_predict

ts_after_tsdataset_idx_slice.df = simple_df.df.iloc[-horizon:]
ts.tsdataset_idx_slice.return_value = ts_after_tsdataset_idx_slice
ts_mock.tsdataset_idx_slice.return_value = ts_after_tsdataset_idx_slice

future = DeepBaseModel.forecast(self=deep_base_model_mock, ts=ts, prediction_size=horizon)
future = DeepBaseModel.forecast(self=deep_base_model_mock, ts=ts_mock, prediction_size=horizon)
np.testing.assert_allclose(
future.df.loc[:, pd.IndexSlice["A", "target"]], raw_predict[("A", "target")][:horizon, 0]
)
np.testing.assert_allclose(
future.df.loc[:, pd.IndexSlice["B", "target"]], raw_predict[("B", "target")][:horizon, 0]
)
ts.tsdataset_idx_slice.return_value.inverse_transform.assert_called_once()
ts_mock.tsdataset_idx_slice.return_value.inverse_transform.assert_called_once()
67 changes: 39 additions & 28 deletions tests/test_models/test_inference/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,32 +86,30 @@ def _test_forecast_in_sample_full_no_target(ts, model, transforms):
def test_forecast_in_sample_full_no_target(self, model, transforms, example_tsds):
self._test_forecast_in_sample_full_no_target(example_tsds, model, transforms)

@to_be_fixed(raises=AssertionError)
# Looks like a problem of current implementation of NNs
@pytest.mark.parametrize(
"model, transforms",
[
(RNNModel(input_size=1, encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), []),
(
MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)),
[LagTransform(in_column="target", lags=[5, 6])],
),
(LinearPerSegmentModel(), [LagTransform(in_column="target", lags=[2, 3])]),
(LinearMultiSegmentModel(), [LagTransform(in_column="target", lags=[2, 3])]),
(ElasticPerSegmentModel(), [LagTransform(in_column="target", lags=[2, 3])]),
(ElasticMultiSegmentModel(), [LagTransform(in_column="target", lags=[2, 3])]),
],
)
def test_forecast_in_sample_full_no_target_failed_assertion_error(self, model, transforms, example_tsds):
self._test_forecast_in_sample_full_no_target(example_tsds, model, transforms)
def test_forecast_in_sample_full_no_target_failed_nans_sklearn(self, model, transforms, example_tsds):
with pytest.raises(ValueError, match="Input contains NaN, infinity or a value too large"):
self._test_forecast_in_sample_full_no_target(example_tsds, model, transforms)

@pytest.mark.parametrize(
"model, transforms",
[
(LinearPerSegmentModel(), [LagTransform(in_column="target", lags=[2, 3])]),
(LinearMultiSegmentModel(), [LagTransform(in_column="target", lags=[2, 3])]),
(ElasticPerSegmentModel(), [LagTransform(in_column="target", lags=[2, 3])]),
(ElasticMultiSegmentModel(), [LagTransform(in_column="target", lags=[2, 3])]),
(
MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)),
[LagTransform(in_column="target", lags=[5, 6])],
),
],
)
def test_forecast_in_sample_full_no_target_failed_nans_lags(self, model, transforms, example_tsds):
with pytest.raises(ValueError, match="Input contains NaN, infinity or a value too large"):
def test_forecast_in_sample_full_no_target_failed_nans_nn(self, model, transforms, example_tsds):
with pytest.raises(ValueError, match="There are NaNs in features"):
self._test_forecast_in_sample_full_no_target(example_tsds, model, transforms)

@pytest.mark.parametrize(
Expand All @@ -121,6 +119,7 @@ def test_forecast_in_sample_full_no_target_failed_nans_lags(self, model, transfo
(NaiveModel(lag=3), []),
(SeasonalMovingAverageModel(), []),
(DeadlineMovingAverageModel(window=1), []),
(RNNModel(input_size=1, encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), []),
],
)
def test_forecast_in_sample_full_no_target_failed_not_enough_context(self, model, transforms, example_tsds):
Expand Down Expand Up @@ -182,7 +181,6 @@ class TestForecastInSampleFull:
(HoltModel(), []),
(HoltWintersModel(), []),
(SimpleExpSmoothingModel(), []),
(RNNModel(input_size=1, encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), []),
],
)
def test_forecast_in_sample_full(self, model, transforms, example_tsds):
Expand All @@ -197,7 +195,7 @@ def test_forecast_in_sample_full(self, model, transforms, example_tsds):
(ElasticMultiSegmentModel(), [LagTransform(in_column="target", lags=[2, 3])]),
],
)
def test_forecast_in_sample_full_failed_nans_lags(self, model, transforms, example_tsds):
def test_forecast_in_sample_full_failed_nans_sklearn(self, model, transforms, example_tsds):
with pytest.raises(ValueError, match="Input contains NaN, infinity or a value too large"):
_test_prediction_in_sample_full(example_tsds, model, transforms, method_name="forecast")

Expand All @@ -210,8 +208,8 @@ def test_forecast_in_sample_full_failed_nans_lags(self, model, transforms, examp
),
],
)
def test_forecast_in_sample_full_failed_nans_lags_nns(self, model, transforms, example_tsds):
with pytest.raises(AssertionError):
def test_forecast_in_sample_full_failed_nans_nn(self, model, transforms, example_tsds):
with pytest.raises(ValueError, match="There are NaNs in features"):
_test_prediction_in_sample_full(example_tsds, model, transforms, method_name="forecast")

@pytest.mark.parametrize(
Expand All @@ -221,12 +219,22 @@ def test_forecast_in_sample_full_failed_nans_lags_nns(self, model, transforms, e
(NaiveModel(lag=3), []),
(SeasonalMovingAverageModel(), []),
(DeadlineMovingAverageModel(window=1), []),
(RNNModel(input_size=1, encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), []),
],
)
def test_forecast_in_sample_full_failed_not_enough_context(self, model, transforms, example_tsds):
with pytest.raises(ValueError, match="Given context isn't big enough"):
_test_prediction_in_sample_full(example_tsds, model, transforms, method_name="forecast")

@to_be_fixed(raises=AssertionError)
# Looks like a problem of current implementation of NNs
@pytest.mark.parametrize(
"model, transforms",
[],
)
def test_forecast_in_sample_full_failed_nans_lags_nns(self, model, transforms, example_tsds):
_test_prediction_in_sample_full(example_tsds, model, transforms, method_name="forecast")

@to_be_fixed(raises=NotImplementedError, match="It is not possible to make in-sample predictions")
@pytest.mark.parametrize(
"model, transforms",
Expand Down Expand Up @@ -549,7 +557,7 @@ def _test_forecast_out_sample_suffix(ts, model, transforms, full_prediction_size
# firstly we should forecast prefix to use it as a context
forecast_prefix_ts = deepcopy(forecast_gap_ts)
forecast_prefix_ts.df = forecast_prefix_ts.df.iloc[:-suffix_prediction_size]
model.forecast(forecast_prefix_ts, prediction_size=prediction_size_diff)
forecast_prefix_ts = model.forecast(forecast_prefix_ts, prediction_size=prediction_size_diff)
forecast_gap_ts.df = forecast_gap_ts.df.combine_first(forecast_prefix_ts.df)

# forecast suffix with known context for it
Expand Down Expand Up @@ -584,25 +592,28 @@ def _test_forecast_out_sample_suffix(ts, model, transforms, full_prediction_size
(SeasonalMovingAverageModel(), []),
(NaiveModel(lag=3), []),
(DeadlineMovingAverageModel(window=1), []),
(
MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)),
[LagTransform(in_column="target", lags=[5, 6])],
),
],
)
def test_forecast_out_sample_suffix(self, model, transforms, example_tsds):
self._test_forecast_out_sample_suffix(example_tsds, model, transforms)

@to_be_fixed(raises=AssertionError)
# Looks like a problem of current implementation of NNs
@pytest.mark.parametrize(
"model, transforms",
[
(RNNModel(input_size=1, encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), []),
(
MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)),
[LagTransform(in_column="target", lags=[5, 6])],
),
],
)
def test_forecast_out_sample_suffix_failed_assertion_error(self, model, transforms, example_tsds):
self._test_forecast_out_sample_suffix(example_tsds, model, transforms)
def test_forecast_out_sample_suffix_failed_rnn(self, model, transforms, example_tsds):
"""This test is expected to fail due to autoregression in RNN.

More about it in issue: https://github.com/tinkoff-ai/etna/issues/1087
"""
with pytest.raises(AssertionError):
self._test_forecast_out_sample_suffix(example_tsds, model, transforms)

@to_be_fixed(raises=NotImplementedError, match="You can only forecast from the next point after the last one")
@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_models/test_inference/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_predict_in_sample_full(self, model, transforms, example_tsds):
(ElasticMultiSegmentModel(), [LagTransform(in_column="target", lags=[2, 3])]),
],
)
def test_predict_in_sample_full_failed_not_enough_context(self, model, transforms, example_tsds):
def test_predict_in_sample_full_failed_nans_sklearn(self, model, transforms, example_tsds):
with pytest.raises(ValueError, match="Input contains NaN, infinity or a value too large"):
_test_prediction_in_sample_full(example_tsds, model, transforms, method_name="predict")

Expand Down