/
py_utils.py
529 lines (427 loc) · 15.9 KB
/
py_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
# coding=utf-8
# Copyright 2023 The TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Some python utils function and classes."""
import base64
import contextlib
import functools
import io
import itertools
import logging
import os
import random
import re
import shutil
import string
import sys
import textwrap
import threading
import typing
from typing import Any, Callable, Dict, Iterator, List, NoReturn, Optional, Tuple, Type, TypeVar, Union
import uuid
from absl import logging as absl_logging
from etils import epath
from six.moves import urllib
from tensorflow_datasets.core import constants
from tensorflow_datasets.core.utils import type_utils
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
Tree = type_utils.Tree
# NOTE: When used on an instance method, the cache is shared across all
# instances and IS NOT per-instance.
# See
# https://stackoverflow.com/questions/14946264/python-lru-cache-decorator-per-instance
# For @property methods, use @memoized_property below.
memoize = functools.lru_cache
T = TypeVar('T')
U = TypeVar('U')
Fn = TypeVar('Fn', bound=Callable[..., Any])
def is_notebook() -> bool:
"""Returns True if running in a notebook (Colab, Jupyter) environment."""
# Inspired from the tqdm autonotebook code
try:
# Use sys.module as we do not want to trigger import
IPython = sys.modules['IPython'] # pylint: disable=invalid-name
if 'IPKernelApp' not in IPython.get_ipython().config:
return False # Run in a IPython terminal
except: # pylint: disable=bare-except
return False
else:
return True
# TODO(tfds): Should likely have a `logging_utils` wrapper around `absl.logging`
# so logging messages are displayed on Colab.
def print_notebook(*args: Any) -> None:
"""Like `print`/`logging.info`. Colab do not print stderr by default."""
msg = ' '.join([str(x) for x in args])
if is_notebook():
print(msg)
else:
absl_logging.info(msg)
def warning(text: str) -> None:
if is_notebook():
print(text)
else:
absl_logging.warning(text)
@contextlib.contextmanager
def temporary_assignment(obj, attr, value):
"""Temporarily assign obj.attr to value."""
original = getattr(obj, attr)
setattr(obj, attr, value)
try:
yield
finally:
setattr(obj, attr, original)
def zip_dict(*dicts):
"""Iterate over items of dictionaries grouped by their keys."""
for key in set(itertools.chain(*dicts)): # set merge all keys
# Will raise KeyError if the dict don't have the same keys
yield key, tuple(d[key] for d in dicts)
@contextlib.contextmanager
def disable_logging():
"""Temporarily disable the logging."""
logger = logging.getLogger()
logger_disabled = logger.disabled
logger.disabled = True
try:
yield
finally:
logger.disabled = logger_disabled
class NonMutableDict(Dict[T, U]):
"""Dict where keys can only be added but not modified.
Raises an error if a key is overwritten. The error message can be customized
during construction. It will be formatted using {key} for the overwritten key.
"""
def __init__(self, *args, **kwargs):
self._error_msg = kwargs.pop(
'error_msg',
'Try to overwrite existing key: {key}',
)
if kwargs:
raise ValueError('NonMutableDict cannot be initialized with kwargs.')
super(NonMutableDict, self).__init__(*args, **kwargs)
def __setitem__(self, key, value):
if key in self.keys():
raise ValueError(self._error_msg.format(key=key))
return super(NonMutableDict, self).__setitem__(key, value)
def update(self, other): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
if any(k in self.keys() for k in other):
raise ValueError(self._error_msg.format(key=set(self) & set(other)))
return super(NonMutableDict, self).update(other)
class classproperty(property): # pylint: disable=invalid-name
"""Descriptor to be used as decorator for @classmethods."""
def __get__(self, obj, objtype=None):
return self.fget.__get__(None, objtype)() # pytype: disable=attribute-error
if typing.TYPE_CHECKING:
# TODO(b/171883689): There is likely a better way to annotate descriptors
def classproperty(fn: Callable[[Type[Any]], T]) -> T: # pylint: disable=function-redefined
return fn(type(None))
def memoized_property(fn: Callable[[Any], T]) -> T: # pylint: disable=function-redefined
return fn(None)
def map_nested(function, data_struct, dict_only=False, map_tuple=False):
"""Apply a function recursively to each element of a nested data struct."""
# Could add support for more exotic data_struct, like OrderedDict
if isinstance(data_struct, dict):
return {
k: map_nested(function, v, dict_only, map_tuple)
for k, v in data_struct.items()
}
elif not dict_only:
types_ = [list]
if map_tuple:
types_.append(tuple)
if isinstance(data_struct, tuple(types_)):
mapped = [
map_nested(function, v, dict_only, map_tuple) for v in data_struct
]
if isinstance(data_struct, list):
return mapped
else:
return tuple(mapped)
# Singleton
return function(data_struct)
def zip_nested(arg0, *args, **kwargs):
"""Zip data struct together and return a data struct with the same shape."""
# Python 2 do not support kwargs only arguments
dict_only = kwargs.pop('dict_only', False)
assert not kwargs
# Could add support for more exotic data_struct, like OrderedDict
if isinstance(arg0, dict):
return {
k: zip_nested(*a, dict_only=dict_only) for k, a in zip_dict(arg0, *args)
}
elif not dict_only:
if isinstance(arg0, list):
return [zip_nested(*a, dict_only=dict_only) for a in zip(arg0, *args)]
# Singleton
return (arg0,) + args
def flatten_nest_dict(d: type_utils.TreeDict[T]) -> Dict[str, T]:
"""Return the dict with all nested keys flattened joined with '/'."""
# Use NonMutableDict to ensure there is no collision between features keys
flat_dict = NonMutableDict()
for k, v in d.items():
if isinstance(v, dict):
for k2, v2 in flatten_nest_dict(v).items():
flat_dict[f'{k}/{k2}'] = v2
else:
flat_dict[k] = v
return flat_dict
# Note: Could use `tree.flatten_with_path` instead, but makes it harder for
# users to compile from source.
def flatten_with_path(
structure: Tree[T],
) -> Iterator[Tuple[Tuple[Union[str, int], ...], T]]: # pytype: disable=invalid-annotation
"""Convert a TreeDict into a flat list of paths and their values.
```py
flatten_with_path({'a': {'b': v}}) == [(('a', 'b'), v)]
```
Args:
structure: Nested input structure
Yields:
The `(path, value)` tuple. With path being the tuple of `dict` keys and
`list` indexes
"""
if isinstance(structure, dict):
key_struct_generator = sorted(structure.items())
elif isinstance(structure, (list, tuple)):
key_struct_generator = enumerate(structure)
else:
key_struct_generator = None # End of recursion
if key_struct_generator is not None:
for key, sub_structure in key_struct_generator:
# Recurse into sub-structures
for sub_path, sub_value in flatten_with_path(sub_structure):
yield (key,) + sub_path, sub_value
else:
yield (), structure # Leaf, yield value
def dedent(text):
"""Wrapper around `textwrap.dedent` which also `strip()` and handle `None`."""
return textwrap.dedent(text).strip() if text else text
def indent(text: str, indent: str) -> str: # pylint: disable=redefined-outer-name
text = dedent(text)
return text.replace('\n', '\n' + indent)
def pack_as_nest_dict(flat_d, nest_d):
"""Pack a 1-lvl dict into a nested dict with same structure as `nest_d`."""
nest_out_d = {}
for k, v in nest_d.items():
if isinstance(v, dict):
v_flat = flatten_nest_dict(v)
sub_d = {
k2: flat_d.pop('{}/{}'.format(k, k2)) for k2, _ in v_flat.items()
}
# Recursively pack the dictionary
nest_out_d[k] = pack_as_nest_dict(sub_d, v)
else:
nest_out_d[k] = flat_d.pop(k)
if flat_d: # At the end, flat_d should be empty
raise ValueError(
'Flat dict strucure do not match the nested dict. Extra keys: '
'{}'.format(list(flat_d.keys()))
)
return nest_out_d
@contextlib.contextmanager
def nullcontext(enter_result: T = None) -> Iterator[T]:
"""Backport of `contextlib.nullcontext`."""
yield enter_result
def _get_incomplete_path(filename):
"""Returns a temporary filename based on filename."""
random_suffix = ''.join(
random.choice(string.ascii_uppercase + string.digits) for _ in range(6)
)
return filename + constants.INCOMPLETE_SUFFIX + random_suffix
@contextlib.contextmanager
def incomplete_dir(dirname: epath.PathLike) -> Iterator[str]:
"""Create temporary dir for dirname and rename on exit."""
dirname = os.fspath(dirname)
tmp_dir = _get_incomplete_path(dirname)
tmp_path = epath.Path(tmp_dir)
tmp_path.mkdir(parents=True, exist_ok=True)
try:
yield tmp_dir
tmp_path.rename(dirname)
finally:
if tmp_path.exists():
tmp_path.rmtree()
@contextlib.contextmanager
def incomplete_file(
path: epath.Path,
) -> Iterator[epath.Path]:
"""Writes to path atomically, by writing to temp file and renaming it."""
tmp_path = (
path.parent
/ f'{path.name}{constants.INCOMPLETE_SUFFIX}.{uuid.uuid4().hex}'
)
try:
yield tmp_path
tmp_path.replace(path)
finally:
# Eventually delete the tmp_path if exception was raised
tmp_path.unlink(missing_ok=True)
def is_incomplete_file(path: epath.Path) -> bool:
"""Returns whether the given filename suggests that it's incomplete."""
return bool(
re.search(
rf'^.+{re.escape(constants.INCOMPLETE_SUFFIX)}\.[0-9a-fA-F]{{32}}$',
path.name,
)
)
@contextlib.contextmanager
def atomic_write(path, mode):
"""Writes to path atomically, by writing to temp file and renaming it."""
tmp_path = '%s%s_%s' % (path, constants.INCOMPLETE_SUFFIX, uuid.uuid4().hex)
with tf.io.gfile.GFile(tmp_path, mode) as file_:
yield file_
tf.io.gfile.rename(tmp_path, path, overwrite=True)
def reraise(
e: Exception,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
) -> NoReturn:
"""Reraise an exception with an additional message."""
prefix = prefix or ''
suffix = '\n' + suffix if suffix else ''
# If unsure about modifying the function inplace, create a new exception
# and stack it in the chain.
if (
# Exceptions with custom error message
type(e).__str__ is not BaseException.__str__
# This should never happens unless the user plays with Exception
# internals
or not hasattr(e, 'args')
or not isinstance(e.args, tuple)
):
msg = f'{prefix}{e}{suffix}'
# Could try to dynamically create a
# `type(type(e).__name__, (ReraisedError, type(e)), {})`, but should be
# carefull when nesting `reraise` as well as compatibility with external
# code.
# Some base exception class (ImportError, OSError) and subclasses (
# ModuleNotFoundError, FileNotFoundError) have custom `__str__` error
# message. We re-raise those with same type to allow except in caller code.
if isinstance(e, (ImportError, OSError)):
exception = type(e)(msg)
else:
exception = RuntimeError(f'{type(e).__name__}: {msg}')
raise exception from e
# Otherwise, modify the exception in-place
elif len(e.args) <= 1:
exception_msg = e.args[0] if e.args else ''
e.args = (f'{prefix}{exception_msg}{suffix}',)
raise # pylint: disable=misplaced-bare-raise
# If there is more than 1 args, concatenate the message with other args
else:
e.args = tuple(
p for p in (prefix,) + e.args + (suffix,) if not isinstance(p, str) or p
)
raise # pylint: disable=misplaced-bare-raise
@contextlib.contextmanager
def try_reraise(*args, **kwargs):
"""Context manager which reraise exceptions with an additional message.
Contrary to `raise ... from ...` and `raise Exception().with_traceback(tb)`,
this function tries to modify the original exception, to avoid nested
`During handling of the above exception, another exception occurred:`
stacktraces.
Args:
*args: Prefix to add to the exception message
**kwargs: Prefix to add to the exception message
Yields:
None
"""
try:
yield
except Exception as e: # pylint: disable=broad-except
reraise(e, *args, **kwargs)
def rgetattr(obj, attr, *args):
"""Get attr that handles dots in attr name."""
def _getattr(obj, attr):
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split('.'))
def has_sufficient_disk_space(needed_bytes, directory='.'):
try:
free_bytes = shutil.disk_usage(os.path.abspath(directory)).free
except OSError:
return True
return needed_bytes < free_bytes
def get_class_path(cls, use_tfds_prefix=True):
"""Returns path of given class or object. Eg: `tfds.image.cifar.Cifar10`."""
if not isinstance(cls, type):
cls = cls.__class__
module_path = cls.__module__
if use_tfds_prefix and module_path.startswith('tensorflow_datasets'):
module_path = 'tfds' + module_path[len('tensorflow_datasets') :]
return '.'.join([module_path, cls.__name__])
def get_class_url(cls):
"""Returns URL of given class or object."""
cls_path = get_class_path(cls, use_tfds_prefix=False)
module_path, unused_class_name = cls_path.rsplit('.', 1)
module_path = module_path.replace('.', '/')
return constants.SRC_BASE_URL + module_path + '.py'
def build_synchronize_decorator() -> Callable[[Fn], Fn]:
"""Returns a decorator which prevents concurrent calls to functions.
Usage:
synchronized = build_synchronize_decorator()
@synchronized
def read_value():
...
@synchronized
def write_value(x):
...
Returns:
make_threadsafe (fct): The decorator which lock all functions to which it
is applied under a same lock
"""
lock = threading.Lock()
def lock_decorator(fn: Fn) -> Fn:
@functools.wraps(fn)
def lock_decorated(*args, **kwargs):
with lock:
return fn(*args, **kwargs)
return lock_decorated
return lock_decorator
def basename_from_url(url: str) -> str:
"""Returns file name of file at given url."""
filename = urllib.parse.urlparse(url).path
filename = os.path.basename(filename)
# Replace `%2F` (html code for `/`) by `_`.
# This is consistent with how Chrome rename downloaded files.
filename = filename.replace('%2F', '_')
return filename or 'unknown_name'
def list_info_files(dir_path: epath.PathLike) -> List[str]:
"""Returns name of info files within dir_path."""
from tensorflow_datasets.core import file_adapters # pylint: disable=g-import-not-at-top # pytype: disable=import-error
path = os.fspath(dir_path)
return [
fname
for fname in tf.io.gfile.listdir(path)
if not tf.io.gfile.isdir(os.path.join(path, fname))
and not file_adapters.is_example_file(fname)
]
def get_base64(
write_fn: Union[bytes, Callable[[io.BytesIO], None]],
) -> str:
"""Extracts the base64 string of an object by writing into a tmp buffer."""
if isinstance(write_fn, bytes): # Value already encoded
bytes_value = write_fn
else:
buffer = io.BytesIO()
write_fn(buffer)
bytes_value = buffer.getvalue()
return base64.b64encode(bytes_value).decode('ascii') # pytype: disable=bad-return-type
@contextlib.contextmanager
def add_sys_path(path: epath.PathLike) -> Iterator[None]:
"""Temporary add given path to `sys.path`."""
path = os.fspath(path)
try:
sys.path.insert(0, path)
yield
finally:
sys.path.remove(path)