Skip to content

Commit

Permalink
[ENH] test for more than one parameter sets per estimator (#2862)
Browse files Browse the repository at this point in the history
In order to test parameter settings for estimators properly, each
important parameter should have been set to a non-default.
This PR introduces a test which checks that there are at least two test
parameter sets (if the estimator has at least one parameter), and
correctness of these sets.

Requires:
* #4279 for the general test for
`get_test_params` (no specific number assumed)
* #2835 as it uses
the parameter name interface
* #3428 to pass the
`no-softdeps` step.

Current estimators are not tested due to differential testing.

If diff testing is turned off, helps tracking down the estimators that
do not yet have two parameter sets, see #3429.
  • Loading branch information
fkiraly committed Aug 12, 2023
1 parent 4815c01 commit 9a384d5
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions sktime/tests/test_all_estimators.py
Expand Up @@ -768,13 +768,14 @@ def _coerce_to_list_of_str(obj):
else:
return []

# reserved_param_names = estimator_class.get_class_tag(
# "reserved_params", tag_value_default=None
# )
# reserved_param_names = _coerce_to_list_of_str(reserved_param_names)
# reserved_set = set(reserved_param_names)
reserved_param_names = estimator_class.get_class_tag(
"reserved_params", tag_value_default=None
)
reserved_param_names = _coerce_to_list_of_str(reserved_param_names)
reserved_set = set(reserved_param_names)

param_names = estimator_class.get_param_names()
unreserved_param_names = set(param_names).difference(reserved_set)

key_list = [x.keys() for x in param_list]

Expand All @@ -799,6 +800,22 @@ def _coerce_to_list_of_str(obj):
f"but found some parameters that are not __init__ args: {notfound_errs}"
)

if len(unreserved_param_names) > 0:
assert (
len(param_list) > 1
), "get_test_params should return at least two test parameter sets"
params_tested = set()
for params in param_list:
params_tested = params_tested.union(params.keys())

# this test is too harsh for the current estimator base
# params_not_tested = set(unreserved_param_names).difference(params_tested)
# assert len(params_not_tested) == 0, (
# f"get_test_params shoud set each parameter of {estimator_class} "
# f"to a non-default value at least once, but the following "
# f"parameters are not tested: {params_not_tested}"
# )

def test_create_test_instances_and_names(self, estimator_class):
"""Check that create_test_instances_and_names works.
Expand Down

0 comments on commit 9a384d5

Please sign in to comment.