Skip to content

Commit

Permalink
[BUG] fix BATS and TBATS _predict_interval interface, part 2 (#…
Browse files Browse the repository at this point in the history
…4505)

With more test coverage on `predict_interval` from
#4470, `BATS` and `TBATS` fail some
in-sample cases.

These were not addressed in #4492
  • Loading branch information
fkiraly committed Apr 23, 2023
1 parent 2928e0b commit 06dfabb
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions sktime/forecasting/base/adapters/_tbats.py
Expand Up @@ -194,9 +194,8 @@ def _tbats_forecast(self, fh):

if not fh.is_all_in_sample(cutoff=self.cutoff):
fh_out = fh.to_out_of_sample(cutoff=self.cutoff)
steps = fh_out.to_pandas().max()
steps = fh_out.to_pandas()[-1]
y_out = self._forecaster.forecast(steps=steps, confidence_level=None)

else:
y_out = nans(len(fh))

Expand Down Expand Up @@ -227,17 +226,22 @@ def _tbats_forecast_interval(self, fh, conf_lev):
"""
fh = fh.to_relative(cutoff=self.cutoff)
fh_out = fh.to_out_of_sample(cutoff=self.cutoff)
steps = fh_out.to_pandas().max()

_, tbats_ci = self._forecaster.forecast(steps=steps, confidence_level=conf_lev)
out = pd.DataFrame(tbats_ci)
if not fh.is_all_in_sample(cutoff=self.cutoff):
steps = fh_out.to_pandas()[-1]
_, tbats_ci = self._forecaster.forecast(
steps=steps, confidence_level=conf_lev
)
out = pd.DataFrame(tbats_ci)
# pred_int
lower = pd.Series(out["lower_bound"])
upper = pd.Series(out["upper_bound"])
pred_int_oos = pd.DataFrame({"lower": lower, "upper": upper})
pred_int_oos = pred_int_oos.iloc[fh_out.to_indexer()]
pred_int_oos.index = fh_out.to_absolute_index(self.cutoff)
else:
pred_int_oos = pd.DataFrame(columns=["lower", "upper"])

# pred_int
lower = pd.Series(out["lower_bound"])
upper = pd.Series(out["upper_bound"])
pred_int_oos = pd.DataFrame({"lower": lower, "upper": upper})
pred_int_oos = pred_int_oos.iloc[fh_out.to_indexer()]
pred_int_oos.index = fh_out.to_absolute_index(self.cutoff)
full_ix = fh.to_absolute_index(self.cutoff)
pred_int = pred_int_oos.reindex(full_ix)

Expand Down

0 comments on commit 06dfabb

Please sign in to comment.