Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 38 additions & 4 deletions skbase/validate/_named_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def is_sequence_named_objects(
seq_to_check: Union[Sequence[Tuple[str, BaseObject]], Dict[str, BaseObject]],
allow_dict: bool = True,
require_unique_names=False,
object_type: Optional[Union[type, Tuple[type]]] = None,
) -> bool:
"""Indicate if input is a sequence of named BaseObject instances.

Expand Down Expand Up @@ -68,6 +69,10 @@ def is_sequence_named_objects(
depends on whether `seq_to_check` follows sequence of named
BaseObject format.

object_type : class or tuple[class], default=None
The class type(s) that is used to ensure that all elements of named objects
match the expected type.

Returns
-------
bool
Expand All @@ -82,7 +87,7 @@ def is_sequence_named_objects(

Examples
--------
>>> from skbase.base import BaseObject
>>> from skbase.base import BaseObject, BaseEstimator
>>> from skbase.validate import is_sequence_named_objects
>>> named_objects = [("Step 1", BaseObject()), ("Step 2", BaseObject())]
>>> is_sequence_named_objects(named_objects)
Expand All @@ -107,9 +112,21 @@ def is_sequence_named_objects(
>>> named_items = [("1", 7), ("2", 42)]
>>> is_sequence_named_objects(named_items)
False

The validation can require the object elements to be a certain class type

>>> named_objects = [("Step 1", BaseObject()), ("Step 2", BaseObject())]
>>> is_sequence_named_objects(named_objects, object_type=BaseEstimator)
False
>>> named_objects = [("Step 1", BaseEstimator()), ("Step 2", BaseEstimator())]
>>> is_sequence_named_objects(named_objects, object_type=BaseEstimator)
True
"""
# Want to end quickly if the input isn't sequence or is a dict and we
# aren't allowing dicts
if object_type is None:
object_type = BaseObject

