-
-
Notifications
You must be signed in to change notification settings - Fork 609
/
checks.py
502 lines (397 loc) · 17.7 KB
/
checks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
# Copyright (c) 2021 Oleg Polakow. All rights reserved.
# This code is licensed under Apache 2.0 with Commons Clause license (see LICENSE.md for details)
"""Utilities for validation during runtime."""
import os
from collections.abc import Hashable, Mapping
from inspect import signature, getmro
from keyword import iskeyword
import dill
import numpy as np
import pandas as pd
from numba.core.registry import CPUDispatcher
from vectorbt import _typing as tp
# ############# Checks ############# #
def is_np_array(arg: tp.Any) -> bool:
"""Check whether the argument is `np.ndarray`."""
return isinstance(arg, np.ndarray)
def is_series(arg: tp.Any) -> bool:
"""Check whether the argument is `pd.Series`."""
return isinstance(arg, pd.Series)
def is_index(arg: tp.Any) -> bool:
"""Check whether the argument is `pd.Index`."""
return isinstance(arg, pd.Index)
def is_frame(arg: tp.Any) -> bool:
"""Check whether the argument is `pd.DataFrame`."""
return isinstance(arg, pd.DataFrame)
def is_pandas(arg: tp.Any) -> bool:
"""Check whether the argument is `pd.Series` or `pd.DataFrame`."""
return is_series(arg) or is_frame(arg)
def is_any_array(arg: tp.Any) -> bool:
"""Check whether the argument is any of `np.ndarray`, `pd.Series` or `pd.DataFrame`."""
return is_pandas(arg) or isinstance(arg, np.ndarray)
def _to_any_array(arg: tp.ArrayLike) -> tp.AnyArray:
"""Convert any array-like object to an array.
Pandas objects are kept as-is."""
if is_any_array(arg):
return arg
return np.asarray(arg)
def is_sequence(arg: tp.Any) -> bool:
"""Check whether the argument is a sequence."""
try:
len(arg)
arg[0:0]
return True
except TypeError:
return False
def is_iterable(arg: tp.Any) -> bool:
"""Check whether the argument is iterable."""
try:
_ = iter(arg)
return True
except TypeError:
return False
def is_numba_func(arg: tp.Any) -> bool:
"""Check whether the argument is a Numba-compiled function."""
from vectorbt._settings import settings
numba_cfg = settings['numba']
if not numba_cfg['check_func_type']:
return True
if 'NUMBA_DISABLE_JIT' in os.environ:
if os.environ['NUMBA_DISABLE_JIT'] == '1':
if not numba_cfg['check_func_suffix']:
return True
if arg.__name__.endswith('_nb'):
return True
return isinstance(arg, CPUDispatcher)
def is_hashable(arg: tp.Any) -> bool:
"""Check whether the argument can be hashed."""
if not isinstance(arg, Hashable):
return False
# Having __hash__() method does not mean that it's hashable
try:
hash(arg)
except TypeError:
return False
return True
def is_index_equal(arg1: tp.Any, arg2: tp.Any, strict: bool = True) -> bool:
"""Check whether indexes are equal.
Introduces naming tests on top of `pd.Index.equals`, but still doesn't check for types."""
if not strict:
return pd.Index.equals(arg1, arg2)
if isinstance(arg1, pd.MultiIndex) and isinstance(arg2, pd.MultiIndex):
if arg1.names != arg2.names:
return False
elif isinstance(arg1, pd.MultiIndex) or isinstance(arg2, pd.MultiIndex):
return False
else:
if arg1.name != arg2.name:
return False
return pd.Index.equals(arg1, arg2)
def is_default_index(arg: tp.Any) -> bool:
"""Check whether index is a basic range."""
return is_index_equal(arg, pd.RangeIndex(start=0, stop=len(arg), step=1))
def is_namedtuple(x: tp.Any) -> bool:
"""Check whether object is an instance of namedtuple."""
t = type(x)
b = t.__bases__
if len(b) != 1 or b[0] != tuple:
return False
f = getattr(t, '_fields', None)
if not isinstance(f, tuple):
return False
return all(type(n) == str for n in f)
def func_accepts_arg(func: tp.Callable, arg_name: str, arg_kind: tp.Optional[tp.MaybeTuple[int]] = None) -> bool:
"""Check whether `func` accepts a positional or keyword argument with name `arg_name`."""
sig = signature(func)
if arg_kind is not None and isinstance(arg_kind, int):
arg_kind = (arg_kind,)
if arg_kind is None:
if arg_name.startswith('**'):
return arg_name[2:] in [
p.name for p in sig.parameters.values()
if p.kind == p.VAR_KEYWORD
]
if arg_name.startswith('*'):
return arg_name[1:] in [
p.name for p in sig.parameters.values()
if p.kind == p.VAR_POSITIONAL
]
return arg_name in [
p.name for p in sig.parameters.values()
if p.kind != p.VAR_POSITIONAL and p.kind != p.VAR_KEYWORD
]
return arg_name in [
p.name for p in sig.parameters.values()
if p.kind in arg_kind
]
def is_equal(arg1: tp.Any, arg2: tp.Any,
equality_func: tp.Callable[[tp.Any, tp.Any], bool] = lambda x, y: x == y) -> bool:
"""Check whether two objects are equal."""
try:
return equality_func(arg1, arg2)
except:
pass
return False
def is_deep_equal(arg1: tp.Any, arg2: tp.Any, check_exact: bool = False, **kwargs) -> bool:
"""Check whether two objects are equal (deep check)."""
def _select_kwargs(_method, _kwargs):
__kwargs = dict()
if len(kwargs) > 0:
for k, v in _kwargs.items():
if func_accepts_arg(_method, k):
__kwargs[k] = v
return __kwargs
def _check_array(assert_method):
__kwargs = _select_kwargs(assert_method, kwargs)
safe_assert(arg1.dtype == arg2.dtype)
if arg1.dtype.fields is not None:
for field in arg1.dtype.names:
assert_method(arg1[field], arg2[field], **__kwargs)
else:
assert_method(arg1, arg2, **__kwargs)
try:
safe_assert(type(arg1) == type(arg2))
if isinstance(arg1, pd.Series):
_kwargs = _select_kwargs(pd.testing.assert_series_equal, kwargs)
pd.testing.assert_series_equal(arg1, arg2, check_exact=check_exact, **_kwargs)
elif isinstance(arg1, pd.DataFrame):
_kwargs = _select_kwargs(pd.testing.assert_frame_equal, kwargs)
pd.testing.assert_frame_equal(arg1, arg2, check_exact=check_exact, **_kwargs)
elif isinstance(arg1, pd.Index):
_kwargs = _select_kwargs(pd.testing.assert_index_equal, kwargs)
pd.testing.assert_index_equal(arg1, arg2, check_exact=check_exact, **_kwargs)
elif isinstance(arg1, np.ndarray):
try:
_check_array(np.testing.assert_array_equal)
except:
if check_exact:
return False
_check_array(np.testing.assert_allclose)
else:
if isinstance(arg1, (tuple, list)):
for i in range(len(arg1)):
safe_assert(is_deep_equal(arg1[i], arg2[i], **kwargs))
elif isinstance(arg1, dict):
for k in arg1.keys():
safe_assert(is_deep_equal(arg1[k], arg2[k], **kwargs))
else:
try:
if arg1 == arg2:
return True
except:
pass
try:
_kwargs = _select_kwargs(dill.dumps, kwargs)
if dill.dumps(arg1, **_kwargs) == dill.dumps(arg2, **_kwargs):
return True
except:
pass
return False
except:
return False
return True
def is_subclass_of(arg: tp.Any, types: tp.MaybeTuple[tp.Union[tp.Type, str]]) -> bool:
"""Check whether the argument is a subclass of `types`.
`types` can be one or multiple types or strings."""
if isinstance(types, type):
return issubclass(arg, types)
if isinstance(types, str):
for base_t in getmro(arg):
if str(base_t) == types or base_t.__name__ == types:
return True
if isinstance(types, tuple):
for t in types:
if is_subclass_of(arg, t):
return True
return False
def is_instance_of(arg: tp.Any, types: tp.MaybeTuple[tp.Union[tp.Type, str]]) -> bool:
"""Check whether the argument is an instance of `types`.
`types` can be one or multiple types or strings."""
return is_subclass_of(type(arg), types)
def is_mapping(arg: tp.Any) -> bool:
"""Check whether the arguments is a mapping."""
return isinstance(arg, Mapping)
def is_mapping_like(arg: tp.Any) -> bool:
"""Check whether the arguments is a mapping-like object."""
return is_mapping(arg) or is_series(arg) or is_index(arg) or is_namedtuple(arg)
def is_valid_variable_name(arg: str) -> bool:
"""Check whether the argument is a valid variable name."""
return arg.isidentifier() and not iskeyword(arg)
# ############# Asserts ############# #
def safe_assert(arg: tp.Any, msg: tp.Optional[str] = None) -> None:
if not arg:
raise AssertionError(msg)
def assert_in(arg1: tp.Any, arg2: tp.Sequence) -> None:
"""Raise exception if the first argument is not in the second argument."""
if arg1 not in arg2:
raise AssertionError(f"{arg1} not found in {arg2}")
def assert_numba_func(func: tp.Callable) -> None:
"""Raise exception if `func` is not Numba-compiled."""
if not is_numba_func(func):
raise AssertionError(f"Function {func} must be Numba compiled")
def assert_not_none(arg: tp.Any) -> None:
"""Raise exception if the argument is None."""
if arg is None:
raise AssertionError(f"Argument cannot be None")
def assert_instance_of(arg: tp.Any, types: tp.MaybeTuple[tp.Type]) -> None:
"""Raise exception if the argument is none of types `types`."""
if not is_instance_of(arg, types):
if isinstance(types, tuple):
raise AssertionError(f"Type must be one of {types}, not {type(arg)}")
else:
raise AssertionError(f"Type must be {types}, not {type(arg)}")
def assert_subclass_of(arg: tp.Type, classes: tp.MaybeTuple[tp.Type]) -> None:
"""Raise exception if the argument is not a subclass of classes `classes`."""
if not is_subclass_of(arg, classes):
if isinstance(classes, tuple):
raise AssertionError(f"Class must be a subclass of one of {classes}, not {arg}")
else:
raise AssertionError(f"Class must be a subclass of {classes}, not {arg}")
def assert_type_equal(arg1: tp.Any, arg2: tp.Any) -> None:
"""Raise exception if the first argument and the second argument have different types."""
if type(arg1) != type(arg2):
raise AssertionError(f"Types {type(arg1)} and {type(arg2)} do not match")
def assert_dtype(arg: tp.ArrayLike, dtype: tp.DTypeLike) -> None:
"""Raise exception if the argument is not of data type `dtype`."""
arg = _to_any_array(arg)
if isinstance(arg, pd.DataFrame):
for i, col_dtype in enumerate(arg.dtypes):
if col_dtype != dtype:
raise AssertionError(f"Data type of column {i} must be {dtype}, not {col_dtype}")
else:
if arg.dtype != dtype:
raise AssertionError(f"Data type must be {dtype}, not {arg.dtype}")
def assert_subdtype(arg: tp.ArrayLike, dtype: tp.DTypeLike) -> None:
"""Raise exception if the argument is not a sub data type of `dtype`."""
arg = _to_any_array(arg)
if isinstance(arg, pd.DataFrame):
for i, col_dtype in enumerate(arg.dtypes):
if not np.issubdtype(col_dtype, dtype):
raise AssertionError(f"Data type of column {i} must be {dtype}, not {col_dtype}")
else:
if not np.issubdtype(arg.dtype, dtype):
raise AssertionError(f"Data type must be {dtype}, not {arg.dtype}")
def assert_dtype_equal(arg1: tp.ArrayLike, arg2: tp.ArrayLike) -> None:
"""Raise exception if the first argument and the second argument have different data types."""
arg1 = _to_any_array(arg1)
arg2 = _to_any_array(arg2)
if isinstance(arg1, pd.DataFrame):
dtypes1 = arg1.dtypes.to_numpy()
else:
dtypes1 = np.asarray([arg1.dtype])
if isinstance(arg2, pd.DataFrame):
dtypes2 = arg2.dtypes.to_numpy()
else:
dtypes2 = np.asarray([arg2.dtype])
if len(dtypes1) == len(dtypes2):
if (dtypes1 == dtypes2).all():
return
elif len(np.unique(dtypes1)) == 1 and len(np.unique(dtypes2)) == 1:
if np.all(np.unique(dtypes1) == np.unique(dtypes2)):
return
raise AssertionError(f"Data types {dtypes1} and {dtypes2} do not match")
def assert_ndim(arg: tp.ArrayLike, ndims: tp.MaybeTuple[int]) -> None:
"""Raise exception if the argument has a different number of dimensions than `ndims`."""
arg = _to_any_array(arg)
if isinstance(ndims, tuple):
if arg.ndim not in ndims:
raise AssertionError(f"Number of dimensions must be one of {ndims}, not {arg.ndim}")
else:
if arg.ndim != ndims:
raise AssertionError(f"Number of dimensions must be {ndims}, not {arg.ndim}")
def assert_len_equal(arg1: tp.Sized, arg2: tp.Sized) -> None:
"""Raise exception if the first argument and the second argument have different length.
Does not transform arguments to NumPy arrays."""
if len(arg1) != len(arg2):
raise AssertionError(f"Lengths of {arg1} and {arg2} do not match")
def assert_shape_equal(arg1: tp.ArrayLike, arg2: tp.ArrayLike,
axis: tp.Optional[tp.Union[int, tp.Tuple[int, int]]] = None) -> None:
"""Raise exception if the first argument and the second argument have different shapes along `axis`."""
arg1 = _to_any_array(arg1)
arg2 = _to_any_array(arg2)
if axis is None:
if arg1.shape != arg2.shape:
raise AssertionError(f"Shapes {arg1.shape} and {arg2.shape} do not match")
else:
if isinstance(axis, tuple):
if arg1.shape[axis[0]] != arg2.shape[axis[1]]:
raise AssertionError(
f"Axis {axis[0]} of {arg1.shape} and axis {axis[1]} of {arg2.shape} do not match")
else:
if arg1.shape[axis] != arg2.shape[axis]:
raise AssertionError(f"Axis {axis} of {arg1.shape} and {arg2.shape} do not match")
def assert_index_equal(arg1: pd.Index, arg2: pd.Index, **kwargs) -> None:
"""Raise exception if the first argument and the second argument have different index/columns."""
if not is_index_equal(arg1, arg2, **kwargs):
raise AssertionError(f"Indexes {arg1} and {arg2} do not match")
def assert_meta_equal(arg1: tp.ArrayLike, arg2: tp.ArrayLike) -> None:
"""Raise exception if the first argument and the second argument have different metadata."""
arg1 = _to_any_array(arg1)
arg2 = _to_any_array(arg2)
assert_type_equal(arg1, arg2)
assert_shape_equal(arg1, arg2)
if is_pandas(arg1) and is_pandas(arg2):
assert_index_equal(arg1.index, arg2.index)
if is_frame(arg1) and is_frame(arg2):
assert_index_equal(arg1.columns, arg2.columns)
def assert_array_equal(arg1: tp.ArrayLike, arg2: tp.ArrayLike) -> None:
"""Raise exception if the first argument and the second argument have different metadata or values."""
arg1 = _to_any_array(arg1)
arg2 = _to_any_array(arg2)
assert_meta_equal(arg1, arg2)
if is_pandas(arg1) and is_pandas(arg2):
if arg1.equals(arg2):
return
elif not is_pandas(arg1) and not is_pandas(arg2):
if np.array_equal(arg1, arg2):
return
raise AssertionError(f"Arrays {arg1} and {arg2} do not match")
def assert_level_not_exists(arg: pd.Index, level_name: str) -> None:
"""Raise exception if index the argument has level `level_name`."""
if isinstance(arg, pd.MultiIndex):
names = arg.names
else:
names = [arg.name]
if level_name in names:
raise AssertionError(f"Level {level_name} already exists in {names}")
def assert_equal(arg1: tp.Any, arg2: tp.Any, deep: bool = False) -> None:
"""Raise exception if the first argument and the second argument are different."""
if deep:
if not is_deep_equal(arg1, arg2):
raise AssertionError(f"{arg1} and {arg2} do not match (deep check)")
else:
if not is_equal(arg1, arg2):
raise AssertionError(f"{arg1} and {arg2} do not match")
def assert_dict_valid(arg: tp.DictLike, lvl_keys: tp.Sequence[tp.MaybeSequence[str]]) -> None:
"""Raise exception if dict the argument has keys that are not in `lvl_keys`.
`lvl_keys` should be a list of lists, each corresponding to a level in the dict."""
if arg is None:
arg = {}
if len(lvl_keys) == 0:
return
if isinstance(lvl_keys[0], str):
lvl_keys = [lvl_keys]
set1 = set(arg.keys())
set2 = set(lvl_keys[0])
if not set1.issubset(set2):
raise AssertionError(f"Keys {set1.difference(set2)} are not recognized. Possible keys are {set2}.")
for k, v in arg.items():
if isinstance(v, dict):
assert_dict_valid(v, lvl_keys[1:])
def assert_dict_sequence_valid(arg: tp.DictLikeSequence, lvl_keys: tp.Sequence[tp.MaybeSequence[str]]) -> None:
"""Raise exception if a dict or any dict in a sequence of dicts has keys that are not in `lvl_keys`."""
if arg is None:
arg = {}
if isinstance(arg, dict):
assert_dict_valid(arg, lvl_keys)
else:
for _arg in arg:
assert_dict_valid(_arg, lvl_keys)
def assert_sequence(arg: tp.Any) -> None:
"""Raise exception if the argument is not a sequence."""
if not is_sequence(arg):
raise ValueError(f"{arg} must be a sequence")
def assert_iterable(arg: tp.Any) -> None:
"""Raise exception if the argument is not an iterable."""
if not is_iterable(arg):
raise ValueError(f"{arg} must be an iterable")