/
dispatch.py
1124 lines (886 loc) · 42.7 KB
/
dispatch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Type-based dispatch for TensorFlow's Python APIs.
"Python APIs" refers to Python functions that have been exported with
`tf_export`, such as `tf.add` and `tf.linalg.matmul`; they are sometimes also
referred to as "ops".
There are currently two dispatch systems for TensorFlow:
* The "fallback dispatch" system calls an API's standard implementation first,
and only tries to perform dispatch if that standard implementation raises a
TypeError (or ValueError) exception.
* The "type-based dispatch" system checks the types of the parameters passed
to an API, and performs dispatch if those types match any signatures that
have been registered for dispatch.
The fallback dispatch system was the original dispatch system, but it was
somewhat brittle and had limitations, such as an inability to support dispatch
for some operations (like convert_to_tensor). We plan to remove the fallback
dispatch system in favor of the type-based dispatch system, once all users have
been switched over to use it.
### Fallback Dispatch
The fallback dispatch system is based on "operation dispatchers", which can be
used to override the behavior for TensorFlow ops when they are called with
otherwise unsupported argument types. In particular, when an operation is
called with arguments that would cause it to raise a TypeError, it falls back on
its registered operation dispatchers. If any registered dispatchers can handle
the arguments, then its result is returned. Otherwise, the original TypeError is
raised.
### Type-based Dispatch
The main interface for the type-based dispatch system is the `dispatch_for_api`
decorator, which overrides the default implementation for a TensorFlow API.
The decorated function (known as the "dispatch target") will override the
default implementation for the API when the API is called with parameters that
match a specified type signature.
### Dispatch Support
By default, dispatch support is added to the generated op wrappers for any
visible ops by default. APIs/ops that are implemented in Python can opt in to
dispatch support using the `add_dispatch_support` decorator.
"""
import collections
import itertools
import typing # pylint: disable=unused-import (used in doctests)
from tensorflow.python.framework import _pywrap_python_api_dispatcher as _api_dispatcher
from tensorflow.python.framework import ops
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_export as tf_export_lib
from tensorflow.python.util import tf_inspect
from tensorflow.python.util import traceback_utils
from tensorflow.python.util import type_annotations
from tensorflow.python.util.tf_export import tf_export
# Private function attributes used to store dispatchers on TensorFlow APIs.
FALLBACK_DISPATCH_ATTR = "_tf_fallback_dispatchers"
TYPE_BASED_DISPATCH_ATTR = "_tf_type_based_dispatcher"
# OpDispatchers which should be used for all operations.
_GLOBAL_DISPATCHERS = []
################################################################################
# Fallback Dispatch
################################################################################
@tf_export("__internal__.dispatch.OpDispatcher", v1=[])
class OpDispatcher(object):
"""Abstract base class for TensorFlow operator dispatchers.
Each operation dispatcher acts as an override handler for a single
TensorFlow operation, and its results are used when the handler indicates
that it can handle the operation's arguments (by returning any value other
than `OpDispatcher.NOT_SUPPORTED`).
"""
# Sentinel value that can be returned to indicate that an operation
# dispatcher does not support a given set of arguments.
NOT_SUPPORTED = object()
def handle(self, args, kwargs): # pylint: disable=unused-argument
"""Handle this dispatcher's operation with the specified arguments.
If this operation dispatcher can handle the given arguments, then
return an appropriate value (or raise an appropriate exception).
Args:
args: The arguments to the operation.
kwargs: They keyword arguments to the operation.
Returns:
The result of the operation, or `OpDispatcher.NOT_SUPPORTED` if this
dispatcher can not handle the given arguments.
"""
return self.NOT_SUPPORTED
def register(self, op):
"""Register this dispatcher as a handler for `op`.
Args:
op: Python function: the TensorFlow operation that should be handled. Must
have a dispatch list (which is added automatically for generated ops,
and can be added to Python ops using the `add_dispatch_support`
decorator).
"""
if not hasattr(op, FALLBACK_DISPATCH_ATTR):
raise AssertionError("Dispatching not enabled for %s" % op)
getattr(op, FALLBACK_DISPATCH_ATTR).append(self)
@tf_export("__internal__.dispatch.GlobalOpDispatcher", v1=[])
class GlobalOpDispatcher(object):
"""Abstract base class for TensorFlow global operator dispatchers."""
NOT_SUPPORTED = OpDispatcher.NOT_SUPPORTED
def handle(self, op, args, kwargs):
"""Handle the specified operation with the specified arguments."""
def register(self):
"""Register this dispatcher as a handler for all ops."""
_GLOBAL_DISPATCHERS.append(self)
def dispatch(op, args, kwargs):
"""Returns the result from the first successful dispatcher for a given op.
Calls the `handle` method of each `OpDispatcher` that has been registered
to handle `op`, and returns the value from the first successful handler.
Args:
op: Python function: the operation to dispatch for.
args: The arguments to the operation.
kwargs: They keyword arguments to the operation.
Returns:
The result of the operation, or `NOT_SUPPORTED` if no registered
dispatcher can handle the given arguments.
"""
for dispatcher in getattr(op, FALLBACK_DISPATCH_ATTR):
result = dispatcher.handle(args, kwargs)
if result is not OpDispatcher.NOT_SUPPORTED:
return result
for dispatcher in _GLOBAL_DISPATCHERS:
result = dispatcher.handle(op, args, kwargs)
if result is not OpDispatcher.NOT_SUPPORTED:
return result
return OpDispatcher.NOT_SUPPORTED
class _TypeBasedDispatcher(OpDispatcher):
"""Dispatcher that handles op if any arguments have a specified type.
Checks the types of the arguments and keyword arguments (including elements
of lists or tuples), and if any argument values have the indicated type(s),
then delegates to an override function.
"""
def __init__(self, override_func, types):
self._types = types
self._override_func = override_func
def _handles(self, args, kwargs):
for arg in itertools.chain(args, kwargs.values()):
if (isinstance(arg, self._types) or
(isinstance(arg, (list, tuple)) and
any(isinstance(elt, self._types) for elt in arg))):
return True
return False
def handle(self, args, kwargs):
if self._handles(args, kwargs):
return self._override_func(*args, **kwargs)
else:
return self.NOT_SUPPORTED
# pylint: disable=g-doc-return-or-yield
def dispatch_for_types(op, *types):
"""Decorator to declare that a Python function overrides an op for a type.
The decorated function is used to override `op` if any of the arguments or
keyword arguments (including elements of lists or tuples) have one of the
specified types.
Example:
```python
@dispatch_for_types(math_ops.add, RaggedTensor, RaggedTensorValue)
def ragged_add(x, y, name=None): ...
```
Args:
op: Python function: the operation that should be overridden.
*types: The argument types for which this function should be used.
"""
def decorator(func):
if tf_inspect.getargspec(func) != tf_inspect.getargspec(op):
raise AssertionError("The decorated function's signature must exactly "
"match the signature of the overridden op.")
_TypeBasedDispatcher(func, types).register(op)
return func
return decorator
# pylint: enable=g-doc-return-or-yield
def add_fallback_dispatch_list(target):
"""Decorator that adds a dispatch_list attribute to an op."""
if hasattr(target, FALLBACK_DISPATCH_ATTR):
raise AssertionError("%s already has a dispatch list" % target)
setattr(target, FALLBACK_DISPATCH_ATTR, [])
return target
# Alias for backwards-compatibility.
add_dispatch_list = add_fallback_dispatch_list
################################################################################
# Type-based Dispatch
################################################################################
@tf_export("experimental.dispatch_for_api")
def dispatch_for_api(api, *signatures):
"""Decorator that overrides the default implementation for a TensorFlow API.
The decorated function (known as the "dispatch target") will override the
default implementation for the API when the API is called with parameters that
match a specified type signature. Signatures are specified using dictionaries
that map parameter names to type annotations. E.g., in the following example,
`masked_add` will be called for `tf.add` if both `x` and `y` are
`MaskedTensor`s:
>>> class MaskedTensor(tf.experimental.ExtensionType):
... values: tf.Tensor
... mask: tf.Tensor
>>> @dispatch_for_api(tf.math.add, {'x': MaskedTensor, 'y': MaskedTensor})
... def masked_add(x, y, name=None):
... return MaskedTensor(x.values + y.values, x.mask & y.mask)
>>> mt = tf.add(MaskedTensor([1, 2], [True, False]), MaskedTensor(10, True))
>>> print(f"values={mt.values.numpy()}, mask={mt.mask.numpy()}")
values=[11 12], mask=[ True False]
If multiple type signatures are specified, then the dispatch target will be
called if any of the signatures match. For example, the following code
registers `masked_add` to be called if `x` is a `MaskedTensor` *or* `y` is
a `MaskedTensor`.
>>> @dispatch_for_api(tf.math.add, {'x': MaskedTensor}, {'y':MaskedTensor})
... def masked_add(x, y):
... x_values = x.values if isinstance(x, MaskedTensor) else x
... x_mask = x.mask if isinstance(x, MaskedTensor) else True
... y_values = y.values if isinstance(y, MaskedTensor) else y
... y_mask = y.mask if isinstance(y, MaskedTensor) else True
... return MaskedTensor(x_values + y_values, x_mask & y_mask)
The type annotations in type signatures may be type objects (e.g.,
`MaskedTensor`), `typing.List` values, or `typing.Union` values. For
example, the following will register `masked_concat` to be called if `values`
is a list of `MaskedTensor` values:
>>> @dispatch_for_api(tf.concat, {'values': typing.List[MaskedTensor]})
... def masked_concat(values, axis):
... return MaskedTensor(tf.concat([v.values for v in values], axis),
... tf.concat([v.mask for v in values], axis))
Each type signature must contain at least one subclass of `tf.CompositeTensor`
(which includes subclasses of `tf.ExtensionType`), and dispatch will only be
triggered if at least one type-annotated parameter contains a
`CompositeTensor` value. This rule avoids invoking dispatch in degenerate
cases, such as the following examples:
* `@dispatch_for_api(tf.concat, {'values': List[MaskedTensor]})`: Will not
dispatch to the decorated dispatch target when the user calls
`tf.concat([])`.
* `@dispatch_for_api(tf.add, {'x': Union[MaskedTensor, Tensor], 'y':
Union[MaskedTensor, Tensor]})`: Will not dispatch to the decorated dispatch
target when the user calls `tf.add(tf.constant(1), tf.constant(2))`.
The dispatch target's signature must match the signature of the API that is
being overridden. In particular, parameters must have the same names, and
must occur in the same order. The dispatch target may optionally elide the
"name" parameter, in which case it will be wrapped with a call to
`tf.name_scope` when appropraite.
Args:
api: The TensorFlow API to override.
*signatures: Dictionaries mapping parameter names or indices to type
annotations, specifying when the dispatch target should be called. In
particular, the dispatch target will be called if any signature matches;
and a signature matches if all of the specified parameters have types that
match with the indicated type annotations. If no signatures are
specified, then a signature will be read from the dispatch target
function's type annotations.
Returns:
A decorator that overrides the default implementation for `api`.
#### Registered APIs
The TensorFlow APIs that may be overridden by `@dispatch_for_api` are:
<<API_LIST>>
"""
dispatcher = getattr(api, TYPE_BASED_DISPATCH_ATTR, None)
if dispatcher is None:
raise ValueError(f"{api} does not support dispatch.")
api_signature = tf_inspect.signature(api)
signature_checkers = [
_make_signature_checker(api_signature, signature)
for signature in signatures
]
def decorator(dispatch_target):
"""Decorator that registers the given dispatch target."""
if not callable(dispatch_target):
raise TypeError("Expected dispatch_target to be callable; "
f"got {dispatch_target!r}")
dispatch_target = _add_name_scope_wrapper(dispatch_target, api_signature)
_check_signature(api_signature, dispatch_target)
for signature_checker in signature_checkers:
dispatcher.Register(signature_checker, dispatch_target)
_TYPE_BASED_DISPATCH_SIGNATURES[api][dispatch_target].extend(signatures)
if not signature_checkers:
signature = _signature_from_annotations(dispatch_target)
checker = _make_signature_checker(api_signature, signature)
dispatcher.Register(checker, dispatch_target)
_TYPE_BASED_DISPATCH_SIGNATURES[api][dispatch_target].append(signature)
return dispatch_target
return decorator
# Nested dict mapping `api_func` -> `dispatch_target` -> `List[signature]`,
# which can be used for documentation generation and for improved error messages
# when APIs are called with unsupported types.
_TYPE_BASED_DISPATCH_SIGNATURES = {}
def apis_with_type_based_dispatch():
"""Returns a list of TensorFlow APIs that support type-based dispatch."""
return sorted(
_TYPE_BASED_DISPATCH_SIGNATURES,
key=lambda api: f"{api.__module__}.{api.__name__}")
def type_based_dispatch_signatures_for(cls):
"""Returns dispatch signatures that have been registered for a given class.
This function is intended for documentation-generation purposes.
Args:
cls: The class to search for. Type signatures are searched recursively, so
e.g., if `cls=RaggedTensor`, then information will be returned for all
dispatch targets that have `RaggedTensor` anywhere in their type
annotations (including nested in `typing.Union` or `typing.List`.)
Returns:
A `dict` mapping `api` -> `signatures`, where `api` is a TensorFlow API
function; and `signatures` is a list of dispatch signatures for `api`
that include `cls`. (Each signature is a dict mapping argument names to
type annotations; see `dispatch_for_api` for more info.)
"""
def contains_cls(x):
"""Returns true if `x` contains `cls`."""
if isinstance(x, dict):
return any(contains_cls(v) for v in x.values())
elif x is cls:
return True
elif (type_annotations.is_generic_list(x) or
type_annotations.is_generic_union(x)):
type_args = type_annotations.get_generic_type_args(x)
return any(contains_cls(arg) for arg in type_args)
else:
return False
result = {}
for api, api_signatures in _TYPE_BASED_DISPATCH_SIGNATURES.items():
for _, signatures in api_signatures.items():
filtered = list(filter(contains_cls, signatures))
if filtered:
result.setdefault(api, []).extend(filtered)
return result
# TODO(edloper): Consider using a mechanism like this to automatically add
# the `name` argument to all TensorFlow APIs that are implemented in Python
# (so each Python function doesn't need to do it manually).
def _add_name_scope_wrapper(func, api_signature):
"""Wraps `func` to expect a "name" arg, and use it to call `ops.name_scope`.
If `func` already expects a "name" arg, or if `api_signature` does not
expect a "name" arg, then returns `func` as-is.
Args:
func: The function to wrap. Signature must match `api_signature` (except
the "name" parameter may be missing.
api_signature: The signature of the original API (used to find the index for
the "name" parameter).
Returns:
The wrapped function (or the original function if no wrapping is needed).
"""
if "name" not in api_signature.parameters:
return func # no wrapping needed (API has no name parameter).
func_signature = tf_inspect.signature(func)
func_argspec = tf_inspect.getargspec(func)
if "name" in func_signature.parameters or func_argspec.keywords is not None:
return func # No wrapping needed (already has name parameter).
name_index = list(api_signature.parameters).index("name")
def wrapped_func(*args, **kwargs):
if name_index < len(args):
name = args[name_index]
args = args[:name_index] + args[name_index + 1:]
else:
name = kwargs.pop("name", None)
if name is None:
return func(*args, **kwargs)
else:
with ops.name_scope(name):
return func(*args, **kwargs)
wrapped_func = tf_decorator.make_decorator(func, wrapped_func)
wrapped_func.__signature__ = func_signature.replace(
parameters=(list(func_signature.parameters.values()) +
[api_signature.parameters["name"]]))
del wrapped_func._tf_decorator
return wrapped_func
@tf_export("experimental.unregister_dispatch_for")
def unregister_dispatch_for(dispatch_target):
"""Unregisters a function that was registered with `@dispatch_for_*`.
This is primarily intended for testing purposes.
Example:
>>> # Define a type and register a dispatcher to override `tf.abs`:
>>> class MyTensor(tf.experimental.ExtensionType):
... value: tf.Tensor
>>> @dispatch_for_api(tf.abs)
... def my_abs(x: MyTensor):
... return MyTensor(tf.abs(x.value))
>>> tf.abs(MyTensor(5))
MyTensor(value=<tf.Tensor: shape=(), dtype=int32, numpy=5>)
>>> # Unregister the dispatcher, so `tf.abs` no longer calls `my_abs`.
>>> unregister_dispatch_for(my_abs)
>>> tf.abs(MyTensor(5))
Traceback (most recent call last):
...
ValueError: Attempt to convert a value ... to a Tensor.
Args:
dispatch_target: The function to unregister.
Raises:
ValueError: If `dispatch_target` was not registered using `@dispatch_for`,
`@dispatch_for_unary_elementwise_apis`, or
`@dispatch_for_binary_elementwise_apis`.
"""
found = False
# Check if dispatch_target registered by `@dispatch_for_api`
for api, signatures in _TYPE_BASED_DISPATCH_SIGNATURES.items():
if dispatch_target in signatures:
dispatcher = getattr(api, TYPE_BASED_DISPATCH_ATTR)
dispatcher.Unregister(dispatch_target)
del signatures[dispatch_target]
found = True
# Check if dispatch_target registered by `@dispatch_for_*_elementwise_apis`
elementwise_keys_to_delete = [
key for (key, handler) in _ELEMENTWISE_API_HANDLERS.items()
if handler is dispatch_target
]
for key in set(elementwise_keys_to_delete):
for _, target in _ELEMENTWISE_API_TARGETS[key]:
unregister_dispatch_for(target)
del _ELEMENTWISE_API_HANDLERS[key]
del _ELEMENTWISE_API_TARGETS[key]
found = True
if not found:
raise ValueError(f"Function {dispatch_target} was not registered using "
"a `@dispatch_for_*` decorator.")
def register_dispatchable_type(cls):
"""Class decorator that registers a type for use with type-based dispatch.
Should *not* be used with subclasses of `CompositeTensor` or `ExtensionType`
(which are automatically registered).
Note: this function is intended to support internal legacy use cases (such
as RaggedTensorValue), and will probably not be exposed as a public API.
Args:
cls: The class to register.
Returns:
`cls`.
"""
_api_dispatcher.register_dispatchable_type(cls)
return cls
def add_type_based_api_dispatcher(target):
"""Adds a PythonAPIDispatcher to the given TensorFlow API function."""
if hasattr(target, TYPE_BASED_DISPATCH_ATTR):
raise ValueError(f"{target} already has a type-based API dispatcher.")
_, unwrapped = tf_decorator.unwrap(target)
target_argspec = tf_inspect.getargspec(unwrapped)
if target_argspec.varargs or target_argspec.keywords:
# @TODO(b/194903203) Add v2 dispatch support for APIs that take varargs
# and keywords. Examples of APIs that take varargs and kwargs: meshgrid,
# einsum, map_values, map_flat_values.
return target
setattr(
target, TYPE_BASED_DISPATCH_ATTR,
_api_dispatcher.PythonAPIDispatcher(unwrapped.__name__,
target_argspec.args,
target_argspec.defaults))
_TYPE_BASED_DISPATCH_SIGNATURES[target] = collections.defaultdict(list)
return target
def _check_signature(api_signature, func):
"""Checks that a dispatch target's signature is compatible with an API.
Args:
api_signature: The signature of the TensorFlow API.
func: The dispatch target.
Raises:
ValueError: if the signatures are incompatible. Two signatures are
considered compatible if they have the same number of parameters, and all
corresponding parameters have the same `name` and `kind`. (Parameters
are not required to have the same default value or the same annotation.)
"""
# Special case: if func_signature is (*args, **kwargs), then assume it's ok.
func_argspec = tf_inspect.getargspec(func)
if (func_argspec.varargs is not None and func_argspec.keywords is not None
and not func_argspec.args):
return
func_signature = tf_inspect.signature(func)
ok = len(api_signature.parameters) == len(func_signature.parameters)
if ok:
for param_1, param_2 in zip(api_signature.parameters.values(),
func_signature.parameters.values()):
if (param_1.name != param_2.name) or (param_1.kind != param_2.kind):
ok = False
if not ok:
raise ValueError(f"Dispatch function's signature {func_signature} does "
f"not match API's signature {api_signature}.")
def _make_signature_checker(api_signature, signature):
"""Builds a PySignatureChecker for the given type signature.
Args:
api_signature: The `inspect.Signature` of the API whose signature is
being checked.
signature: Dictionary mapping parameter names to type annotations.
Returns:
A `PySignatureChecker`.
"""
if not (isinstance(signature, dict) and
all(isinstance(k, (str, int)) for k in signature)):
raise TypeError("signatures must be dictionaries mapping parameter names "
"to type annotations.")
checkers = []
param_names = list(api_signature.parameters)
for param_name, param_type in signature.items():
# Convert positional parameters to named parameters.
if (isinstance(param_name, int) and
param_name < len(api_signature.parameters)):
param_name = list(api_signature.parameters.values())[param_name].name
# Check that the parameter exists, and has an appropriate kind.
param = api_signature.parameters.get(param_name, None)
if param is None:
raise ValueError("signature includes annotation for unknown "
f"parameter {param_name!r}.")
if param.kind not in (tf_inspect.Parameter.POSITIONAL_ONLY,
tf_inspect.Parameter.POSITIONAL_OR_KEYWORD):
raise ValueError("Dispatch currently only supports type annotations "
"for positional parameters; can't handle annotation "
f"for {param.kind!r} parameter {param_name}.")
checker = make_type_checker(param_type)
index = param_names.index(param_name)
checkers.append((index, checker))
return _api_dispatcher.PySignatureChecker(checkers)
# Cache for InstanceTypeChecker objects (we only want to create one
# InstanceTypeChecker for each type, since each one uses an internal cache
# to avoid repeated calls back into Python's isinstance).
_is_instance_checker_cache = {}
def make_type_checker(annotation):
"""Builds a PyTypeChecker for the given type annotation."""
if type_annotations.is_generic_union(annotation):
type_args = type_annotations.get_generic_type_args(annotation)
# If the union contains two or more simple types, then use a single
# InstanceChecker to check them.
simple_types = [t for t in type_args if isinstance(t, type)]
simple_types = tuple(sorted(simple_types, key=id))
if len(simple_types) > 1:
if simple_types not in _is_instance_checker_cache:
checker = _api_dispatcher.MakeInstanceChecker(*simple_types)
_is_instance_checker_cache[simple_types] = checker
options = ([_is_instance_checker_cache[simple_types]] +
[make_type_checker(t) for t in type_args
if not isinstance(t, type)])
return _api_dispatcher.MakeUnionChecker(options)
options = [make_type_checker(t) for t in type_args]
return _api_dispatcher.MakeUnionChecker(options)
elif type_annotations.is_generic_list(annotation):
type_args = type_annotations.get_generic_type_args(annotation)
if len(type_args) != 1:
raise AssertionError("Expected List[...] to have a single type parameter")
elt_type = make_type_checker(type_args[0])
return _api_dispatcher.MakeListChecker(elt_type)
elif isinstance(annotation, type):
if annotation not in _is_instance_checker_cache:
checker = _api_dispatcher.MakeInstanceChecker(annotation)
_is_instance_checker_cache[annotation] = checker
return _is_instance_checker_cache[annotation]
elif annotation is None:
return make_type_checker(type(None))
else:
raise ValueError(f"Type annotation {annotation} is not currently supported"
" by dispatch. Supported annotations: type objects, "
" List[...], and Union[...]")
def _signature_from_annotations(func):
"""Builds a dict mapping from parameter names to type annotations."""
func_signature = tf_inspect.signature(func)
signature = dict([(name, param.annotation)
for (name, param) in func_signature.parameters.items()
if param.annotation != tf_inspect.Parameter.empty])
if not signature:
raise ValueError("The dispatch_for_api decorator must be called with at "
"least one signature, or applied to a function that "
"has type annotations on its parameters.")
return signature
# Registries for elementwise APIs and API handlers.
#
# _*_ELEMENTWISE_APIS: A list of TensorFlow APIs that have been registered
# as elementwise operations using the `register_*_elementwise_api`
# decorators.
#
# _ELEMENTWISE_API_HANDLERS: Dicts mapping from argument type(s) to API
# handlers that have been registered with the `dispatch_for_*_elementwise_apis`
# decorators.
#
# _ELEMENTWISE_API_TARGETS: Dict mapping from argument type(s) to lists of
# `(api, dispatch_target)` pairs. Used to impelement
# `unregister_elementwise_api_handler`.
_UNARY_ELEMENTWISE_APIS = []
_BINARY_ELEMENTWISE_APIS = []
_ELEMENTWISE_API_HANDLERS = {}
_ELEMENTWISE_API_TARGETS = {}
@tf_export("experimental.dispatch_for_unary_elementwise_apis")
def dispatch_for_unary_elementwise_apis(x_type):
"""Decorator to override default implementation for unary elementwise APIs.
The decorated function (known as the "elementwise api handler") overrides
the default implementation for any unary elementwise API whenever the value
for the first argument (typically named `x`) matches the type annotation
`x_type`. The elementwise api handler is called with two arguments:
`elementwise_api_handler(api_func, x)`
Where `api_func` is a function that takes a single parameter and performs the
elementwise operation (e.g., `tf.abs`), and `x` is the first argument to the
elementwise api.
The following example shows how this decorator can be used to update all
unary elementwise operations to handle a `MaskedTensor` type:
>>> class MaskedTensor(tf.experimental.ExtensionType):
... values: tf.Tensor
... mask: tf.Tensor
>>> @dispatch_for_unary_elementwise_apis(MaskedTensor)
... def unary_elementwise_api_handler(api_func, x):
... return MaskedTensor(api_func(x.values), x.mask)
>>> mt = MaskedTensor([1, -2, -3], [True, False, True])
>>> abs_mt = tf.abs(mt)
>>> print(f"values={abs_mt.values.numpy()}, mask={abs_mt.mask.numpy()}")
values=[1 2 3], mask=[ True False True]
For unary elementwise operations that take extra arguments beyond `x`, those
arguments are *not* passed to the elementwise api handler, but are
automatically added when `api_func` is called. E.g., in the following
example, the `dtype` parameter is not passed to
`unary_elementwise_api_handler`, but is added by `api_func`.
>>> ones_mt = tf.ones_like(mt, dtype=tf.float32)
>>> print(f"values={ones_mt.values.numpy()}, mask={ones_mt.mask.numpy()}")
values=[1.0 1.0 1.0], mask=[ True False True]
Args:
x_type: A type annotation indicating when the api handler should be called.
See `dispatch_for_api` for a list of supported annotation types.
Returns:
A decorator.
#### Registered APIs
The unary elementwise APIs are:
<<API_LIST>>
"""
def decorator(handler):
if (x_type,) in _ELEMENTWISE_API_HANDLERS:
raise ValueError("A unary elementwise dispatch handler "
f"({_ELEMENTWISE_API_HANDLERS[(x_type,)]}) "
f"has already been registered for {x_type}.")
_ELEMENTWISE_API_HANDLERS[(x_type,)] = handler
for api in _UNARY_ELEMENTWISE_APIS:
_add_dispatch_for_unary_elementwise_api(api, x_type, handler)
return handler
return decorator
@tf_export("experimental.dispatch_for_binary_elementwise_apis")
def dispatch_for_binary_elementwise_apis(x_type, y_type):
"""Decorator to override default implementation for binary elementwise APIs.
The decorated function (known as the "elementwise api handler") overrides
the default implementation for any binary elementwise API whenever the value
for the first two arguments (typically named `x` and `y`) match the specified
type annotations. The elementwise api handler is called with two arguments:
`elementwise_api_handler(api_func, x, y)`
Where `x` and `y` are the first two arguments to the elementwise api, and
`api_func` is a TensorFlow function that takes two parameters and performs the
elementwise operation (e.g., `tf.add`).
The following example shows how this decorator can be used to update all
binary elementwise operations to handle a `MaskedTensor` type:
>>> class MaskedTensor(tf.experimental.ExtensionType):
... values: tf.Tensor
... mask: tf.Tensor
>>> @dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)
... def binary_elementwise_api_handler(api_func, x, y):
... return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)
>>> a = MaskedTensor([1, 2, 3, 4, 5], [True, True, True, True, False])
>>> b = MaskedTensor([2, 4, 6, 8, 0], [True, True, True, False, True])
>>> c = tf.add(a, b)
>>> print(f"values={c.values.numpy()}, mask={c.mask.numpy()}")
values=[ 3 6 9 12 5], mask=[ True True True False False]
Args:
x_type: A type annotation indicating when the api handler should be called.
y_type: A type annotation indicating when the api handler should be called.
Returns:
A decorator.
#### Registered APIs
The binary elementwise APIs are:
<<API_LIST>>
"""
def decorator(handler):
if (x_type, y_type) in _ELEMENTWISE_API_HANDLERS:
raise ValueError("A binary elementwise dispatch handler "
f"({_ELEMENTWISE_API_HANDLERS[x_type, y_type]}) "
f"has already been registered for ({x_type}, {y_type}).")
_ELEMENTWISE_API_HANDLERS[x_type, y_type] = handler
for api in _BINARY_ELEMENTWISE_APIS:
_add_dispatch_for_binary_elementwise_api(api, x_type, y_type, handler)
return handler
return decorator
def register_unary_elementwise_api(func):
"""Decorator that registers a TensorFlow op as a unary elementwise API."""
_UNARY_ELEMENTWISE_APIS.append(func)
for args, handler in _ELEMENTWISE_API_HANDLERS.items():
if len(args) == 1:
_add_dispatch_for_unary_elementwise_api(func, args[0], handler)
return func
def register_binary_elementwise_api(func):
"""Decorator that registers a TensorFlow op as a binary elementwise API."""
_BINARY_ELEMENTWISE_APIS.append(func)
for args, handler in _ELEMENTWISE_API_HANDLERS.items():
if len(args) == 2:
_add_dispatch_for_binary_elementwise_api(func, args[0], args[1], handler)
return func
def unary_elementwise_apis():
"""Returns a list of APIs that have been registered as unary elementwise."""
return tuple(_UNARY_ELEMENTWISE_APIS)
def binary_elementwise_apis():
"""Returns a list of APIs that have been registered as binary elementwise."""
return tuple(_BINARY_ELEMENTWISE_APIS)
def _add_dispatch_for_unary_elementwise_api(api, x_type,
elementwise_api_handler):
"""Registers a unary elementwise handler as a dispatcher for a given API."""
api_signature = tf_inspect.signature(api)
x_name = list(api_signature.parameters)[0]
name_index = _find_name_index(api_signature)
need_to_bind_api_args = (
len(api_signature.parameters) > 2 or
"name" not in api_signature.parameters)
@dispatch_for_api(api, {x_name: x_type})
def dispatch_target(*args, **kwargs):
args, kwargs, name = _extract_name_arg(args, kwargs, name_index)
if args:
x, args = args[0], args[1:]
else:
x = kwargs.pop(x_name)
if need_to_bind_api_args:
tensor_api = lambda v: api(v, *args, **kwargs)
else:
tensor_api = api
if name is None:
return elementwise_api_handler(tensor_api, x)
else:
with ops.name_scope(name, None, [x]):
return elementwise_api_handler(tensor_api, x)
dispatch_target.__name__ = "elementwise_dispatch_target_for_" + api.__name__
dispatch_target.__qualname__ = dispatch_target.__name__
# Keep track of what targets we've registered (so we can unregister them).
target_list = _ELEMENTWISE_API_TARGETS.setdefault((x_type,), [])
target_list.append((api, dispatch_target))
def _add_dispatch_for_binary_elementwise_api(api, x_type, y_type,
elementwise_api_handler):
"""Registers a binary elementwise handler as a dispatcher for a given API."""
api_signature = tf_inspect.signature(api)
x_name, y_name = list(api_signature.parameters)[:2]
name_index = _find_name_index(api_signature)
need_to_bind_api_args = (len(api_signature.parameters) > 3 or
"name" not in api_signature.parameters)
@dispatch_for_api(api, {x_name: x_type, y_name: y_type})
def dispatch_target(*args, **kwargs):
args, kwargs, name = _extract_name_arg(args, kwargs, name_index)
if len(args) > 1:
x, y, args = args[0], args[1], args[2:]
elif args:
x, args = args[0], args[1:]
y = kwargs.pop(y_name, None)
else:
x = kwargs.pop(x_name, None)
y = kwargs.pop(y_name, None)
if need_to_bind_api_args:
tensor_api = lambda v1, v2: api(v1, v2, *args, **kwargs)
else:
tensor_api = api
if name is None:
return elementwise_api_handler(tensor_api, x, y)
else:
with ops.name_scope(name, None, [x, y]):
return elementwise_api_handler(tensor_api, x, y)
dispatch_target.__name__ = "elementwise_dispatch_target_for_" + api.__name__
dispatch_target.__qualname__ = dispatch_target.__name__
# Keep track of what targets we've registered (so we can unregister them).
target_list = _ELEMENTWISE_API_TARGETS.setdefault((x_type, y_type), [])
target_list.append((api, dispatch_target))
def _find_name_index(signature):
"""Returns the index of the `name` parameter, or -1 if it's not present."""
try:
return list(signature.parameters).index("name")
except ValueError:
return -1
def _extract_name_arg(args, kwargs, name_index):
"""Extracts the parameter `name` and returns `(args, kwargs, name_value)`."""
if name_index < 0:
name_value = None
elif name_index < len(args):
name_value = args[name_index]
args = args[:name_index] + args[name_index + 1:]
else:
name_value = kwargs.pop("name", None)
return args, kwargs, name_value
def update_docstrings_with_api_lists():
"""Updates the docstrings of dispatch decorators with API lists.
Updates docstrings for `dispatch_for_api`,
`dispatch_for_unary_elementwise_apis`, and
`dispatch_for_binary_elementwise_apis`, by replacing the string '<<API_LIST>>'
with a list of APIs that have been registered for that decorator.
"""
_update_docstring_with_api_list(dispatch_for_unary_elementwise_apis,
_UNARY_ELEMENTWISE_APIS)
_update_docstring_with_api_list(dispatch_for_binary_elementwise_apis,
_BINARY_ELEMENTWISE_APIS)
_update_docstring_with_api_list(dispatch_for_api,
_TYPE_BASED_DISPATCH_SIGNATURES)
def _update_docstring_with_api_list(target, api_list):
"""Replaces `<<API_LIST>>` in target.__doc__ with the given list of APIs."""
lines = []
for func in api_list:
name = tf_export_lib.get_canonical_name_for_symbol(
func, add_prefix_to_v1_names=True)
if name is not None:
params = tf_inspect.signature(func).parameters.keys()
lines.append(f" * `tf.{name}({', '.join(params)})`")
lines.sort()
target.__doc__ = target.__doc__.replace(" <<API_LIST>>", "\n".join(lines))