is_dict = isinstance(seq_to_check, dict)
if (not is_dict and not isinstance(seq_to_check, collections.abc.Sequence)) or (
not allow_dict and is_dict
Expand All @@ -123,7 +140,7 @@ def is_sequence_named_objects(
if TYPE_CHECKING: # pragma: no cover
assert isinstance(seq_to_check, dict) # nosec B101
elements_expected_format = [
isinstance(name, str) and isinstance(obj, BaseObject)
isinstance(name, str) and isinstance(obj, object_type)
for name, obj in seq_to_check.items()
]
all_unique_names = True
Expand All @@ -134,7 +151,7 @@ def is_sequence_named_objects(
if (
isinstance(it, tuple)
and len(it) == 2
and (isinstance(it[0], str) and isinstance(it[1], BaseObject))
and (isinstance(it[0], str) and isinstance(it[1], object_type))
):
elements_expected_format.append(True)
names.append(it[0])
Expand All @@ -157,6 +174,7 @@ def check_sequence_named_objects(
seq_to_check: Union[Sequence[Tuple[str, BaseObject]], Dict[str, BaseObject]],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really don't like this additional amount of non-functional code generated from overload, but it's a matter of taste and not a blocker for me...

allow_dict: Literal[True] = True,
require_unique_names=False,
object_type: Optional[Union[type, Tuple[type]]] = None,
sequence_name: Optional[str] = None,
) -> Union[Sequence[Tuple[str, BaseObject]], Dict[str, BaseObject]]:
... # pragma: no cover
Expand All @@ -167,6 +185,7 @@ def check_sequence_named_objects(
seq_to_check: Sequence[Tuple[str, BaseObject]],
allow_dict: Literal[False],
require_unique_names=False,
object_type: Optional[Union[type, Tuple[type]]] = None,
sequence_name: Optional[str] = None,
) -> Sequence[Tuple[str, BaseObject]]:
... # pragma: no cover
Expand All @@ -177,6 +196,7 @@ def check_sequence_named_objects(
seq_to_check: Union[Sequence[Tuple[str, BaseObject]], Dict[str, BaseObject]],
allow_dict: bool = True,
require_unique_names=False,
object_type: Optional[Union[type, Tuple[type]]] = None,
sequence_name: Optional[str] = None,
) -> Union[Sequence[Tuple[str, BaseObject]], Dict[str, BaseObject]]:
... # pragma: no cover
Expand All @@ -186,6 +206,7 @@ def check_sequence_named_objects(
seq_to_check: Union[Sequence[Tuple[str, BaseObject]], Dict[str, BaseObject]],
allow_dict: bool = True,
require_unique_names=False,
object_type: Optional[Union[type, Tuple[type]]] = None,
sequence_name: Optional[str] = None,
) -> Union[Sequence[Tuple[str, BaseObject]], Dict[str, BaseObject]]:
"""Check if input is a sequence of named BaseObject instances.
Expand Down Expand Up @@ -222,6 +243,9 @@ def check_sequence_named_objects(
- If False, then whether or not the function returns True or False
depends on whether `seq_to_check` follows sequence of named BaseObject format.

object_type : class or tuple[class], default=None
The class type(s) that is used to ensure that all elements of named objects
match the expected type.
sequence_name : str, default=None
Optional name used to refer to the input `seq_to_check` when
raising any errors. Ignored ``raise_error=False``.
Expand All @@ -242,7 +266,7 @@ def check_sequence_named_objects(

Examples
--------
>>> from skbase.base import BaseObject
>>> from skbase.base import BaseObject, BaseEstimator
>>> from skbase.validate import check_sequence_named_objects
>>> named_objects = [("Step 1", BaseObject()), ("Step 2", BaseObject())]
>>> check_sequence_named_objects(named_objects)
Expand All @@ -267,11 +291,21 @@ def check_sequence_named_objects(

>>> named_items = [("1", 7), ("2", 42)]
>>> check_sequence_named_objects(named_items) # doctest: +SKIP

The validation can require the object elements to be a certain class type

>>> named_objects = [("Step 1", BaseObject()), ("Step 2", BaseObject())]
>>> check_sequence_named_objects( \
named_objects, object_type=BaseEstimator) # doctest: +SKIP
>>> named_objects = [("Step 1", BaseEstimator()), ("Step 2", BaseEstimator())]
>>> check_sequence_named_objects(named_objects, object_type=BaseEstimator)
[('Step 1', BaseEstimator()), ('Step 2', BaseEstimator())]
"""
is_expected_format = is_sequence_named_objects(
seq_to_check,
allow_dict=allow_dict,
require_unique_names=require_unique_names,
object_type=object_type,
)
# Raise error is format is not expected.
if not is_expected_format:
Expand Down
32 changes: 32 additions & 0 deletions skbase/validate/tests/test_iterable_named_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,20 @@ def test_is_sequence_named_objects_output(
]
assert is_sequence_named_objects(c for c in named_objects) is False

# Validate use of object_type parameter
# Won't work because one named object is a BaseObject but not a BaseEstimator
assert is_sequence_named_objects(named_objects, object_type=BaseEstimator) is False

# Should work because we allow BaseObject or BaseEstimator types
named_objects = [("Step 1", BaseEstimator()), ("Step 2", BaseEstimator())]
assert (
is_sequence_named_objects(
named_objects, object_type=(BaseObject, BaseEstimator)
)
is True
)
assert is_sequence_named_objects(named_objects, object_type=BaseEstimator) is True


def test_check_sequence_named_objects_output(
fixture_estimator_instance, fixture_object_instance
Expand Down Expand Up @@ -131,3 +145,21 @@ def test_check_sequence_named_objects_output(
]
with pytest.raises(ValueError):
check_sequence_named_objects(c for c in named_objects)

# Validate use of object_type parameter
# Won't work because one named object is a BaseObject but not a BaseEstimator
with pytest.raises(ValueError):
check_sequence_named_objects(named_objects, object_type=BaseEstimator)

# Should work because we allow BaseObject or BaseEstimator types
named_objects = [("Step 1", BaseEstimator()), ("Step 2", BaseEstimator())]
assert (
check_sequence_named_objects(
named_objects, object_type=(BaseObject, BaseEstimator)
)
== named_objects
)
assert (
check_sequence_named_objects(named_objects, object_type=BaseEstimator)
== named_objects
)