Skip to content

Commit

Permalink
[MNT] add differential testing to utils module (#6620)
Browse files Browse the repository at this point in the history
This PR adds differntial testing decorators in the `utils` module.
  • Loading branch information
fkiraly committed Jun 20, 2024
1 parent 6fb8161 commit 0c2b202
Show file tree
Hide file tree
Showing 13 changed files with 156 additions and 1 deletion.
29 changes: 29 additions & 0 deletions sktime/utils/estimators/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sktime.datasets import load_airline
from sktime.forecasting.base import BaseForecaster
from sktime.forecasting.naive import NaiveForecaster
from sktime.tests.test_switch import run_test_module_changed
from sktime.transformations.base import BaseTransformer
from sktime.transformations.series.boxcox import BoxCoxTransformer
from sktime.utils.estimators import make_mock_estimator
Expand All @@ -18,6 +19,10 @@
y_series = load_airline().iloc[:-5]


@pytest.mark.skipif(
not run_test_module_changed("sktime.utils.estimators"),
reason="Run test if estimators module has changed.",
)
@pytest.mark.parametrize(
"base", [BaseForecaster, BaseClassifier, BaseClusterer, BaseTransformer]
)
Expand Down Expand Up @@ -46,6 +51,10 @@ def _score(self):
assert hasattr(dummy_instance, "_MockEstimatorMixin__log")


@pytest.mark.skipif(
not run_test_module_changed("sktime.utils.estimators"),
reason="Run test if estimators module has changed.",
)
def test_add_log_item():
"""Test _MockEstimatorMixin.add_log_item behaviour."""
mixin = _MockEstimatorMixin()
Expand All @@ -55,6 +64,10 @@ def test_add_log_item():
assert mixin.log[1] == 2


@pytest.mark.skipif(
not run_test_module_changed("sktime.utils.estimators"),
reason="Run test if estimators module has changed.",
)
def test_log_is_property():
"""Test _MockEstimatorMixin.log can't be overwritten."""
mixin = _MockEstimatorMixin()
Expand All @@ -63,6 +76,10 @@ def test_log_is_property():
assert "can't set attribute" in str(excinfo.value)


@pytest.mark.skipif(
not run_test_module_changed("sktime.utils.estimators"),
reason="Run test if estimators module has changed.",
)
def test_method_logger_exception():
"""Test that _method_logger only works for _MockEstimatorMixin subclasses."""

Expand All @@ -82,6 +99,10 @@ def _method(self):
assert "Estimator is not a Mock Estimator" in str(excinfo.value)


@pytest.mark.skipif(
not run_test_module_changed("sktime.utils.estimators"),
reason="Run test if estimators module has changed.",
)
def test_method_logger():
"""Test that method logger returns the correct output."""

Expand Down Expand Up @@ -143,6 +164,10 @@ def _method3(self):
]


