Skip to content

Commit

Permalink
[BUG] fix pmdarima interfaces breaking for X containing more indice…
Browse files Browse the repository at this point in the history
…s than forecasting horizon (#3667)

Fixes #3657.

The bug was caused by `pmdarima` models breaking when the `X` passed was
strictly larger than the indices in the forecasting horizon.

The example code in #3657 has been added as a test (with minor
generalization to cover more estimators).

In the future, we should probably also add test scenarios where `X` is
strictly larger than the forecasting horizon.

Depends on:
* #4474, which fixes a bug that was
masked by #3657
* #4483 for slicing `X`
  • Loading branch information
fkiraly committed Apr 21, 2023
1 parent 9bae78d commit 41c1baa
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
11 changes: 11 additions & 0 deletions sktime/forecasting/base/adapters/_pmdarima.py
Expand Up @@ -8,6 +8,7 @@

import pandas as pd

from sktime.datatypes._utilities import get_slice
from sktime.forecasting.base import BaseForecaster
from sktime.forecasting.base._base import DEFAULT_ALPHA

Expand Down Expand Up @@ -87,6 +88,15 @@ def _predict(self, fh, X=None):
y_pred : pandas.Series
Returns series of predicted values.
"""
fh_abs = fh.to_absolute(self.cutoff).to_pandas()
fh_abs_int = fh.to_absolute_int(fh_abs[0], self.cutoff).to_pandas()
end_int = fh_abs_int[-1] + 2
# +2 becuase + 1 for "end" (python index), +1 for starting to count at 1 in fh

if X is not None:
X = get_slice(X, start=self.cutoff[0], start_inclusive=False)
X = X.iloc[:end_int]

# distinguish between in-sample and out-of-sample prediction
fh_oos = fh.to_out_of_sample(self.cutoff)
fh_ins = fh.to_in_sample(self.cutoff)
Expand All @@ -108,6 +118,7 @@ def _predict(self, fh, X=None):
# ensure that name is not added nor removed
# otherwise this may upset conversion to pd.DataFrame
y_pred.name = self._y.name
y_pred.index = fh_abs
return y_pred

def _predict_in_sample(
Expand Down
33 changes: 32 additions & 1 deletion sktime/forecasting/model_evaluation/tests/test_evaluate.py
Expand Up @@ -21,6 +21,7 @@

from sktime.datasets import load_airline, load_longley
from sktime.exceptions import FitFailedWarning
from sktime.forecasting.arima import ARIMA, AutoARIMA
from sktime.forecasting.compose._reduce import DirectReductionForecaster
from sktime.forecasting.ets import AutoETS
from sktime.forecasting.exp_smoothing import ExponentialSmoothing
Expand All @@ -32,6 +33,7 @@
from sktime.forecasting.naive import NaiveForecaster
from sktime.forecasting.tests._config import TEST_FHS, TEST_STEP_LENGTHS_INT
from sktime.performance_metrics.forecasting import (
MeanAbsoluteError,
MeanAbsolutePercentageError,
MeanAbsoluteScaledError,
)
Expand All @@ -44,7 +46,10 @@
from sktime.utils._testing.forecasting import make_forecasting_problem
from sktime.utils._testing.hierarchical import _make_hierarchical
from sktime.utils._testing.series import _make_series
from sktime.utils.validation._dependencies import _check_soft_dependencies
from sktime.utils.validation._dependencies import (
_check_estimator_deps,
_check_soft_dependencies,
)


def _check_evaluate_output(out, cv, y, scoring):
Expand Down Expand Up @@ -282,6 +287,32 @@ def test_evaluate_hierarchical(backend):
assert np.all(out_exog[scoring_name] != out_no_exog[scoring_name])


# ARIMA models from statsmodels, pmdarima
ARIMA_MODELS = [ARIMA, AutoARIMA]

# breaks for SARIMAX, see issue #3670, this should be fixed
# ARIMA_MODELS = [ARIMA, AutoARIMA, SARIMAX]


@pytest.mark.parametrize("cls", ARIMA_MODELS)
def test_evaluate_bigger_X(cls):
"""Check that evaluating ARIMA models with exogeneous X works.
Example adapted from bug report #3657.
"""
if not _check_estimator_deps(cls, severity="none"):
return None

y, X = load_longley()

f = cls.create_test_instance()
cv = ExpandingWindowSplitter(initial_window=3, step_length=1, fh=np.arange(1, 4))
loss = MeanAbsoluteError()

# check that this does not break
evaluate(forecaster=f, y=y, X=X, cv=cv, error_score="raise", scoring=loss)


PROBA_METRICS = [CRPS, EmpiricalCoverage, LogLoss, PinballLoss]


Expand Down

0 comments on commit 41c1baa

Please sign in to comment.