Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ENH] basic "test all estimators" suite (#89)
Adds a `TestAllEstimators` which checks generic fit/predict capabilities of estimators. This is partly migrated from `sktime`, and uses the `scikit-base` test framework classes (also originally from `sktime`). This can also be used as a joint refactor point with `sktime` and `skbase`. Additions and changes: * `TestAllEstimators` class * this requires a `BaseFixtureGenerator` which generates `scenarios`. For this, the `skpro` fixture generator now inherits from `skbase`. * Also, scenarios are required. This is not available in `skbase` yet, so the `sktime` scenario framework is copied to `tests.scenarios`. * In incremental testing, an override is introduced so all classes are tested if the central test framework module changes.
- Loading branch information
Showing
6 changed files
with
708 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Test scenarios for estimators.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,271 @@ | ||
"""Testing utility to play back usage scenarios for estimators. | ||
Contains TestScenario class which applies method/args subsequently | ||
""" | ||
# copied from sktime. Should be jointly refactored to scikit-base. | ||
|
||
__author__ = ["fkiraly"] | ||
|
||
__all__ = ["TestScenario"] | ||
|
||
|
||
from copy import deepcopy | ||
from inspect import isclass | ||
|
||
|
||
class TestScenario: | ||
"""Class to run pre-defined method execution scenarios for objects. | ||
Parameters | ||
---------- | ||
args : dict of dict, default = None | ||
dict of argument dicts to be used in methods | ||
names for keys need not equal names of methods these are used in | ||
but scripted method will look at key with same name as default | ||
must be passed to constructor, set in a child class | ||
or dynamically created in get_args | ||
default_method_sequence : list of str, default = None | ||
default sequence for methods to be called | ||
optional, if given, default method sequence to use in `run` | ||
if not provided, at least one of the sequence arguments must be passed in `run` | ||
or default_arg_sequence must be provided | ||
default_arg_sequence : list of str, default = None | ||
default sequence of keys for keyword argument dicts to be used | ||
names for keys need not equal names of methods | ||
if not provided, at least one of the sequence arguments must be passed in `run` | ||
or default_method_sequence must be provided | ||
Methods | ||
------- | ||
run(obj, args=None, default_method_sequence=None) | ||
Run a call(args) scenario on obj, and retrieve method outputs. | ||
is_applicable(obj) | ||
Check whether scenario is applicable to obj. | ||
get_args(key, obj) | ||
Dynamically create args for call defined by key and obj. | ||
Defaults to self.args[key] if not overridden. | ||
""" | ||
|
||
def __init__( | ||
self, args=None, default_method_sequence=None, default_arg_sequence=None | ||
): | ||
if default_method_sequence is not None: | ||
self.default_method_sequence = _check_list_of_str(default_method_sequence) | ||
elif not hasattr(self, "default_method_sequence"): | ||
self.default_method_sequence = None | ||
if default_arg_sequence is not None: | ||
self.default_arg_sequence = _check_list_of_str(default_arg_sequence) | ||
elif not hasattr(self, "default_arg_sequence"): | ||
self.default_arg_sequence = None | ||
if args is not None: | ||
self.args = _check_dict_of_dict(args) | ||
else: | ||
if not hasattr(self, "args"): | ||
raise RuntimeError( | ||
f"{self.__class__.__name__} (scenario class) failed to construct, " | ||
"args must either be given to __init__ or set as an attribute" | ||
) | ||
_check_dict_of_dict(self.args) | ||
|
||
def get_args(self, key, obj=None, deepcopy_args=True): | ||
"""Return args for key. Can be overridden for dynamic arg generation. | ||
If overridden, must not have any side effects on self.args | ||
e.g., avoid assignments args[key] = x without deepcopying self.args first | ||
Parameters | ||
---------- | ||
key : str, argument key to construct/retrieve args for | ||
obj : obj, optional, default=None. Object to construct args for. | ||
deepcopy_args : bool, optional, default=True. Whether to deepcopy return. | ||
Returns | ||
------- | ||
args : argument dict to be used for a method, keyed by `key` | ||
names for keys need not equal names of methods these are used in | ||
but scripted method will look at key with same name as default | ||
""" | ||
args = self.args.get(key, {}) | ||
if deepcopy_args: | ||
args = deepcopy(args) | ||
return args | ||
|
||
def run( | ||
self, | ||
obj, | ||
method_sequence=None, | ||
arg_sequence=None, | ||
return_all=False, | ||
return_args=False, | ||
deepcopy_return=False, | ||
): | ||
"""Run a call(args) scenario on obj, and retrieve method outputs. | ||
Runs a sequence of commands | ||
res_1 = obj.method_1(**args_1) | ||
res_2 = obj.method_2(**args_2) | ||
etc, where method_i is method_sequence[i], | ||
and args_i is self.args[arg_sequence[i]] | ||
and returns results. Args are passed as deepcopy to avoid side effects. | ||
if method_i is __init__ (a constructor), | ||
obj is changed to obj.__init__(**args_i) from the next line on | ||
Parameters | ||
---------- | ||
obj : class or object with methods in method_sequence | ||
method_sequence : list of str, default = arg_sequence if passed | ||
if arg_sequence is also None, then default = self.default_method_sequence | ||
sequence of method names to be run | ||
arg_sequence : list of str, default = method_sequence if passed | ||
if method_sequence is also None, then default = self.default_arg_sequence | ||
sequence of keys for keyword argument dicts to be used | ||
names for keys need not equal names of methods | ||
return_all : bool, default = False | ||
whether all or only the last result should be returned | ||
if False, only the last result is returned | ||
if True, list of deepcopies of intermediate results is returned | ||
return_args : bool, default = False | ||
whether arguments should also be returned | ||
if False, there is no second return argument | ||
if True, "args_after_call" return argument is returned | ||
deepcopy_return : bool, default = False | ||
whether returns are deepcopied before returned | ||
if True, returns are deepcopies of return | ||
if False, returns are references/assignments, not deepcopies | ||
NOTE: if self is returned (e.g., in fit), and deepcopy_return=False | ||
method calls may continue to have side effects on that return | ||
Returns | ||
------- | ||
results : output of the last method call, if return_all = False | ||
list of deepcopies of all outputs, if return_all = True | ||
args_after_call : list of args after method call, only if return_args = True | ||
i-th element is deepcopy of args of i-th method call, after method call | ||
this is possibly subject to side effects by the method | ||
""" | ||
# if both None, fill with defaults if exist | ||
if method_sequence is None and arg_sequence is None: | ||
method_sequence = getattr(self, "default_method_sequence", None) | ||
arg_sequence = getattr(self, "default_arg_sequence", None) | ||
|
||
# if both are still None, raise an error | ||
if method_sequence is None and arg_sequence is None: | ||
raise ValueError( | ||
"at least one of method_sequence, arg_sequence must be not None " | ||
"if no defaults are set in the class" | ||
) | ||
|
||
# if only one is None, fill one with the other | ||
if method_sequence is None: | ||
method_sequence = _check_list_of_str(arg_sequence) | ||
else: | ||
method_sequence = _check_list_of_str(method_sequence) | ||
if arg_sequence is None: | ||
arg_sequence = _check_list_of_str(method_sequence) | ||
else: | ||
arg_sequence = _check_list_of_str(arg_sequence) | ||
|
||
# check that length of sequences is the same | ||
num_calls = len(arg_sequence) | ||
if not num_calls == len(method_sequence): | ||
raise ValueError("arg_sequence and method_sequence must have same length") | ||
|
||
# execute the commands in sequence, report result(s) | ||
results = [] | ||
args_after_call = [] | ||
for i in range(num_calls): | ||
methodname = method_sequence[i] | ||
args = deepcopy(self.get_args(key=arg_sequence[i], obj=obj)) | ||
|
||
if methodname != "__init__": | ||
res = getattr(obj, methodname)(**args) | ||
# if constructor is called, run directly and replace obj | ||
else: | ||
if isclass(obj): | ||
res = obj(**args) | ||
else: | ||
res = type(obj)(**args) | ||
obj = res | ||
|
||
args_after_call += [args] | ||
|
||
if deepcopy_return: | ||
res = deepcopy(res) | ||
|
||
if return_all: | ||
results += [res] | ||
else: | ||
results = res | ||
|
||
if return_args: | ||
return results, args_after_call | ||
else: | ||
return results | ||
|
||
def is_applicable(self, obj): | ||
"""Check whether scenario is applicable to obj. | ||
Abstract method, children should implement. This just returns "true". | ||
Example for child class: scenario is univariate time series forecasting. | ||
Then, this returns False on multivariate, True on univariate forecasters. | ||
Parameters | ||
---------- | ||
obj : class or object to check against scenario | ||
Returns | ||
------- | ||
applicable: bool | ||
True if self is applicable to obj, False if not | ||
"applicable" is defined as the implementer chooses, as output of this method | ||
False is typically used as a "skip" flag in unit or integration testing | ||
""" | ||
return True | ||
|
||
|
||
def _check_list_of_str(obj, name="obj"): | ||
"""Check whether obj is a list of str. | ||
Parameters | ||
---------- | ||
obj : any object, check whether is list of str | ||
name : str, default="obj", name of obj to display in error message | ||
Returns | ||
------- | ||
obj, unaltered | ||
Raises | ||
------ | ||
TypeError if obj is not list of str | ||
""" | ||
if not isinstance(obj, list) or not all(isinstance(x, str) for x in obj): | ||
raise TypeError(f"{name} must be a list of str") | ||
return obj | ||
|
||
|
||
def _check_dict_of_dict(obj, name="obj"): | ||
"""Check whether obj is a dict of dict, with str keys. | ||
Parameters | ||
---------- | ||
obj : any object, check whether is dict of dict, with str keys | ||
name : str, default="obj", name of obj to display in error message | ||
Returns | ||
------- | ||
obj, unaltered | ||
Raises | ||
------ | ||
TypeError if obj is not dict of dict, with str keys | ||
""" | ||
if not ( | ||
isinstance(obj, dict) | ||
and all(isinstance(x, dict) for x in obj.values()) | ||
and all(isinstance(x, str) for x in obj.keys()) | ||
): | ||
raise TypeError(f"{name} must be a dict of dict, with str keys") | ||
return obj |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
"""Retrieval utility for test scenarios.""" | ||
# copied from sktime. Should be jointly refactored to scikit-base. | ||
|
||
__author__ = ["fkiraly"] | ||
|
||
__all__ = ["retrieve_scenarios"] | ||
|
||
|
||
from inspect import isclass | ||
|
||
from skpro.tests.scenarios.scenarios_regressor_proba import scenarios_regressor_proba | ||
|
||
scenarios = dict() | ||
scenarios["regressor_proba"] = scenarios_regressor_proba | ||
|
||
|
||
def retrieve_scenarios(obj, filter_tags=None): | ||
"""Retrieve test scenarios for obj, or by estimator scitype string. | ||
Exactly one of the arguments obj, estimator_type must be provided. | ||
Parameters | ||
---------- | ||
obj : class or object, or string, or list of str. | ||
Which kind of estimator/object to retrieve scenarios for. | ||
If object, must be a class or object inheriting from BaseObject. | ||
If string(s), must be in registry.BASE_CLASS_REGISTER (first col) | ||
for instance 'classifier', 'regressor', 'transformer', 'forecaster' | ||
filter_tags: dict of (str or list of str), default=None | ||
subsets the returned objectss as follows: | ||
each key/value pair is statement in "and"/conjunction | ||
key is tag name to sub-set on | ||
value str or list of string are tag values | ||
condition is "key must be equal to value, or in set(value)" | ||
Returns | ||
------- | ||
scenarios : list of objects, instances of BaseScenario | ||
""" | ||
# if class, get scitypes from inference; otherwise, str or list of str | ||
if not isinstance(obj, str): | ||
if isclass(obj): | ||
if hasattr(obj, "get_class_tag"): | ||
estimator_type = obj.get_class_tag("object_type", "object") | ||
else: | ||
estimator_type = "object" | ||
else: | ||
if hasattr(obj, "get_tag"): | ||
estimator_type = obj.get_tag("object_type", "object", False) | ||
else: | ||
estimator_type = "object" | ||
else: | ||
estimator_type = obj | ||
|
||
# coerce to list, ensure estimator_type is list of str | ||
if not isinstance(estimator_type, list): | ||
estimator_type = [estimator_type] | ||
|
||
# now loop through types and retrieve scenarios | ||
scenarios_for_type = [] | ||
for est_type in estimator_type: | ||
scens = scenarios.get(est_type) | ||
if scens is not None: | ||
scenarios_for_type += scenarios.get(est_type) | ||
|
||
# instantiate all scenarios by calling constructor | ||
scenarios_for_type = [x() for x in scenarios_for_type] | ||
|
||
# if obj was an object, filter to applicable scenarios | ||
if not isinstance(obj, str) and not isinstance(obj, list): | ||
scenarios_for_type = [x for x in scenarios_for_type if x.is_applicable(obj)] | ||
|
||
if filter_tags is not None: | ||
scenarios_for_type = [ | ||
scen for scen in scenarios_for_type if _check_tag_cond(scen, filter_tags) | ||
] | ||
|
||
return scenarios_for_type | ||
|
||
|
||
def _check_tag_cond(obj, filter_tags=None): | ||
"""Check whether object satisfies filter_tags condition. | ||
Parameters | ||
---------- | ||
obj: object inheriting from sktime BaseObject | ||
filter_tags: dict of (str or list of str), default=None | ||
subsets the returned objectss as follows: | ||
each key/value pair is statement in "and"/conjunction | ||
key is tag name to sub-set on | ||
value str or list of string are tag values | ||
condition is "key must be equal to value, or in set(value)" | ||
Returns | ||
------- | ||
cond_sat: bool, whether estimator satisfies condition in filter_tags | ||
""" | ||
if not isinstance(filter_tags, dict): | ||
raise TypeError("filter_tags must be a dict") | ||
|
||
cond_sat = True | ||
|
||
for key, value in filter_tags.items(): | ||
if not isinstance(value, list): | ||
value = [value] | ||
cond_sat = cond_sat and obj.get_class_tag(key) in set(value) | ||
|
||
return cond_sat |
Oops, something went wrong.