/
core.py
284 lines (220 loc) · 9.18 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Core TensorFlow types."""
import sys
import textwrap
from typing import Union
import numpy as np
from tensorflow.python.types import doc_typealias
from tensorflow.python.util.tf_export import tf_export
# pylint:disable=g-import-not-at-top
if sys.version_info >= (3, 8):
from typing import Protocol
from typing import runtime_checkable
else:
from typing_extensions import Protocol
from typing_extensions import runtime_checkable
# pylint:enable=g-import-not-at-top
# TODO(mdan): Consider adding ABC once the dependence on isinstance is reduced.
# TODO(mdan): Add type annotations.
# TODO(b/178822082): Revisit this API when tf.types gets more resource.
@tf_export("__internal__.types.Tensor", v1=[])
class Tensor(object):
"""The base class of all dense Tensor objects.
A dense tensor has a static data type (dtype), and may have a static rank and
shape. Tensor objects are immutable. Mutable objects may be backed by a Tensor
which holds the unique handle that identifies the mutable object.
"""
@property
def dtype(self):
pass
@property
def shape(self):
pass
class Symbol(Tensor):
"""Symbolic "graph" Tensor.
These objects represent the output of an op definition and do not carry a
value.
"""
pass
class Value(Tensor):
"""Tensor that can be associated with a value (aka "eager tensor").
These objects represent the (usually future) output of executing an op
immediately.
"""
def numpy(self):
pass
@tf_export("types.experimental.Callable", v1=[])
class Callable:
"""Base class for TF callables like those created by tf.function.
Note: Callables are conceptually very similar to `tf.Operation`: a
`tf.Operation` is a kind of callable.
"""
def __call__(self, *args, **kwargs):
"""Executes this callable.
This behaves like a regular op - in eager mode, it immediately starts
execution, returning results. In graph mode, it creates ops which return
symbolic TensorFlow values (like `tf.Tensor`, `tf.data.Dataset`,
etc.). For example, `tf.function` callables typically generate a
`tf.raw_ops.PartitionedCall` op, but not always - the
exact operations being generated are an internal implementation detail.
Args:
*args: positional argument for this call
**kwargs: keyword arguments for this call
Returns:
The execution results.
"""
@tf_export("types.experimental.ConcreteFunction", v1=[])
class ConcreteFunction(Callable):
"""Base class for graph functions.
A `ConcreteFunction` encapsulates a single graph function definition and
is differentiable under `tf.GradientTape` contexts.
"""
# TODO(mdan): Name just `types.Function`, for historic continuity?
@tf_export("types.experimental.GenericFunction", v1=[])
class GenericFunction(Callable):
"""Base class for polymorphic graph functions.
Graph functions are Python callable objects that dispatch calls to a
TensorFlow graph. Polymorphic graph functions can be backed by multiple TF
graphs, and automatically select the appropriate specialization based on the
type of input they were called with. They may also create specializations on
the fly if necessary, for example by tracing.
Also see `tf.function`.
"""
def get_concrete_function(self, *args, **kwargs) -> ConcreteFunction:
"""Returns a `ConcreteFunction` specialized to input types.
The arguments specified by `args` and `kwargs` follow normal function call
rules. The returned `ConcreteFunction` has the same set of positional and
keyword arguments as `self`, but their types are compatible to the types
specified by `args` and `kwargs` (though not neccessarily equal).
>>> @tf.function
... def f(x):
... return x
>>> f_concrete = f.get_concrete_function(tf.constant(1.0))
>>> f_concrete = f.get_concrete_function(x=tf.constant(1.0))
Unlike normal calls, `get_concrete_function` allow type specifiers instead
of TensorFlow objects, so for example `tf.Tensor`s may be replaced with
`tf.TensorSpec`s.
>>> @tf.function
... def f(x):
... return x
>>> f_concrete = f.get_concrete_function(tf.TensorSpec([], tf.float64))
If the function definition allows only one specialization, `args` and
`kwargs` may be omitted altogether.
>>> @tf.function(input_signature=[tf.TensorSpec(None, tf.float32)])
... def f(x):
... return x
>>> f_concrete = f.get_concrete_function()
The returned `ConcreteFunction` can be called normally:
>>> f_concrete(tf.constant(1.0))
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>
>>> f_concrete(x=tf.constant(1.0))
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>
Args:
*args: inputs to specialize on.
**kwargs: inputs to specialize on.
Returns:
A `ConcreteFunction`.
"""
pass
def experimental_get_compiler_ir(self, *args, **kwargs):
"""Returns compiler IR for the compiled function.
This API is intended *only* for debugging as there are no guarantees on
backwards compatibility of returned IR or the allowed values of `stage`.
Args:
*args: Arguments used for compilation; same arguments as used for calling
the function. Need to be eager tensors.
**kwargs: Keyword arguments used for compilation.
Returns:
Function callable with the following kwargs:
- `stage` at which the compiler IR should be serialized. Allowed values
are:
- `hlo`: HLO output after conversion from TF
(https://www.tensorflow.org/xla/operation_semantics).
- `hlo_serialized`: Like stage=`hlo`, but the output is a serialized
HLO module proto (a bytes object).
- `optimized_hlo`: HLO after compiler optimizations.
- `optimized_hlo_serialized`: Like stage=`optimized_hlo`, but the
output is a serialized HLO module proto (a bytes object).
- `optimized_hlo_dot`: optimized HLO in DOT format suitable for
Graphviz.
- `device_name` can be either None, in which case the preferred device
is used for compilation, or a device name. It can be a full device
name, or a partial one, e.g., `/device:CPU:0`.
For example, for
```python
@tf.function(jit_compile=True)
def f(x):
return x + 1
f.experimental_get_compiler_ir(tf.random.normal([10, 10])(stage='hlo')
```
the output is:
```
HloModule a_inference_f_13__.9
ENTRY %a_inference_f_13__.9 (arg0.1: f32[10,10]) -> f32[10,10] {
%arg0.1 = f32[10,10]{1,0} parameter(0), parameter_replication={false}
%reshape.2 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %arg0.1)
%constant.3 = f32[] constant(1)
%broadcast.4 = f32[10,10]{1,0} broadcast(f32[] %constant.3)
%add.5 = f32[10,10]{1,0} add(f32[10,10]{1,0} %reshape.2,
f32[10,10]{1,0} %broadcast.4)
%reshape.6 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %add.5)
%tuple.7 = (f32[10,10]{1,0}) tuple(f32[10,10]{1,0} %reshape.6)
ROOT %get-tuple-element.8 = f32[10,10]{1,0}
get-tuple-element((f32[10,10]{1,0}) %tuple.7), index=0
}
```
Raises:
ValueError: If an invalid `stage` is selected or if applied to a function
which is not compiled (`jit_compile=True` is not set).
TypeError: When called with input in graph mode.
"""
pass
@runtime_checkable
class TensorProtocol(Protocol):
"""Protocol type for objects that can be converted to Tensor."""
def __tf_tensor__(self, dtype=None, name=None):
"""Converts this object to a Tensor.
Args:
dtype: data type for the returned Tensor
name: a name for the operations which create the Tensor
Returns:
A Tensor.
"""
pass
# TODO(rahulkamat): Add missing types that are convertible to Tensor.
TensorLike = Union[Tensor, TensorProtocol, int, float, bool, str, bytes,
complex, tuple, list, np.ndarray, np.generic]
doc_typealias.document(
obj=TensorLike,
doc=textwrap.dedent("""\
Union of all types that can be converted to a `tf.Tensor` by `tf.convert_to_tensor`.
This definition may be used in user code. Additional types may be added
in the future as more input types are supported.
Example:
```
def foo(x: TensorLike):
pass
```
This definition passes static type verification for:
```
foo(tf.constant([1, 2, 3]))
foo([1, 2, 3])
foo(np.array([1, 2, 3]))
```
"""),
)
tf_export("types.experimental.TensorLike").export_constant(
__name__, "TensorLike")