/
np_utils.py
689 lines (546 loc) · 21.1 KB
/
np_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for internal use."""
# pylint: disable=g-direct-tensorflow-import
import inspect
import numbers
import os
import re
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.numpy_ops import np_arrays
from tensorflow.python.ops.numpy_ops import np_dtypes
from tensorflow.python.ops.numpy_ops import np_export
from tensorflow.python.types import core
from tensorflow.python.util import nest
def _canonicalize_axis(axis, rank):
return _canonicalize_axes([axis], rank)[0]
def _canonicalize_axes(axes, rank):
rank = _maybe_static(rank)
if isinstance(rank, core.Tensor):
canonicalizer = (
lambda axis: cond(axis < 0, lambda: axis + rank, lambda: axis))
else:
canonicalizer = lambda axis: axis + rank if axis < 0 else axis
return [canonicalizer(axis) for axis in axes]
def _supports_signature():
return hasattr(inspect, 'signature')
def _to_tf_type(dtype):
"""Converts a native python or numpy type to TF DType.
Args:
dtype: Could be a python type, a numpy type or a TF DType.
Returns:
A tensorflow `DType`.
"""
return dtypes.as_dtype(dtype)
def _to_numpy_type(dtype):
"""Converts a native python or TF DType to numpy type.
Args:
dtype: Could be a python type, a numpy type or a TF DType.
Returns:
A NumPy `dtype`.
"""
if isinstance(dtype, dtypes.DType):
return dtype.as_numpy_dtype
return np.dtype(dtype)
def isscalar(val):
"""Returns whether `val` is a scalar value or scalar Tensor."""
if isinstance(val, np_arrays.ndarray):
val = val.data
if isinstance(val, core.Tensor):
ndims = val.shape.ndims
if ndims is not None:
return ndims == 0
else:
return math_ops.equal(array_ops.rank(val), 0)
else:
return np.isscalar(val)
def _has_docstring(f):
return (f and hasattr(f, '__doc__') and isinstance(f.__doc__, str) and
f.__doc__)
def _add_blank_line(s):
if s.endswith('\n'):
return s + '\n'
else:
return s + '\n\n'
def _np_signature(f):
"""An enhanced inspect.signature that can handle numpy.ufunc."""
# TODO(wangpeng): consider migrating away from inspect.signature.
# inspect.signature is supported in Python 3.3.
if not hasattr(inspect, 'signature'):
return None
if f is None:
return None
if not isinstance(f, np.ufunc):
try:
return inspect.signature(f)
except ValueError:
return None
def names_from_num(prefix, n):
if n <= 0:
return []
elif n == 1:
return [prefix]
else:
return [prefix + str(i + 1) for i in range(n)]
input_names = names_from_num('x', f.nin)
output_names = names_from_num('out', f.nout)
keyword_only_params = [('where', True), ('casting', 'same_kind'),
('order', 'K'), ('dtype', None), ('subok', True),
('signature', None), ('extobj', None)]
params = []
params += [
inspect.Parameter(name, inspect.Parameter.POSITIONAL_ONLY)
for name in input_names
]
if f.nout > 1:
params += [
inspect.Parameter(
name, inspect.Parameter.POSITIONAL_ONLY, default=None)
for name in output_names
]
params += [
inspect.Parameter(
'out',
inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=None if f.nout == 1 else (None,) * f.nout)
]
params += [
inspect.Parameter(name, inspect.Parameter.KEYWORD_ONLY, default=default)
for name, default in keyword_only_params
]
return inspect.Signature(params)
# Python 2 doesn't allow keyword-only argument. Python prior to 3.8 doesn't
# allow positional-only argument. So we conflate positional-only, keyword-only
# and positional-or-keyword arguments here.
def _is_compatible_param_kind(a, b):
def relax(k):
if k in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.KEYWORD_ONLY):
return inspect.Parameter.POSITIONAL_OR_KEYWORD
return k
return relax(a) == relax(b)
def _prepare_np_fun_name_and_fun(np_fun_name, np_fun):
"""Mutually propagates information between `np_fun_name` and `np_fun`.
If one is None and the other is not, we'll try to make the former not None in
a best effort.
Args:
np_fun_name: name for the np_fun symbol. At least one of np_fun or
np_fun_name shoud be set.
np_fun: the numpy function whose docstring will be used.
Returns:
Processed `np_fun_name` and `np_fun`.
"""
if np_fun_name is not None:
assert isinstance(np_fun_name, str)
if np_fun is not None:
assert not isinstance(np_fun, str)
if np_fun is None:
assert np_fun_name is not None
try:
np_fun = getattr(np, str(np_fun_name))
except AttributeError:
np_fun = None
if np_fun_name is None:
assert np_fun is not None
np_fun_name = np_fun.__name__
return np_fun_name, np_fun
def _np_doc_helper(f, np_f, np_fun_name=None, unsupported_params=None,
link=None):
"""Helper to get docs."""
assert np_f or np_fun_name
if not np_fun_name:
np_fun_name = np_f.__name__
doc = 'TensorFlow variant of NumPy\'s `%s`.\n\n' % np_fun_name
if unsupported_params:
doc += 'Unsupported arguments: ' + ', '.join(
'`' + name + '`' for name in unsupported_params) + '.\n\n'
if _has_docstring(f):
doc += f.__doc__
doc = _add_blank_line(doc)
# TODO(wangpeng): Re-enable the following and choose inlined vs. link to numpy
# doc according to some global switch.
doc = _add_np_doc(doc, np_fun_name, np_f, link=link)
return doc
_np_doc_form = os.getenv('TF_NP_DOC_FORM', '1.16')
def get_np_doc_form():
"""Gets the form of the original numpy docstrings.
Returns:
See `set_np_doc_form` for the list of valid values.
"""
return _np_doc_form
def set_np_doc_form(value):
r"""Selects the form of the original numpy docstrings.
This function sets a global variable that controls how a tf-numpy symbol's
docstring should refer to the original numpy docstring. If `value` is
`'inlined'`, the numpy docstring will be verbatim copied into the tf-numpy
docstring. Otherwise, a link to the original numpy docstring will be
added. Which numpy version the link points to depends on `value`:
* `'stable'`: the current stable version;
* `'dev'`: the current development version;
* pattern `\d+(\.\d+(\.\d+)?)?`: `value` will be treated as a version number,
e.g. '1.16'.
Args:
value: the value to set the global variable to.
"""
global _np_doc_form
_np_doc_form = value
class Link:
def __init__(self, v):
self.value = v
class AliasOf:
def __init__(self, v):
self.value = v
class NoLink:
pass
def generate_link(flag, np_fun_name):
"""Generates link from numpy function name.
Args:
flag: the flag to control link form. See `set_np_doc_form`.
np_fun_name: the numpy function name.
Returns:
A string.
"""
# Only adds link in this case
if flag == 'dev':
template = 'https://numpy.org/devdocs/reference/generated/numpy.%s.html'
elif flag == 'stable':
template = (
'https://numpy.org/doc/stable/reference/generated/numpy.%s.html')
elif re.match(r'\d+(\.\d+(\.\d+)?)?$', flag):
# `flag` is the version number
template = ('https://numpy.org/doc/' + flag +
'/reference/generated/numpy.%s.html')
else:
return None
return template % np_fun_name
_is_check_link = (os.getenv('TF_NP_CHECK_LINK', 'False') in
('True', 'true', '1'))
def is_check_link():
return _is_check_link
def set_check_link(value):
global _is_check_link
_is_check_link = value
def _add_np_doc(doc, np_fun_name, np_f, link):
"""Appends the numpy docstring to `doc`, according to `set_np_doc_form`.
See `set_np_doc_form` for how it controls the form of the numpy docstring.
Args:
doc: the docstring to be appended to.
np_fun_name: the name of the numpy function.
np_f: (optional) the numpy function.
link: (optional) which link to use. See `np_doc` for details.
Returns:
`doc` with numpy docstring appended.
"""
flag = get_np_doc_form()
if flag == 'inlined':
if _has_docstring(np_f):
doc += 'Documentation for `numpy.%s`:\n\n' % np_fun_name
# TODO(wangpeng): It looks like code snippets in numpy doc don't work
# correctly with doctest. Fix that and remove the reformatting of the np_f
# comment.
doc += np_f.__doc__.replace('>>>', '>')
elif isinstance(flag, str):
if link is None:
url = generate_link(flag, np_fun_name)
elif isinstance(link, AliasOf):
url = generate_link(flag, link.value)
elif isinstance(link, Link):
url = link.value
else:
url = None
if url is not None:
if is_check_link():
# Imports locally because some builds may not have `requests`
import requests # pylint: disable=g-import-not-at-top
r = requests.head(url)
if r.status_code != 200:
raise ValueError(
f'Check link failed at [{url}] with status code {r.status_code}. '
f'Argument `np_fun_name` is {np_fun_name}.')
doc += 'See the NumPy documentation for [`numpy.%s`](%s).' % (
np_fun_name, url)
return doc
_is_sig_mismatch_an_error = (
os.getenv('TF_NP_SIG_MISMATCH_IS_ERROR', 'False') in ('True', 'true', '1'))
def is_sig_mismatch_an_error():
return _is_sig_mismatch_an_error
def set_is_sig_mismatch_an_error(value):
global _is_sig_mismatch_an_error
_is_sig_mismatch_an_error = value
def np_doc(np_fun_name, np_fun=None, export=True, unsupported_params=None,
link=None):
"""Attachs numpy docstring to a function.
Args:
np_fun_name: name for the np_fun symbol. At least one of np_fun or
np_fun_name shoud be set.
np_fun: (optional) the numpy function whose docstring will be used.
export: whether to export this symbol under module
`tf.experimental.numpy`. Note that if `export` is `True`, `np_fun` must be
a function directly under the `numpy` module, not under any submodule of
`numpy` (e.g. `numpy.random`).
unsupported_params: (optional) the list of parameters not supported
by tf.numpy.
link: (optional) which link to use. If `None`, a default link generated from
`np_fun_name` will be used. If an instance of `AliasOf`, `link.value` will
be used in place of `np_fun_name` for the link generation. If an instance
of `Link`, `link.value` will be used as the whole link. If an instance of
`NoLink`, no link will be added.
Returns:
A function decorator that attaches the docstring from `np_fun` to the
decorated function.
"""
np_fun_name_orig, np_fun_orig = np_fun_name, np_fun
np_fun_name, np_fun = _prepare_np_fun_name_and_fun(np_fun_name, np_fun)
np_sig = _np_signature(np_fun)
if unsupported_params is None:
unsupported_params = []
def decorator(f):
"""The decorator."""
if hasattr(inspect, 'signature') and np_sig is not None:
try:
sig = inspect.signature(f)
except ValueError:
sig = None
if sig is not None:
for name, param in sig.parameters.items():
np_param = np_sig.parameters.get(name)
if np_param is None:
if is_sig_mismatch_an_error():
raise TypeError(
f'Cannot find parameter {name} in the numpy function\'s '
f'signature (which has these parameters: '
f'{list(np_sig.parameters.keys())}). Argument `np_fun_name` '
f'is {np_fun_name_orig}. Argument `np_fun` is {np_fun_orig}.')
else:
continue
if (is_sig_mismatch_an_error() and
not _is_compatible_param_kind(param.kind, np_param.kind)):
raise TypeError(
f'Parameter {name} is of kind {param.kind} while in numpy it '
f'is of kind {np_param.kind}. Argument `np_fun_name` is '
f'{np_fun_name_orig}. Argument `np_fun` is {np_fun_orig}.')
has_default = (param.default != inspect.Parameter.empty)
np_has_default = (np_param.default != inspect.Parameter.empty)
if is_sig_mismatch_an_error() and has_default != np_has_default:
raise TypeError(
'Parameter {} should{} have a default value. Argument '
'`np_fun_name` is {}. Argument `np_fun` is {}.'.format(
name, '' if np_has_default else ' not', np_fun_name_orig,
np_fun_orig))
for name in np_sig.parameters:
if name not in sig.parameters:
unsupported_params.append(name)
f.__doc__ = _np_doc_helper(
f, np_fun, np_fun_name=np_fun_name,
unsupported_params=unsupported_params, link=link)
if export:
return np_export.np_export(np_fun_name)(f)
else:
return f
return decorator
def np_doc_only(np_fun_name, np_fun=None, export=True):
"""Attachs numpy docstring to a function.
This differs from np_doc in that it doesn't check for a match in signature.
Args:
np_fun_name: name for the np_fun symbol. At least one of np_fun or
np_fun_name shoud be set.
np_fun: (optional) the numpy function whose docstring will be used.
export: whether to export this symbol under module
`tf.experimental.numpy`. Note that if `export` is `True`, `np_f` must be a
function directly under the `numpy` module, not under any submodule of
`numpy` (e.g. `numpy.random`).
Returns:
A function decorator that attaches the docstring from `np_fun` to the
decorated function.
"""
np_fun_name, np_fun = _prepare_np_fun_name_and_fun(np_fun_name, np_fun)
def decorator(f):
f.__doc__ = _np_doc_helper(f, np_fun, np_fun_name=np_fun_name)
if export:
return np_export.np_export(np_fun_name)(f)
else:
return f
return decorator
# pylint: disable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-docstring-missing-newline,g-doc-return-or-yield,g-doc-args
@np_doc('finfo')
def finfo(dtype):
"""Note that currently it just forwards to the numpy namesake, while
tensorflow and numpy dtypes may have different properties."""
return np.finfo(_to_numpy_type(dtype))
# pylint: enable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-docstring-missing-newline,g-doc-return-or-yield,g-doc-args
def _maybe_get_dtype(x):
"""Returns a numpy type if available from x. Skips if x is numpy.ndarray."""
# Don't put np.ndarray in this list, because np.result_type looks at the
# value (not just dtype) of np.ndarray to decide the result type.
if isinstance(x, numbers.Real):
return x
if isinstance(x, (core.Tensor, indexed_slices.IndexedSlices)):
return _to_numpy_type(x.dtype)
if isinstance(x, dtypes.DType):
return x.as_numpy_dtype
if isinstance(x, (list, tuple)):
raise ValueError(
f'Cannot find dtype for type inference from argument `x` of a sequence '
f'type {type(x)}. For sequences, please call this function on each '
f'element individually.')
return x
# Can't use np_doc because np.result_type is a builtin function.
@np_doc_only('result_type')
def result_type(*arrays_and_dtypes): # pylint: disable=missing-function-docstring
arrays_and_dtypes = [
_maybe_get_dtype(x) for x in nest.flatten(arrays_and_dtypes)
]
if not arrays_and_dtypes:
# If arrays_and_dtypes is an empty list, let numpy decide what the dtype is.
arrays_and_dtypes = [np.asarray([])]
return np_dtypes._result_type(*arrays_and_dtypes) # pylint: disable=protected-access
def result_type_unary(a, dtype): # pylint: disable=missing-function-docstring
"""Find the result type from a single input and a dtype."""
if dtype:
# We need to let np_utils.result_type decide the dtype, not tf.zeros_like
return result_type(dtype)
# np_utils.result_type treats string inputs as dtype strings, not as strings.
# but for unary we want to treat it as a string input.
if isinstance(a, str):
return np.unicode_
elif isinstance(a, bytes):
return np.bytes_
# TF and numpy has different interpretations of Python types such as
# `float`, so we let `np_utils.result_type` decide.
return result_type(a)
def _result_type_binary(t1, t2): # pylint: disable=missing-function-docstring
"""A specialization of result_type for 2 arguments for performance reasons."""
try:
return np_dtypes._result_type(_maybe_get_dtype(t1), # pylint: disable=protected-access
_maybe_get_dtype(t2)) # pylint: disable=protected-access
except ValueError:
return result_type(t1, t2)
@np_doc('promote_types')
def promote_types(type1, type2): # pylint: disable=missing-function-docstring
type1 = _to_numpy_type(type1)
type2 = _to_numpy_type(type2)
return np_dtypes.canonicalize_dtype(np.promote_types(type1, type2))
def tf_broadcast(*args):
"""Broadcast tensors.
Args:
*args: a list of tensors whose shapes are broadcastable against each other.
Returns:
Tensors broadcasted to the common shape.
"""
if len(args) <= 1:
return args
sh = array_ops.shape(args[0])
for arg in args[1:]:
sh = array_ops.broadcast_dynamic_shape(sh, array_ops.shape(arg))
return [array_ops.broadcast_to(arg, sh) for arg in args]
# TODO(wangpeng): Move the following functions to a separate file and check for
# float dtypes in each of them.
def get_static_value(x):
"""A version of tf.get_static_value that returns None on float dtypes.
It returns None on float dtypes in order to avoid breaking gradients.
Args:
x: a tensor.
Returns:
Same as `tf.get_static_value`, except that it returns None when `x` has a
float dtype.
"""
if isinstance(x, core.Tensor) and (x.dtype.is_floating or x.dtype.is_complex):
return None
return tensor_util.constant_value(x)
def _maybe_static(x):
value = get_static_value(x)
if value is None:
return x
else:
return value
# All the following functions exist becaues get_static_value can't handle
# their TF counterparts.
def cond(pred, true_fn, false_fn):
"""A version of tf.cond that tries to evaluate the condition."""
v = get_static_value(pred)
if v is None:
return control_flow_ops.cond(pred, true_fn, false_fn)
if v:
return true_fn()
else:
return false_fn()
def add(a, b):
"""A version of tf.add that eagerly evaluates if possible."""
return _maybe_static(a) + _maybe_static(b)
def subtract(a, b):
"""A version of tf.subtract that eagerly evaluates if possible."""
return _maybe_static(a) - _maybe_static(b)
def greater(a, b):
"""A version of tf.greater that eagerly evaluates if possible."""
return _maybe_static(a) > _maybe_static(b)
def greater_equal(a, b):
"""A version of tf.greater_equal that eagerly evaluates if possible."""
return _maybe_static(a) >= _maybe_static(b)
def less_equal(a, b):
"""A version of tf.less_equal that eagerly evaluates if possible."""
return _maybe_static(a) <= _maybe_static(b)
def logical_and(a, b):
"""A version of tf.logical_and that eagerly evaluates if possible."""
a_value = get_static_value(a)
if a_value is not None:
if np.isscalar(a_value):
if a_value:
return _maybe_static(b)
else:
return a_value
else:
return a_value & _maybe_static(b)
else:
return a & _maybe_static(b)
def logical_or(a, b):
"""A version of tf.logical_or that eagerly evaluates if possible."""
a_value = get_static_value(a)
if a_value is not None:
if np.isscalar(a_value):
if a_value:
return a_value
else:
return _maybe_static(b)
else:
return a_value | _maybe_static(b)
else:
return a | _maybe_static(b)
def getitem(a, slice_spec):
"""A version of __getitem__ that eagerly evaluates if possible."""
return _maybe_static(a)[slice_spec]
def reduce_all(input_tensor, axis=None, keepdims=False):
"""A version of tf.reduce_all that eagerly evaluates if possible."""
v = get_static_value(input_tensor)
if v is None:
return math_ops.reduce_all(input_tensor, axis=axis, keepdims=keepdims)
else:
return v.all(axis=axis, keepdims=keepdims)
def reduce_any(input_tensor, axis=None, keepdims=False):
"""A version of tf.reduce_any that eagerly evaluates if possible."""
v = get_static_value(input_tensor)
if v is None:
return math_ops.reduce_any(input_tensor, axis=axis, keepdims=keepdims)
else:
return v.any(axis=axis, keepdims=keepdims)
def tf_rank(t):
r = t.shape.rank
if r is not None:
return r
return array_ops.rank(t)