/
_named_objects.py
351 lines (277 loc) · 12.7 KB
/
_named_objects.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
# -*- coding: utf-8 -*-
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
"""Validate if an input is one of the allowed named object formats."""
import collections.abc
from skbase.base import BaseObject
__all__ = [
"check_sequence_named_objects",
"is_named_object_tuple",
"is_sequence_named_objects",
]
__author__ = ["RNKuhns"]
def _named_baseobject_error_msg(sequence_name=None, allow_dict=True):
"""Create error message for non-conformance with named BaseObject api."""
name_str = f"{sequence_name}" if sequence_name is not None else "Input"
allowed_types = "a sequence of (string name, BaseObject instance) tuples"
if allow_dict:
allowed_types += " or dict[str, BaseObject instance]"
msg = f"Invalid {name_str!r}, {name_str!r} should be {allowed_types}."
return msg
def is_named_object_tuple(obj, object_type=None):
"""Indicate if input is a a tuple of format (str, `object_type`).
Used to validate that input follows named object tuple API format.
Parameters
----------
obj : Any
The object to be checked to see if it is a (str, `object_type`) tuple.
object_type : class or tuple of class, default=BaseObject
Class(es) that all objects are checked to be an instance of. If None,
then :class:``skbase.base.BaseObject`` is used as default.
Returns
-------
bool
True if obj is (str, `object_type`) tuple, otherwise False.
See Also
--------
is_sequence_named_objects :
Indicate (True/False) if an input sequence follows the named object API.
check_sequence_named_objects :
Validate input to see if it follows sequence of named objects API. An error
is raised for input that does not conform to the API format.
Examples
--------
>>> from skbase.base import BaseObject, BaseEstimator
>>> from skbase.validate import is_named_object_tuple
Default checks for object to be an instance of BaseObject
>>> is_named_object_tuple(("Step 1", BaseObject()))
True
>>> is_named_object_tuple(("Step 2", BaseEstimator()))
True
If a different `object_type` is provided then it is used in the isinstance check
>>> is_named_object_tuple(("Step 1", BaseObject()), object_type=BaseEstimator)
False
>>> is_named_object_tuple(("Step 1", BaseEstimator()), object_type=BaseEstimator)
True
If the input is does not follow named object tuple format then False is returned
>>> is_named_object_tuple({"Step 1": BaseEstimator()})
False
>>> is_named_object_tuple((1, BaseObject()))
False
"""
if object_type is None:
object_type = BaseObject
if not isinstance(obj, tuple) or len(obj) != 2:
return False
if not isinstance(obj[0], str) or not isinstance(obj[1], object_type):
return False
return True
def is_sequence_named_objects(
seq_to_check,
allow_dict=True,
require_unique_names=False,
object_type=None,
):
"""Indicate if input is a sequence of named BaseObject instances.
This can be a sequence of (str, BaseObject instance) tuples or
a dictionary with string names as keys and BaseObject instances as values
(if ``allow_dict=True``).
Parameters
----------
seq_to_check : Sequence((str, BaseObject)) or Dict[str, BaseObject]
The input to check for conformance with the named object interface.
Conforming input are:
- Sequence that contains (str, BaseObject instance) tuples
- Dictionary with string names as keys and BaseObject instances as values
if ``allow_dict=True``
allow_dict : bool, default=True
Whether a dictionary of named objects is allowed as conforming named object
type.
- If True, then a dictionary with string keys and BaseObject instances
is allowed format for providing a sequence of named objects.
- If False, then only sequences that contain (str, BaseObject instance)
tuples are considered conforming with the named object parameter API.
require_unique_names : bool, default=False
Whether names used in the sequence of named BaseObject instances
must be unique.
- If True and the names are not unique, then False is always returned.
- If False, then whether or not the function returns True or False
depends on whether `seq_to_check` follows sequence of named
BaseObject format.
object_type : class or tuple[class], default=None
The class type(s) that is used to ensure that all elements of named objects
match the expected type.
Returns
-------
bool
Whether the input `seq_to_check` is a sequence that follows the API for
nameed base object instances.
Raises
------
ValueError
If `seq_to_check` is not a sequence or ``allow_dict is False`` and
`seq_to_check` is a dictionary.
See Also
--------
is_named_object_tuple :
Indicate (True/False) if input follows the named object API format for
a single named object (e.g., tuple[str, expected class type]).
check_sequence_named_objects :
Validate input to see if it follows sequence of named objects API. An error
is raised for input that does not conform to the API format.
Examples
--------
>>> from skbase.base import BaseObject, BaseEstimator
>>> from skbase.validate import is_sequence_named_objects
>>> named_objects = [("Step 1", BaseObject()), ("Step 2", BaseObject())]
>>> is_sequence_named_objects(named_objects)
True
Dictionaries are optionally allowed as sequences of named BaseObjects
>>> dict_named_objects = {"Step 1": BaseObject(), "Step 2": BaseObject()}
>>> is_sequence_named_objects(dict_named_objects)
True
>>> is_sequence_named_objects(dict_named_objects, allow_dict=False)
False
Invalid format due to object names not being strings
>>> incorrectly_named_objects = [(1, BaseObject()), (2, BaseObject())]
>>> is_sequence_named_objects(incorrectly_named_objects)
False
Invalid format due to named items not being BaseObject instances
>>> named_items = [("1", 7), ("2", 42)]
>>> is_sequence_named_objects(named_items)
False
The validation can require the object elements to be a certain class type
>>> named_objects = [("Step 1", BaseObject()), ("Step 2", BaseObject())]
>>> is_sequence_named_objects(named_objects, object_type=BaseEstimator)
False
>>> named_objects = [("Step 1", BaseEstimator()), ("Step 2", BaseEstimator())]
>>> is_sequence_named_objects(named_objects, object_type=BaseEstimator)
True
"""
# Want to end quickly if the input isn't sequence or is a dict and we
# aren't allowing dicts
if object_type is None:
object_type = BaseObject
is_dict = isinstance(seq_to_check, dict)
if (not is_dict and not isinstance(seq_to_check, collections.abc.Sequence)) or (
not allow_dict and is_dict
):
return False
if is_dict:
elements_expected_format = [
isinstance(name, str) and isinstance(obj, object_type)
for name, obj in seq_to_check.items()
]
all_unique_names = True
else:
names = []
elements_expected_format = []
for it in seq_to_check:
if is_named_object_tuple(it, object_type=object_type):
elements_expected_format.append(True)
names.append(it[0])
else:
elements_expected_format.append(False)
all_unique_names = len(set(names)) == len(names)
all_expected_format = all(elements_expected_format)
if not all_expected_format or (require_unique_names and not all_unique_names):
is_expected_format = False
else:
is_expected_format = True
return is_expected_format
def check_sequence_named_objects(
seq_to_check,
allow_dict=True,
require_unique_names=False,
object_type=None,
sequence_name=None,
):
"""Check if input is a sequence of named BaseObject instances.
`seq_to_check` is returned unchanged when it follows the allowed named
BaseObject convention. The allowed format includes a sequence of
(str, BaseObject instance) tuples. A dictionary with string names as keys
and BaseObject instances as values is also allowed if ``allow_dict is True``.
Parameters
----------
seq_to_check : Sequence((str, BaseObject)) or Dict[str, BaseObject]
The input to check for conformance with the named object interface.
Conforming input are:
- Sequence that contains (str, BaseObject instance) tuples
- Dictionary with string names as keys and BaseObject instances as values
if ``allow_dict=True``
allow_dict : bool, default=True
Whether a dictionary of named objects is allowed as conforming named object
type.
- If True, then a dictionary with string keys and BaseObject instances
is allowed format for providing a sequence of named objects.
- If False, then only sequences that contain (str, BaseObject instance)
tuples are considered conforming with the named object parameter API.
require_unique_names : bool, default=False
Whether names used in the sequence of named BaseObject instances
must be unique.
- If True and the names are not unique, then False is always returned.
- If False, then whether or not the function returns True or False
depends on whether `seq_to_check` follows sequence of named BaseObject format.
object_type : class or tuple[class], default=None
The class type(s) that is used to ensure that all elements of named objects
match the expected type.
sequence_name : str, default=None
Optional name used to refer to the input `seq_to_check` when
raising any errors. Ignored ``raise_error=False``.
Returns
-------
Sequence((str, BaseObject)) or Dict[str, BaseObject]
The `seq_to_check` is returned if it is a conforming named object type.
- If ``allow_dict=True`` then return type is Sequence((str, BaseObject))
or Dict[str, BaseObject]
- If ``allow_dict=False`` then return type is Sequence((str, BaseObject))
Raises
------
ValueError
If `seq_to_check` does not conform to the named BaseObject API.
See Also
--------
is_named_object_tuple :
Indicate (True/False) if input follows the named object API format for
a single named object (e.g., tuple[str, expected class type]).
is_sequence_named_objects :
Indicate (True/False) if an input sequence follows the named object API.
Examples
--------
>>> from skbase.base import BaseObject, BaseEstimator
>>> from skbase.validate import check_sequence_named_objects
>>> named_objects = [("Step 1", BaseObject()), ("Step 2", BaseObject())]
>>> check_sequence_named_objects(named_objects)
[('Step 1', BaseObject()), ('Step 2', BaseObject())]
Dictionaries are optionally allowed as sequences of named BaseObjects
>>> named_objects = {"Step 1": BaseObject(), "Step 2": BaseObject()}
>>> check_sequence_named_objects(named_objects)
{'Step 1': BaseObject(), 'Step 2': BaseObject()}
Raises error since dictionaries are not allowed when allow_dict is False
>>> check_sequence_named_objects(named_objects, allow_dict=False) # doctest: +SKIP
Raises error due to invalid format due to object names not being strings
>>> incorrectly_named_objects = [(1, BaseObject()), (2, BaseObject())]
>>> check_sequence_named_objects(incorrectly_named_objects) # doctest: +SKIP
Raises error due to invalid format since named items are not BaseObject instances
>>> named_items = [("1", 7), ("2", 42)]
>>> check_sequence_named_objects(named_items) # doctest: +SKIP
The validation can require the object elements to be a certain class type
>>> named_objects = [("Step 1", BaseObject()), ("Step 2", BaseObject())]
>>> check_sequence_named_objects( \
named_objects, object_type=BaseEstimator) # doctest: +SKIP
>>> named_objects = [("Step 1", BaseEstimator()), ("Step 2", BaseEstimator())]
>>> check_sequence_named_objects(named_objects, object_type=BaseEstimator)
[('Step 1', BaseEstimator()), ('Step 2', BaseEstimator())]
"""
is_expected_format = is_sequence_named_objects(
seq_to_check,
allow_dict=allow_dict,
require_unique_names=require_unique_names,
object_type=object_type,
)
# Raise error is format is not expected.
if not is_expected_format:
msg = _named_baseobject_error_msg(
sequence_name=sequence_name, allow_dict=allow_dict
)
raise ValueError(msg)
return seq_to_check