/
tensor_conversion_registry.py
142 lines (117 loc) · 5.24 KB
/
tensor_conversion_registry.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Registry for tensor conversion functions."""
# pylint: disable=g-bad-name
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import threading
import numpy as np
import six
from tensorflow.python.util import lazy_loader
from tensorflow.python.util.tf_export import tf_export
# Loaded lazily due to a circular dependency
# ops->tensor_conversion_registry->constant_op->ops.
constant_op = lazy_loader.LazyLoader(
"constant_op", globals(),
"tensorflow.python.framework.constant_op")
_tensor_conversion_func_registry = collections.defaultdict(list)
_tensor_conversion_func_cache = {}
_tensor_conversion_func_lock = threading.Lock()
# Instances of these types are always converted using
# `_default_conversion_function`.
_UNCONVERTIBLE_TYPES = six.integer_types + (
float,
np.generic,
np.ndarray,
)
def _default_conversion_function(value, dtype, name, as_ref):
del as_ref # Unused.
return constant_op.constant(value, dtype, name=name)
# TODO(josh11b): Add ctx argument to conversion_func() signature.
@tf_export("register_tensor_conversion_function")
def register_tensor_conversion_function(base_type,
conversion_func,
priority=100):
"""Registers a function for converting objects of `base_type` to `Tensor`.
The conversion function must have the following signature:
```python
def conversion_func(value, dtype=None, name=None, as_ref=False):
# ...
```
It must return a `Tensor` with the given `dtype` if specified. If the
conversion function creates a new `Tensor`, it should use the given
`name` if specified. All exceptions will be propagated to the caller.
The conversion function may return `NotImplemented` for some
inputs. In this case, the conversion process will continue to try
subsequent conversion functions.
If `as_ref` is true, the function must return a `Tensor` reference,
such as a `Variable`.
NOTE: The conversion functions will execute in order of priority,
followed by order of registration. To ensure that a conversion function
`F` runs before another conversion function `G`, ensure that `F` is
registered with a smaller priority than `G`.
Args:
base_type: The base type or tuple of base types for all objects that
`conversion_func` accepts.
conversion_func: A function that converts instances of `base_type` to
`Tensor`.
priority: Optional integer that indicates the priority for applying this
conversion function. Conversion functions with smaller priority values run
earlier than conversion functions with larger priority values. Defaults to
100.
Raises:
TypeError: If the arguments do not have the appropriate type.
"""
base_types = base_type if isinstance(base_type, tuple) else (base_type,)
if any(not isinstance(x, type) for x in base_types):
raise TypeError("Argument `base_type` must be a type or a tuple of types. "
f"Obtained: {base_type}")
if any(issubclass(x, _UNCONVERTIBLE_TYPES) for x in base_types):
raise TypeError("Cannot register conversions for Python numeric types and "
"NumPy scalars and arrays.")
del base_types # Only needed for validation.
if not callable(conversion_func):
raise TypeError("Argument `conversion_func` must be callable. Received "
f"{conversion_func}.")
with _tensor_conversion_func_lock:
_tensor_conversion_func_registry[priority].append(
(base_type, conversion_func))
_tensor_conversion_func_cache.clear()
def get(query):
"""Get conversion function for objects of `cls`.
Args:
query: The type to query for.
Returns:
A list of conversion functions in increasing order of priority.
"""
if issubclass(query, _UNCONVERTIBLE_TYPES):
return [(query, _default_conversion_function)]
conversion_funcs = _tensor_conversion_func_cache.get(query)
if conversion_funcs is None:
with _tensor_conversion_func_lock:
# Has another thread populated the cache in the meantime?
conversion_funcs = _tensor_conversion_func_cache.get(query)
if conversion_funcs is None:
conversion_funcs = []
for _, funcs_at_priority in sorted(
_tensor_conversion_func_registry.items()):
conversion_funcs.extend(
(base_type, conversion_func)
for base_type, conversion_func in funcs_at_priority
if issubclass(query, base_type))
_tensor_conversion_func_cache[query] = conversion_funcs
return conversion_funcs