-
Notifications
You must be signed in to change notification settings - Fork 5
/
helper_functions.py
293 lines (241 loc) · 6.43 KB
/
helper_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
# helper_functions.py
from __future__ import annotations
from operator import itemgetter
from tqdm import tqdm
import numpy as np
from typing import Callable, Any, Sequence, Iterable, Sized
from time import time
from functools import wraps
from hashlib import sha512
import warnings
# Optional imports
try:
import pandas as pd
except ImportError:
pd = None
def next_power_of_2(x : int) -> int:
"""
Returns the next power of two greater than or equal to the input.
Parameters
----------
x : int
The input number.
Returns
-------
nextpow2 : int
The next power of two, nextpow2 >= x.
"""
if x <= 0:
return 0
else:
return int(2**np.ceil(np.log2(x)))
def hash_str_repeatable(s : str) -> int:
"""
By default, string hashing in python is randomized. This function returns a
repeatable non-randomized hash for strings.
See: https://docs.python.org/3/using/cmdline.html#cmdoption-R
Parameters
----------
s : str
The string to hash.
Returns
-------
s_hash : int
The hash of str.
"""
return int(sha512(s.encode('utf-8')).hexdigest(), 16)
def hashable_val(val : Any) -> Any:
"""
For `nummap` and `valmap`, we need to use values as keys in a dictionary.
This function will return the string representation of a value if that
value is not hashable.
Parameters
----------
val : Any
The value to hash.
Returns
-------
hashable_val : Any
A hashable representation of the val.
"""
try:
hash(val)
return val
except TypeError:
return str(val)
def is_num(val : Any) -> bool:
"""
Type checking function to see if the input is a number.
Parameters
----------
val : Any
The value to check.
Returns
-------
isnum : bool
Returns True if the input is a number, False otherwise.
"""
if isinstance(val, bool) or isinstance(val, str):
return False
else:
try:
float(val)
except (ValueError, TypeError):
return False
else:
return True
def length(x : Any) -> int | None:
"""
Genericized length function that works on scalars (which have length 1).
Parameters
----------
x : Any
The value to check.
Returns
-------
x_len : int
The length of the input. If not a sequence or scalar, returns None.
"""
if isinstance(x, Sized):
return len(x)
elif isinstance(x, (np.float64, float, bool, int)):
return 1
else:
return None
def get_list(x : Any) -> list[Any]:
"""
Converts the input to an iterable list.
Parameters
----------
x : Any
The object to convert.
Returns
-------
x_list : list
A list conversion of the input.
"""
if x is None:
return list()
elif isinstance(x, str):
return [x, ]
elif pd and isinstance(x, pd.DataFrame):
return [x, ]
elif isinstance(x, Iterable):
if isinstance(x, np.ndarray) and np.ndim(x) == 0:
return [x[()], ]
return list(x)
else:
return [x, ]
def slice_by_index(sequence : Sequence[Any],
indices : int | Iterable[int],
) -> list:
"""
Returns a slice of a sequence at the specified indices.
Parameters
----------
sequence : Sequence
The sequence to slice.
indices : int | Iterable
The indices to slice at.
Returns
-------
slice : list
A list representing the values of the input sequence at the specified
indices.
"""
indices_list = get_list(indices)
if sequence is None or indices_list == list():
return []
items = itemgetter(*indices_list)(sequence)
if len(indices_list) == 1:
return [items]
return list(items)
def vprint(verbose : bool, *args, **kwargs) -> None:
"""
Print only if verbose is True.
Parameters
----------
verbose : bool
Flag to determine whether to print something.
*args, **kwargs
Must include something to print here!
"""
if verbose:
print(*args, **kwargs)
def warn_short_format(message, category, filename, lineno, file=None, line=None) -> str:
"""
Custom warning format for use in vwarn()
"""
return f'{category.__name__}: {message}\n'
def vwarn(verbose : bool, *args, **kwargs) -> None:
"""
Warn only if verbose is True.
Parameters
----------
verbose : bool
Flag to determine whether to print something.
*args, **kwargs
Must include a warning message here!
"""
if verbose:
warn_default_format = warnings.formatwarning
warnings.formatwarning = warn_short_format # type: ignore
warnings.warn(*args, **kwargs)
warnings.formatwarning = warn_default_format
def vwrite(verbose : bool, *args, **kwargs) -> None:
"""
Perform a tqdm.write() only if verbose is True.
Parameters
----------
verbose : bool
Flag to determine whether to write something.
*args, **kwargs
Must include something to write here!
"""
if verbose:
tqdm.write(*args, **kwargs)
def timeit(fcn : Callable):
"""
Function decorator to print out the function runtime in milliseconds.
Parameters
----------
fcn : Callable
Function to time.
"""
@wraps(fcn)
def timed(*args, **kw):
t0 = time()
output = fcn(*args, **kw)
t1 = time()
print(f'"{fcn.__name__}" took {(t1 - t0)*1000 : .3f} ms to execute.\n')
return output
return timed
def empty_list() -> list:
"""
Sentinel for default arguments being an empty list.
Returns
-------
empty_list : list
An empty list.
"""
return []
def flatten(nested_x : Iterable[Any]) -> list[Any]:
"""
Flattens a nested interable into a list with all nested items.
Parameters
----------
nested_x : Iterable
Nested iterable.
Returns
-------
flattened_x : list
The nested iterable flattened into a list.
"""
def flatten_generator(x):
for element in x:
if isinstance(element, Iterable) and not isinstance(element, (str, bytes)):
yield from flatten(element)
else:
yield element
flattened_x = list(flatten_generator(nested_x))
return flattened_x