Skip to content

Commit

Permalink
[MNT] differential testing and minor improvements to `forecasting.bas…
Browse files Browse the repository at this point in the history
…e` tests (#6619)

This PR adds differential testing and makes minor improvements to
`forecasting.base` tests:

* replaces `ARIMA` with a faster and internal alternative
* refactors fixture generation
  • Loading branch information
fkiraly committed Jun 20, 2024
1 parent 0c2b202 commit 6438c61
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 38 deletions.
99 changes: 69 additions & 30 deletions sktime/forecasting/base/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@

from sktime.datatypes import check_is_mtype, convert
from sktime.datatypes._utilities import get_cutoff, get_window
from sktime.forecasting.arima import ARIMA
from sktime.forecasting.compose import YfromX
from sktime.forecasting.naive import NaiveForecaster
from sktime.forecasting.theta import ThetaForecaster
from sktime.forecasting.var import VAR
from sktime.split import temporal_train_test_split
from sktime.tests.test_switch import run_test_module_changed
from sktime.utils._testing.hierarchical import _make_hierarchical
from sktime.utils._testing.panel import _make_panel
from sktime.utils._testing.series import _make_series
from sktime.utils.dependencies import _check_estimator_deps
from sktime.utils.dependencies import _check_estimator_deps, _check_soft_dependencies
from sktime.utils.parallel import _get_parallel_test_fixtures

PANEL_MTYPES = ["pd-multiindex", "nested_univ", "numpy3D"]
Expand All @@ -34,8 +34,9 @@


@pytest.mark.skipif(
not _check_estimator_deps(ARIMA, severity="none"),
reason="skip test if required soft dependency for ARIMA not available",
not run_test_module_changed(["sktime.forecasting.base", "sktime.datatypes"])
or not _check_soft_dependencies("skpro", severity="none"),
reason="run only if base module has changed or datatypes module has changed",
)
@pytest.mark.parametrize("backend", BACKENDS)
@pytest.mark.parametrize("mtype", PANEL_MTYPES)
Expand All @@ -49,7 +50,7 @@ def test_vectorization_series_to_panel(mtype, backend):

y = _make_panel(n_instances=n_instances, random_state=42, return_mtype=mtype)

f = ARIMA()
f = YfromX.create_test_instance()
f.set_config(**backend.copy())
y_pred = f.fit(y).predict([1, 2, 3])
valid, _, metadata = check_is_mtype(
Expand All @@ -58,7 +59,7 @@ def test_vectorization_series_to_panel(mtype, backend):

msg = (
f"vectorization of forecasters does not work for test example "
f"of mtype {mtype}, using the ARIMA forecaster"
f"of mtype {mtype}, using the YfromX forecaster"
)

assert valid, msg
Expand Down Expand Up @@ -87,15 +88,15 @@ def test_vectorization_series_to_panel(mtype, backend):


@pytest.mark.skipif(
not _check_estimator_deps(ARIMA, severity="none"),
reason="skip test if required soft dependency for ARIMA not available",
not run_test_module_changed(["sktime.forecasting.base", "sktime.datatypes"]),
reason="run only if base module has changed or datatypes module has changed",
)
@pytest.mark.parametrize("backend", BACKENDS)
@pytest.mark.parametrize("mtype", HIER_MTYPES)
def test_vectorization_series_to_hier(mtype, backend):
"""Test that forecaster vectorization works for Hierarchical data.
This test passes Hierarchical data to the ARIMA forecaster which internally has an
This test passes Hierarchical data to the YfromX forecaster which internally has an
implementation for Series only, so the BaseForecaster has to vectorize.
"""
hierarchy_levels = (2, 4)
Expand All @@ -104,7 +105,9 @@ def test_vectorization_series_to_hier(mtype, backend):
y = _make_hierarchical(hierarchy_levels=hierarchy_levels, random_state=84)
y = convert(y, from_type="pd_multiindex_hier", to_type=mtype)

f = ARIMA()
f = YfromX.create_test_instance()
assert f.get_tags()["scitype:y"] == "univariate" # check the assumption

f.set_config(**backend.copy())
y_pred = f.fit(y).predict([1, 2, 3])
valid, _, metadata = check_is_mtype(
Expand All @@ -113,7 +116,7 @@ def test_vectorization_series_to_hier(mtype, backend):

msg = (
f"vectorization of forecasters does not work for test example "
f"of mtype {mtype}, using the ARIMA forecaster"
f"of mtype {mtype}, using the YfromX forecaster"
)

assert valid, msg
Expand Down Expand Up @@ -144,8 +147,9 @@ def test_vectorization_series_to_hier(mtype, backend):


@pytest.mark.skipif(
not _check_estimator_deps(ARIMA, severity="none"),
reason="skip test if required soft dependency for ARIMA not available",
not run_test_module_changed(["sktime.forecasting.base", "sktime.datatypes"])
or not _check_soft_dependencies("skpro", severity="none"),
reason="run only if base module has changed or datatypes module has changed",
)
@pytest.mark.parametrize("method", PROBA_DF_METHODS)
@pytest.mark.parametrize("mtype", PANEL_MTYPES)
Expand All @@ -159,7 +163,8 @@ def test_vectorization_series_to_panel_proba(method, mtype):

y = _make_panel(n_instances=n_instances, random_state=42, return_mtype=mtype)

est = ARIMA().fit(y)
est = _get_exog_proba_fcst()
est.fit(y)
y_pred = getattr(est, method)([1, 2, 3])

if method in ["predict_interval", "predict_quantiles"]:
Expand All @@ -182,8 +187,9 @@ def test_vectorization_series_to_panel_proba(method, mtype):


@pytest.mark.skipif(
not _check_estimator_deps(ARIMA, severity="none"),
reason="skip test if required soft dependency for ARIMA not available",
not run_test_module_changed(["sktime.forecasting.base", "sktime.datatypes"])
or not _check_soft_dependencies("skpro", severity="none"),
reason="run only if base module has changed or datatypes module has changed",
)
@pytest.mark.parametrize("method", PROBA_DF_METHODS)
@pytest.mark.parametrize("mtype", HIER_MTYPES)
Expand All @@ -197,7 +203,8 @@ def test_vectorization_series_to_hier_proba(method, mtype):
y = _make_hierarchical(hierarchy_levels=hierarchy_levels, random_state=84)
y = convert(y, from_type="pd_multiindex_hier", to_type=mtype)

est = ARIMA().fit(y)
est = _get_exog_proba_fcst()
est.fit(y)
y_pred = getattr(est, method)([1, 2, 3])

if method in ["predict_interval", "predict_quantiles"]:
Expand All @@ -220,16 +227,18 @@ def test_vectorization_series_to_hier_proba(method, mtype):


@pytest.mark.skipif(
not _check_estimator_deps(ARIMA, severity="none"),
reason="skip test if required soft dependency for ARIMA not available",
not run_test_module_changed(["sktime.forecasting.base", "sktime.datatypes"])
or not _check_soft_dependencies("skpro", severity="none"),
reason="run only if base module has changed or datatypes module has changed",
)
@pytest.mark.parametrize("method", PROBA_DF_METHODS)
def test_vectorization_preserves_row_index_names(method):
"""Test that forecaster vectorization preserves row index names in forecast."""
hierarchy_levels = (2, 4)
y = _make_hierarchical(hierarchy_levels=hierarchy_levels, random_state=84)

est = ARIMA().fit(y, fh=[1, 2, 3])
est = _get_exog_proba_fcst()
est.fit(y, fh=[1, 2, 3])
y_pred = getattr(est, method)()

msg = (
Expand All @@ -241,8 +250,8 @@ def test_vectorization_preserves_row_index_names(method):


@pytest.mark.skipif(
not _check_estimator_deps(ARIMA, severity="none"),
reason="skip test if required soft dependency for ARIMA not available",
not run_test_module_changed(["sktime.forecasting.base", "sktime.datatypes"]),
reason="run only if base module has changed or datatypes module has changed",
)
@pytest.mark.parametrize("mtype", HIER_MTYPES)
@pytest.mark.parametrize("exogeneous", [True, False])
Expand All @@ -264,7 +273,8 @@ def test_vectorization_multivariate(mtype, exogeneous):
X_fit = None
X_pred = None

est = ARIMA().fit(y=y_fit, X=X_fit, fh=[1, 2, 3])
est = YfromX.create_test_instance()
est.fit(y=y_fit, X=X_fit, fh=[1, 2, 3])
y_pred = est.predict(X=X_pred)
valid, _, metadata = check_is_mtype(
y_pred, mtype, return_metadata=True, msg_return_dict="list"
Expand Down Expand Up @@ -298,6 +308,10 @@ def test_vectorization_multivariate(mtype, exogeneous):
assert y_pred_equal_length, msg


@pytest.mark.skipif(
not run_test_module_changed("sktime.forecasting.base"),
reason="run only if base module has changed",
)
def test_col_vectorization_correct_col_order():
"""Test that forecaster vectorization preserves column index ordering.
Expand All @@ -321,6 +335,10 @@ def test_col_vectorization_correct_col_order():
assert (y_pred == y.iloc[4]).all().all()


@pytest.mark.skipif(
not run_test_module_changed("sktime.forecasting.base"),
reason="run only if base module has changed",
)
def test_row_vectorization_correct_row_order():
"""Test that forecaster vectorization preserves row index ordering.
Expand Down Expand Up @@ -394,8 +412,8 @@ def test_predict_residuals():


@pytest.mark.skipif(
not _check_estimator_deps(ARIMA, severity="none"),
reason="skip test if required soft dependency not available",
not run_test_module_changed(["sktime.forecasting.base", "sktime.datatypes"]),
reason="run only if base module has changed or datatypes module has changed",
)
@pytest.mark.parametrize("nullable_type", ["Int64", "Float64", "boolean"])
def test_nullable_dtypes(nullable_type):
Expand All @@ -410,7 +428,7 @@ def test_nullable_dtypes(nullable_type):
y = pd.Series([1, 0] * 20, dtype=dtype)
y.index = pd.date_range("1/1/21", periods=40)

f = ARIMA()
f = YfromX.create_test_instance()

fh = list(range(1, len(X_test) + 1))
f.fit(X=X_train, y=y, fh=fh)
Expand All @@ -421,8 +439,9 @@ def test_nullable_dtypes(nullable_type):


@pytest.mark.skipif(
not _check_estimator_deps(VAR, severity="none"),
reason="skip test if required soft dependency not available",
not _check_estimator_deps(VAR, severity="none")
or not run_test_module_changed(["sktime.forecasting.base", "sktime.datatypes"]),
reason="run only if base module has changed or datatypes module has changed",
)
def test_range_fh_in_fit():
"""Test using ``range`` in ``fit``."""
Expand All @@ -436,8 +455,9 @@ def test_range_fh_in_fit():


@pytest.mark.skipif(
not _check_estimator_deps(VAR, severity="none"),
reason="skip test if required soft dependency not available",
not _check_estimator_deps(VAR, severity="none")
or not run_test_module_changed(["sktime.forecasting.base", "sktime.datatypes"]),
reason="run only if base module has changed or datatypes module has changed",
)
def test_range_fh_in_predict():
"""Test using ``range`` in ``predict``."""
Expand All @@ -460,6 +480,10 @@ def test_range_fh_in_predict():
assert var_predictions.shape == (10 * 2, 5)


@pytest.mark.skipif(
not run_test_module_changed(["sktime.forecasting.base", "sktime.datatypes"]),
reason="run only if base module has changed or datatypes module has changed",
)
def test_remember_data():
"""Test that the ``remember_data`` flag works as expected."""
from sktime.datasets import load_airline
Expand All @@ -482,6 +506,10 @@ def test_remember_data():
assert f._y is not None


@pytest.mark.skipif(
not run_test_module_changed(["sktime.forecasting.base", "sktime.datatypes"]),
reason="run only if base module has changed or datatypes module has changed",
)
def test_panel_with_inner_freq():
"""Test that panel data with inner frequency set returns the correct predictions."""
from sktime.datasets import load_airline
Expand Down Expand Up @@ -517,3 +545,14 @@ def test_panel_with_inner_freq():

msg = "Panel not returning same predictions as simple case."
assert y_pred.droplevel("hour").sort_index().equals(y_pred_simple), msg


def _get_exog_proba_fcst():
"""Fast forecaster that can use exogenous data and make proba forecasts."""
from sklearn.linear_model import LinearRegression
from skpro.regression.residual import ResidualDouble

lin_reg = LinearRegression()
reg_proba = ResidualDouble(lin_reg, lin_reg)

return YfromX(reg_proba)
10 changes: 8 additions & 2 deletions sktime/forecasting/base/tests/test_base_bugs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@
from sktime.forecasting.reconcile import ReconcilerForecaster
from sktime.forecasting.trend import PolynomialTrendForecaster
from sktime.split import ExpandingWindowSplitter
from sktime.tests.test_switch import run_test_module_changed
from sktime.transformations.hierarchical.aggregate import Aggregator
from sktime.transformations.series.difference import Differencer
from sktime.utils._testing.hierarchical import _make_hierarchical
from sktime.utils.dependencies import _check_estimator_deps


@pytest.mark.skipif(
not _check_estimator_deps(ExponentialSmoothing, severity="none"),
reason="skip test if required soft dependency not available",
not run_test_module_changed("sktime.forecasting.base")
or not _check_estimator_deps(ExponentialSmoothing, severity="none"),
reason="run only if base module has changed",
)
def test_heterogeneous_get_fitted_params():
"""Regression test for bugfix #4574, related to get_fitted_params."""
Expand Down Expand Up @@ -60,6 +62,10 @@ def test_heterogeneous_get_fitted_params():
reconciler.get_fitted_params() # triggers an error pre-fix


@pytest.mark.skipif(
not run_test_module_changed("sktime.forecasting.base"),
reason="run only if base module has changed",
)
def test_predict_residuals_conversion():
"""Regression test for bugfix #4766, related to predict_residuals internal type."""
from sktime.datasets import load_longley
Expand Down
Loading

0 comments on commit 6438c61

Please sign in to comment.