@pytest.mark.skipif(
not run_test_module_changed("sktime.utils.estimators"),
reason="Run test if estimators module has changed.",
)
@pytest.mark.parametrize(
"estimator_class, method_regex, logged_methods",
[
Expand All @@ -161,6 +186,10 @@ def test_make_mock_estimator(estimator_class, method_regex, logged_methods):
assert set(methods_called) >= set(logged_methods)


@pytest.mark.skipif(
not run_test_module_changed("sktime.utils.estimators"),
reason="Run test if estimators module has changed.",
)
@pytest.mark.parametrize(
"estimator_class, estimator_kwargs",
[
Expand Down
4 changes: 3 additions & 1 deletion sktime/utils/numba/tests/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import pytest
from numpy.testing import assert_array_equal

from sktime.tests.test_switch import run_test_module_changed
from sktime.utils.dependencies import _check_soft_dependencies

DATATYPES = ["int32", "int64", "float32", "float64"]


@pytest.mark.skipif(
not _check_soft_dependencies("numba", severity="none"),
not run_test_module_changed(["sktime.utils.numba"])
or not _check_soft_dependencies("numba", severity="none"),
reason="skip test if required soft dependency not available",
)
@pytest.mark.parametrize("type", DATATYPES)
Expand Down
9 changes: 9 additions & 0 deletions sktime/utils/sklearn/tests/test_sklearn_df_adapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@
import pandas as pd
import pytest

from sktime.tests.test_switch import run_test_module_changed
from sktime.utils.sklearn._adapt_df import prep_skl_df


@pytest.mark.skipif(
not run_test_module_changed(["sktime.utils.sklearn"]),
reason="Run if utils module has changed.",
)
@pytest.mark.parametrize("copy_df", [True, False])
def test_prep_skl_df_coercion(copy_df):
"""Test that prep_skl_df behaves correctly on the coercion case."""
Expand All @@ -22,6 +27,10 @@ def test_prep_skl_df_coercion(copy_df):
assert res is mixed_example


@pytest.mark.skipif(
not run_test_module_changed(["sktime.utils.sklearn"]),
reason="Run if utils module has changed.",
)
@pytest.mark.parametrize("copy_df", [True, False])
def test_prep_skl_df_non_coercion(copy_df):
"""Test that prep_skl_df behaves correctly on the non-coercion case."""
Expand Down
13 changes: 13 additions & 0 deletions sktime/utils/sklearn/tests/test_sklearn_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from sktime.classification.feature_based import SummaryClassifier
from sktime.forecasting.naive import NaiveForecaster
from sktime.tests.test_switch import run_test_for_class
from sktime.utils.sklearn import is_sklearn_estimator, sklearn_scitype

CORRECT_SCITYPES = {
Expand All @@ -23,6 +24,10 @@
sktime_estimators = [SummaryClassifier, NaiveForecaster]


@pytest.mark.skipif(
not run_test_for_class(is_sklearn_estimator),
reason="Run if utilities have changed.",
)
@pytest.mark.parametrize("estimator", sklearn_estimators)
def test_is_sklearn_estimator_positive(estimator):
"""Test that is_sklearn_estimator recognizes positive examples correctly."""
Expand All @@ -33,6 +38,10 @@ def test_is_sklearn_estimator_positive(estimator):
assert is_sklearn_estimator(estimator), msg


@pytest.mark.skipif(
not run_test_for_class(is_sklearn_estimator),
reason="Run if utilities have changed.",
)
@pytest.mark.parametrize("estimator", sktime_estimators)
def test_is_sklearn_estimator_negative(estimator):
"""Test that is_sklearn_estimator recognizes negative examples correctly."""
Expand All @@ -43,6 +52,10 @@ def test_is_sklearn_estimator_negative(estimator):
assert not is_sklearn_estimator(estimator), msg


@pytest.mark.skipif(
not run_test_for_class(sklearn_scitype),
reason="Run if utilities have changed.",
)
@pytest.mark.parametrize("estimator", sklearn_estimators)
def test_sklearn_scitype(estimator):
"""Test that sklearn_scitype returns the correct scitype string."""
Expand Down
18 changes: 18 additions & 0 deletions sktime/utils/tests/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

import numpy as np
import pandas as pd
import pytest

from sktime.datasets import load_airline
from sktime.datatypes import VectorizedDF
from sktime.datatypes._utilities import get_time_index
from sktime.tests.test_switch import run_test_module_changed
from sktime.utils._testing.hierarchical import _bottom_hier_datagen
from sktime.utils.datetime import (
_coerce_duration_to_int,
Expand All @@ -20,6 +22,10 @@
from sktime.utils.dependencies import _check_soft_dependencies


@pytest.mark.skipif(
not run_test_module_changed(["sktime.utils", "sktime.datatypes"]),
reason="Run if utils or datatypes module has changed.",
)
def test_get_freq():
"""Test whether get_freq runs without error."""
x = pd.Series(
Expand All @@ -45,6 +51,10 @@ def test_get_freq():
assert _get_freq(x4) is None


@pytest.mark.skipif(
not run_test_module_changed(["sktime.utils", "sktime.datatypes"]),
reason="Run if utils or datatypes module has changed.",
)
def test_coerce_duration_to_int() -> None:
"""Test _coerce_duration_to_int."""
assert _coerce_duration_to_int(duration=0) == 0
Expand All @@ -69,6 +79,10 @@ def test_coerce_duration_to_int() -> None:
)


@pytest.mark.skipif(
not run_test_module_changed(["sktime.utils", "sktime.datatypes"]),
reason="Run if utils or datatypes module has changed.",
)
def test_infer_freq() -> None:
"""Test frequency inference."""
assert infer_freq(None) is None
Expand Down Expand Up @@ -97,6 +111,10 @@ def test_infer_freq() -> None:
assert infer_freq(y) in ["M", "ME"]


@pytest.mark.skipif(
not run_test_module_changed(["sktime.utils", "sktime.datatypes"]),
reason="Run if utils or datatypes module has changed.",
)
def test_set_freq_hier():
"""Test that setting frequency on a DatetimeIndex MultiIndex works."""
# from pandas 2.1.0 on, freq is preserved correctly,
Expand Down
9 changes: 9 additions & 0 deletions sktime/utils/tests/test_multiindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
import pandas as pd
import pytest

from sktime.tests.test_switch import run_test_module_changed
from sktime.utils.multiindex import flatten_multiindex, rename_multiindex


@pytest.mark.skipif(
not run_test_module_changed(["sktime.utils.multiindex"]),
reason="Run if multiindex module has changed.",
)
def test_flatten_multiindex():
"""Test flatten_multiindex contract."""
mi = pd.MultiIndex.from_product([["a", "b"], [0, 42]])
Expand All @@ -18,6 +23,10 @@ def test_flatten_multiindex():
assert (expected == flat).all()


@pytest.mark.skipif(
not run_test_module_changed(["sktime.utils.multiindex"]),
reason="Run if multiindex module has changed.",
)
def test_rename_multiindex():
"""Test rename_multiindex contract."""
mi = pd.MultiIndex.from_tuples([("a", 1), ("a", 42), ("b", 1), ("c", 0)])
Expand Down
13 changes: 13 additions & 0 deletions sktime/utils/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@

from sktime.datasets import load_unit_test
from sktime.datatypes import check_is_scitype
from sktime.tests.test_switch import run_test_module_changed
from sktime.utils.sampling import random_partition, stratified_resample

NK_FIXTURES = [(10, 3), (15, 5), (19, 6), (3, 1), (1, 2)]
SEED_FIXTURES = [42, 0, 100, -5]


@pytest.mark.skipif(
not run_test_module_changed(["sktime.utils.sampling", "sktime.datatypes"]),
reason="Run if multiindex module has changed.",
)
@pytest.mark.parametrize("n, k", NK_FIXTURES)
def test_partition(n, k):
"""Test that random_partition returns a disjoint partition."""
Expand All @@ -35,6 +40,10 @@ def test_partition(n, k):
assert len(set(x).intersection(y)) == 0


@pytest.mark.skipif(
not run_test_module_changed(["sktime.utils.sampling", "sktime.datatypes"]),
reason="Run if multiindex module has changed.",
)
@pytest.mark.parametrize("seed", SEED_FIXTURES)
@pytest.mark.parametrize("n, k", NK_FIXTURES)
def test_seed(n, k, seed):
Expand All @@ -47,6 +56,10 @@ def test_seed(n, k, seed):
assert deep_equals(part, part2)


@pytest.mark.skipif(
not run_test_module_changed(["sktime.utils.sampling", "sktime.datatypes"]),
reason="Run if multiindex module has changed.",
)
def test_stratified_resample():
"""Test resampling returns valid data structure and maintains class distribution."""
trainX, trainy = load_unit_test(split="TRAIN")
Expand Down
13 changes: 13 additions & 0 deletions sktime/utils/tests/test_seasonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
import pandas as pd
import pytest

from sktime.tests.test_switch import run_test_module_changed
from sktime.utils._testing.series import _make_series
from sktime.utils.seasonality import _pivot_sp, _unpivot_sp


@pytest.mark.skipif(
not run_test_module_changed(["sktime.utils"]),
reason="Run if utils module has changed.",
)
@pytest.mark.parametrize("n_timepoints", [49, 1])
@pytest.mark.parametrize("index_type", ["period", "datetime", "range", "int"])
@pytest.mark.parametrize("sp", [2, 10])
Expand Down Expand Up @@ -53,6 +58,10 @@ def test_pivot_sp(sp, index_type, n_timepoints, anchor_side):
assert not np.isnan(df_pivot.iloc[-1, -1])


@pytest.mark.skipif(
not run_test_module_changed(["sktime.utils"]),
reason="Run if utils module has changed.",
)
@pytest.mark.parametrize("n_timepoints", [49, 1])
@pytest.mark.parametrize("index_type", ["period", "datetime", "range", "int"])
@pytest.mark.parametrize("sp", [2, 10])
Expand Down Expand Up @@ -81,6 +90,10 @@ def test_unpivot_sp(sp, index_type, n_timepoints, anchor_side):
assert np.all(df_unpivot == df)


@pytest.mark.skipif(
not run_test_module_changed(["sktime.utils"]),
reason="Run if utils module has changed.",
)
@pytest.mark.parametrize("n_timepoints", [50, 2])
@pytest.mark.parametrize("index_type", ["period", "datetime", "range", "int"])
@pytest.mark.parametrize("sp", [3, 10])
Expand Down
13 changes: 13 additions & 0 deletions sktime/utils/tests/test_utils_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
import pytest
from scipy.stats import linregress

from sktime.tests.test_switch import run_test_module_changed
from sktime.utils._testing.forecasting import _generate_polynomial_series
from sktime.utils.slope_and_trend import _fit_trend, _slope


@pytest.mark.skipif(
not run_test_module_changed(["sktime.utils"]),
reason="Run if utils module has changed.",
)
@pytest.mark.parametrize("trend_order", [0, 3])
def test_time_series_slope_against_scipy_linregress(trend_order):
"""Test time series slope against scipy lingress."""
Expand All @@ -21,6 +26,10 @@ def test_time_series_slope_against_scipy_linregress(trend_order):


# Check linear and constant cases
@pytest.mark.skipif(
not run_test_module_changed(["sktime.utils"]),
reason="Run if utils module has changed.",
)
@pytest.mark.parametrize("slope", [-1, 0, 1])
def test_time_series_slope_against_simple_cases(slope):
"""Test time series slope against simple cases."""
Expand All @@ -29,6 +38,10 @@ def test_time_series_slope_against_simple_cases(slope):
np.testing.assert_almost_equal(_slope(y), slope, decimal=10)


@pytest.mark.skipif(
not run_test_module_changed(["sktime.utils"]),
reason="Run if utils module has changed.",
)
@pytest.mark.parametrize("order", [0, 1, 2]) # polynomial order
@pytest.mark.parametrize("n_timepoints", [1, 10]) # number of time series observations
@pytest.mark.parametrize("n_instances", [1, 10]) # number of samples
Expand Down
9 changes: 9 additions & 0 deletions sktime/utils/validation/tests/test_check_window_length.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
"""Tests for window length."""
import pytest

from sktime.tests.test_switch import run_test_for_class
from sktime.utils.validation import check_window_length


@pytest.mark.skipif(
not run_test_for_class(check_window_length),
reason="Run if tested function has changed.",
)
@pytest.mark.parametrize(
"window_length, n_timepoints, expected",
[
Expand All @@ -22,6 +27,10 @@ def test_check_window_length(window_length, n_timepoints, expected):
assert check_window_length(window_length, n_timepoints) == expected


@pytest.mark.skipif(
not run_test_for_class(check_window_length),
reason="Run if tested function has changed.",
)
@pytest.mark.parametrize(
"window_length, n_timepoints",
[
Expand Down
Loading

0 comments on commit 0c2b202

Please sign in to comment.