From 61e19d9c17f0a4dd15de71d5ec67130616d03301 Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Wed, 22 Feb 2023 20:31:58 -0500 Subject: [PATCH 1/2] Add object_type param to named object check --- skbase/validate/_named_objects.py | 42 +++++++++++++++++-- .../tests/test_iterable_named_objects.py | 10 +++++ 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/skbase/validate/_named_objects.py b/skbase/validate/_named_objects.py index 90b55ebf..d91a4d92 100644 --- a/skbase/validate/_named_objects.py +++ b/skbase/validate/_named_objects.py @@ -33,6 +33,7 @@ def is_sequence_named_objects( seq_to_check: Union[Sequence[Tuple[str, BaseObject]], Dict[str, BaseObject]], allow_dict: bool = True, require_unique_names=False, + object_type: Optional[type] = None, ) -> bool: """Indicate if input is a sequence of named BaseObject instances. @@ -68,6 +69,10 @@ def is_sequence_named_objects( depends on whether `seq_to_check` follows sequence of named BaseObject format. + object_type : class, default=None + The class type that is used to ensure that all elements of named objects + match the expected type. + Returns ------- bool @@ -82,7 +87,7 @@ def is_sequence_named_objects( Examples -------- - >>> from skbase.base import BaseObject + >>> from skbase.base import BaseObject, BaseEstimator >>> from skbase.validate import is_sequence_named_objects >>> named_objects = [("Step 1", BaseObject()), ("Step 2", BaseObject())] >>> is_sequence_named_objects(named_objects) @@ -107,9 +112,21 @@ def is_sequence_named_objects( >>> named_items = [("1", 7), ("2", 42)] >>> is_sequence_named_objects(named_items) False + + The validation can require the object elements to be a certain class type + + >>> named_objects = [("Step 1", BaseObject()), ("Step 2", BaseObject())] + >>> is_sequence_named_objects(named_objects, object_type=BaseEstimator) + False + >>> named_objects = [("Step 1", BaseEstimator()), ("Step 2", BaseEstimator())] + >>> is_sequence_named_objects(named_objects, object_type=BaseEstimator) + True """ # Want to end quickly if the input isn't sequence or is a dict and we # aren't allowing dicts + if object_type is None: + object_type = BaseObject + is_dict = isinstance(seq_to_check, dict) if (not is_dict and not isinstance(seq_to_check, collections.abc.Sequence)) or ( not allow_dict and is_dict @@ -123,7 +140,7 @@ def is_sequence_named_objects( if TYPE_CHECKING: # pragma: no cover assert isinstance(seq_to_check, dict) # nosec B101 elements_expected_format = [ - isinstance(name, str) and isinstance(obj, BaseObject) + isinstance(name, str) and isinstance(obj, object_type) for name, obj in seq_to_check.items() ] all_unique_names = True @@ -134,7 +151,7 @@ def is_sequence_named_objects( if ( isinstance(it, tuple) and len(it) == 2 - and (isinstance(it[0], str) and isinstance(it[1], BaseObject)) + and (isinstance(it[0], str) and isinstance(it[1], object_type)) ): elements_expected_format.append(True) names.append(it[0]) @@ -157,6 +174,7 @@ def check_sequence_named_objects( seq_to_check: Union[Sequence[Tuple[str, BaseObject]], Dict[str, BaseObject]], allow_dict: Literal[True] = True, require_unique_names=False, + object_type: Optional[type] = None, sequence_name: Optional[str] = None, ) -> Union[Sequence[Tuple[str, BaseObject]], Dict[str, BaseObject]]: ... # pragma: no cover @@ -167,6 +185,7 @@ def check_sequence_named_objects( seq_to_check: Sequence[Tuple[str, BaseObject]], allow_dict: Literal[False], require_unique_names=False, + object_type: Optional[type] = None, sequence_name: Optional[str] = None, ) -> Sequence[Tuple[str, BaseObject]]: ... # pragma: no cover @@ -177,6 +196,7 @@ def check_sequence_named_objects( seq_to_check: Union[Sequence[Tuple[str, BaseObject]], Dict[str, BaseObject]], allow_dict: bool = True, require_unique_names=False, + object_type: Optional[type] = None, sequence_name: Optional[str] = None, ) -> Union[Sequence[Tuple[str, BaseObject]], Dict[str, BaseObject]]: ... # pragma: no cover @@ -186,6 +206,7 @@ def check_sequence_named_objects( seq_to_check: Union[Sequence[Tuple[str, BaseObject]], Dict[str, BaseObject]], allow_dict: bool = True, require_unique_names=False, + object_type: Optional[type] = None, sequence_name: Optional[str] = None, ) -> Union[Sequence[Tuple[str, BaseObject]], Dict[str, BaseObject]]: """Check if input is a sequence of named BaseObject instances. @@ -222,6 +243,9 @@ def check_sequence_named_objects( - If False, then whether or not the function returns True or False depends on whether `seq_to_check` follows sequence of named BaseObject format. + object_type : class, default=None + The class type that is used to ensure that all elements of named objects + match the expected type. sequence_name : str, default=None Optional name used to refer to the input `seq_to_check` when raising any errors. Ignored ``raise_error=False``. @@ -242,7 +266,7 @@ def check_sequence_named_objects( Examples -------- - >>> from skbase.base import BaseObject + >>> from skbase.base import BaseObject, BaseEstimator >>> from skbase.validate import check_sequence_named_objects >>> named_objects = [("Step 1", BaseObject()), ("Step 2", BaseObject())] >>> check_sequence_named_objects(named_objects) @@ -267,11 +291,21 @@ def check_sequence_named_objects( >>> named_items = [("1", 7), ("2", 42)] >>> check_sequence_named_objects(named_items) # doctest: +SKIP + + The validation can require the object elements to be a certain class type + + >>> named_objects = [("Step 1", BaseObject()), ("Step 2", BaseObject())] + >>> check_sequence_named_objects( \ + named_objects, object_type=BaseEstimator) # doctest: +SKIP + >>> named_objects = [("Step 1", BaseEstimator()), ("Step 2", BaseEstimator())] + >>> check_sequence_named_objects(named_objects, object_type=BaseEstimator) + [('Step 1', BaseEstimator()), ('Step 2', BaseEstimator())] """ is_expected_format = is_sequence_named_objects( seq_to_check, allow_dict=allow_dict, require_unique_names=require_unique_names, + object_type=object_type, ) # Raise error is format is not expected. if not is_expected_format: diff --git a/skbase/validate/tests/test_iterable_named_objects.py b/skbase/validate/tests/test_iterable_named_objects.py index 60b320d1..1a9335f2 100644 --- a/skbase/validate/tests/test_iterable_named_objects.py +++ b/skbase/validate/tests/test_iterable_named_objects.py @@ -131,3 +131,13 @@ def test_check_sequence_named_objects_output( ] with pytest.raises(ValueError): check_sequence_named_objects(c for c in named_objects) + + # Validate use of object_type parameter + with pytest.raises(ValueError): + check_sequence_named_objects(named_objects, object_type=BaseEstimator) + + named_objects = [("Step 1", BaseEstimator()), ("Step 2", BaseEstimator())] + assert ( + check_sequence_named_objects(named_objects, object_type=BaseEstimator) + == named_objects + ) From 9ab5d1a6bfdfea8a5d5ca3739eb4cd9ee144bb81 Mon Sep 17 00:00:00 2001 From: rnkuhns Date: Thu, 23 Feb 2023 17:54:08 -0500 Subject: [PATCH 2/2] Document that tuples are allowed for object_type param --- skbase/validate/_named_objects.py | 18 +++++++-------- .../tests/test_iterable_named_objects.py | 22 +++++++++++++++++++ 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/skbase/validate/_named_objects.py b/skbase/validate/_named_objects.py index d91a4d92..edd1035e 100644 --- a/skbase/validate/_named_objects.py +++ b/skbase/validate/_named_objects.py @@ -33,7 +33,7 @@ def is_sequence_named_objects( seq_to_check: Union[Sequence[Tuple[str, BaseObject]], Dict[str, BaseObject]], allow_dict: bool = True, require_unique_names=False, - object_type: Optional[type] = None, + object_type: Optional[Union[type, Tuple[type]]] = None, ) -> bool: """Indicate if input is a sequence of named BaseObject instances. @@ -69,8 +69,8 @@ def is_sequence_named_objects( depends on whether `seq_to_check` follows sequence of named BaseObject format. - object_type : class, default=None - The class type that is used to ensure that all elements of named objects + object_type : class or tuple[class], default=None + The class type(s) that is used to ensure that all elements of named objects match the expected type. Returns @@ -174,7 +174,7 @@ def check_sequence_named_objects( seq_to_check: Union[Sequence[Tuple[str, BaseObject]], Dict[str, BaseObject]], allow_dict: Literal[True] = True, require_unique_names=False, - object_type: Optional[type] = None, + object_type: Optional[Union[type, Tuple[type]]] = None, sequence_name: Optional[str] = None, ) -> Union[Sequence[Tuple[str, BaseObject]], Dict[str, BaseObject]]: ... # pragma: no cover @@ -185,7 +185,7 @@ def check_sequence_named_objects( seq_to_check: Sequence[Tuple[str, BaseObject]], allow_dict: Literal[False], require_unique_names=False, - object_type: Optional[type] = None, + object_type: Optional[Union[type, Tuple[type]]] = None, sequence_name: Optional[str] = None, ) -> Sequence[Tuple[str, BaseObject]]: ... # pragma: no cover @@ -196,7 +196,7 @@ def check_sequence_named_objects( seq_to_check: Union[Sequence[Tuple[str, BaseObject]], Dict[str, BaseObject]], allow_dict: bool = True, require_unique_names=False, - object_type: Optional[type] = None, + object_type: Optional[Union[type, Tuple[type]]] = None, sequence_name: Optional[str] = None, ) -> Union[Sequence[Tuple[str, BaseObject]], Dict[str, BaseObject]]: ... # pragma: no cover @@ -206,7 +206,7 @@ def check_sequence_named_objects( seq_to_check: Union[Sequence[Tuple[str, BaseObject]], Dict[str, BaseObject]], allow_dict: bool = True, require_unique_names=False, - object_type: Optional[type] = None, + object_type: Optional[Union[type, Tuple[type]]] = None, sequence_name: Optional[str] = None, ) -> Union[Sequence[Tuple[str, BaseObject]], Dict[str, BaseObject]]: """Check if input is a sequence of named BaseObject instances. @@ -243,8 +243,8 @@ def check_sequence_named_objects( - If False, then whether or not the function returns True or False depends on whether `seq_to_check` follows sequence of named BaseObject format. - object_type : class, default=None - The class type that is used to ensure that all elements of named objects + object_type : class or tuple[class], default=None + The class type(s) that is used to ensure that all elements of named objects match the expected type. sequence_name : str, default=None Optional name used to refer to the input `seq_to_check` when diff --git a/skbase/validate/tests/test_iterable_named_objects.py b/skbase/validate/tests/test_iterable_named_objects.py index 1a9335f2..ba3fc973 100644 --- a/skbase/validate/tests/test_iterable_named_objects.py +++ b/skbase/validate/tests/test_iterable_named_objects.py @@ -79,6 +79,20 @@ def test_is_sequence_named_objects_output( ] assert is_sequence_named_objects(c for c in named_objects) is False + # Validate use of object_type parameter + # Won't work because one named object is a BaseObject but not a BaseEstimator + assert is_sequence_named_objects(named_objects, object_type=BaseEstimator) is False + + # Should work because we allow BaseObject or BaseEstimator types + named_objects = [("Step 1", BaseEstimator()), ("Step 2", BaseEstimator())] + assert ( + is_sequence_named_objects( + named_objects, object_type=(BaseObject, BaseEstimator) + ) + is True + ) + assert is_sequence_named_objects(named_objects, object_type=BaseEstimator) is True + def test_check_sequence_named_objects_output( fixture_estimator_instance, fixture_object_instance @@ -133,10 +147,18 @@ def test_check_sequence_named_objects_output( check_sequence_named_objects(c for c in named_objects) # Validate use of object_type parameter + # Won't work because one named object is a BaseObject but not a BaseEstimator with pytest.raises(ValueError): check_sequence_named_objects(named_objects, object_type=BaseEstimator) + # Should work because we allow BaseObject or BaseEstimator types named_objects = [("Step 1", BaseEstimator()), ("Step 2", BaseEstimator())] + assert ( + check_sequence_named_objects( + named_objects, object_type=(BaseObject, BaseEstimator) + ) + == named_objects + ) assert ( check_sequence_named_objects(named_objects, object_type=BaseEstimator) == named_objects