Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] conditional execution of test_distance and test_distance_params #5099

Merged
merged 10 commits into from Aug 19, 2023
4 changes: 4 additions & 0 deletions sktime/distances/tests/test_numba_distance_parameters.py
Expand Up @@ -10,6 +10,7 @@
from sktime.distances.base import MetricInfo
from sktime.distances.tests._expected_results import _expected_distance_results_params
from sktime.distances.tests._utils import create_test_distance_numpy
from sktime.tests.test_switch import run_test_for_class
from sktime.utils.numba.njit import njit
from sktime.utils.validation._dependencies import _check_soft_dependencies

Expand Down Expand Up @@ -85,6 +86,9 @@ def _test_derivative(q: np.ndarray):
@pytest.mark.parametrize("dist", _METRIC_INFOS)
def test_distance_params(dist: MetricInfo):
"""Test parametisation of distance callables."""
if not run_test_for_class(dist.dist_func):
return None

if dist.canonical_name in DIST_PARAMS:
_test_distance_params(
DIST_PARAMS[dist.canonical_name],
Expand Down
4 changes: 4 additions & 0 deletions sktime/distances/tests/test_numba_distances.py
Expand Up @@ -16,6 +16,7 @@
_test_metric_parameters,
)
from sktime.distances.tests._utils import create_test_distance_numpy
from sktime.tests.test_switch import run_test_for_class
from sktime.utils.validation._dependencies import _check_soft_dependencies

_ran_once = False
Expand Down Expand Up @@ -165,6 +166,9 @@ def test_distance(dist: MetricInfo) -> None:
distance_function = dist.dist_func
distance_factory = distance_numba_class.distance_factory

if not run_test_for_class(distance_function):
return None

_validate_distance_result(
x=np.array([10.0]),
y=np.array([15.0]),
Expand Down
19 changes: 13 additions & 6 deletions sktime/tests/test_switch.py
Expand Up @@ -5,26 +5,27 @@


def run_test_for_class(cls):
"""Check if test should run for a class.
"""Check if test should run for a class or function.

This checks the following conditions:

1. whether all required soft dependencies are not present.
1. whether all required soft dependencies are present.
If not, does not run the test.
2. If yes:
* if ONLY_CHANGED_MODULES setting is on, runs the test if and only
if the module containing the class has changed according to is_class_changed
if the module containing the class/func has changed according to is_class_changed
* if ONLY_CHANGED_MODULES if off, always runs the test if all soft dependencies
are present.

cls can also be a list, in this case the test is run if and only if:
cls can also be a list of classes or functions,
in this case the test is run if and only if:

* all required soft dependencies are present
* if yes, if any of the estimators in the list should be tested by criterion 2 above

Parameters
----------
cls : class or list of class
cls : class, function or list of classes/functions
class for which to determine whether it should be tested

Returns
Expand All @@ -39,8 +40,14 @@ class for which to determine whether it should be tested
from sktime.utils.git_diff import is_class_changed
from sktime.utils.validation._dependencies import _check_estimator_deps

def _required_deps_present(obj):
if hasattr(obj, "get_class_tag"):
return _check_estimator_deps(obj, severity="none")
else:
return True

# if any of the required soft dependencies are not present, do not run the test
if not all(_check_estimator_deps(x, severity="none") for x in cls):
if not all(_required_deps_present(x) for x in cls):
return False

# if ONLY_CHANGED_MODULES is on, run the test if and only if
Expand Down