/
map_fn.py
656 lines (556 loc) · 27.4 KB
/
map_fn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Functional operations."""
import re
from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
from tensorflow.python.autograph.impl import api as autograph
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util import variable_utils
from tensorflow.python.util.tf_export import tf_export
@tf_export(v1=["map_fn"])
@deprecation.deprecated_args(None, "Use fn_output_signature instead", "dtype")
def map_fn(fn,
elems,
dtype=None,
parallel_iterations=None,
back_prop=True,
swap_memory=False,
infer_shape=True,
name=None,
fn_output_signature=None):
"""Transforms `elems` by applying `fn` to each element unstacked on axis 0.
See also `tf.scan`.
`map_fn` unstacks `elems` on axis 0 to obtain a sequence of elements;
calls `fn` to transform each element; and then stacks the transformed
values back together.
#### Mapping functions with single-Tensor inputs and outputs
If `elems` is a single tensor and `fn`'s signature is `tf.Tensor->tf.Tensor`,
then `map_fn(fn, elems)` is equivalent to
`tf.stack([fn(elem) for elem in tf.unstack(elems)])`. E.g.:
>>> tf.map_fn(fn=lambda t: tf.range(t, t + 3), elems=tf.constant([3, 5, 2]))
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[3, 4, 5],
[5, 6, 7],
[2, 3, 4]], dtype=int32)>
`map_fn(fn, elems).shape = [elems.shape[0]] + fn(elems[0]).shape`.
#### Mapping functions with multi-arity inputs and outputs
`map_fn` also supports functions with multi-arity inputs and outputs:
* If `elems` is a tuple (or nested structure) of tensors, then those tensors
must all have the same outer-dimension size (`num_elems`); and `fn` is
used to transform each tuple (or structure) of corresponding slices from
`elems`. E.g., if `elems` is a tuple `(t1, t2, t3)`, then `fn` is used to
transform each tuple of slices `(t1[i], t2[i], t3[i])`
(where `0 <= i < num_elems`).
* If `fn` returns a tuple (or nested structure) of tensors, then the
result is formed by stacking corresponding elements from those structures.
#### Specifying `fn`'s output signature
If `fn`'s input and output signatures are different, then the output
signature must be specified using `fn_output_signature`. (The input and
output signatures are differ if their structures, dtypes, or tensor types do
not match). E.g.:
>>> tf.map_fn(fn=tf.strings.length, # input & output have different dtypes
... elems=tf.constant(["hello", "moon"]),
... fn_output_signature=tf.int32)
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([5, 4], dtype=int32)>
>>> tf.map_fn(fn=tf.strings.join, # input & output have different structures
... elems=[tf.constant(['The', 'A']), tf.constant(['Dog', 'Cat'])],
... fn_output_signature=tf.string)
<tf.Tensor: shape=(2,), dtype=string,
numpy=array([b'TheDog', b'ACat'], dtype=object)>
`fn_output_signature` can be specified using any of the following:
* A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`)
* A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`)
* A `tf.SparseTensorSpec` (to describe a `tf.sparse.SparseTensor`)
* A (possibly nested) tuple, list, or dict containing the above types.
#### RaggedTensors
`map_fn` supports `tf.RaggedTensor` inputs and outputs. In particular:
* If `elems` is a `RaggedTensor`, then `fn` will be called with each
row of that ragged tensor.
* If `elems` has only one ragged dimension, then the values passed to
`fn` will be `tf.Tensor`s.
* If `elems` has multiple ragged dimensions, then the values passed to
`fn` will be `tf.RaggedTensor`s with one fewer ragged dimension.
* If the result of `map_fn` should be a `RaggedTensor`, then use a
`tf.RaggedTensorSpec` to specify `fn_output_signature`.
* If `fn` returns `tf.Tensor`s with varying sizes, then use a
`tf.RaggedTensorSpec` with `ragged_rank=0` to combine them into a
single ragged tensor (which will have ragged_rank=1).
* If `fn` returns `tf.RaggedTensor`s, then use a `tf.RaggedTensorSpec`
with the same `ragged_rank`.
>>> # Example: RaggedTensor input
>>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
>>> tf.map_fn(tf.reduce_sum, rt, fn_output_signature=tf.int32)
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([6, 0, 9, 6], dtype=int32)>
>>> # Example: RaggedTensor output
>>> elems = tf.constant([3, 5, 0, 2])
>>> tf.map_fn(tf.range, elems,
... fn_output_signature=tf.RaggedTensorSpec(shape=[None],
... dtype=tf.int32))
<tf.RaggedTensor [[0, 1, 2], [0, 1, 2, 3, 4], [], [0, 1]]>
Note: `map_fn` should only be used if you need to map a function over the
*rows* of a `RaggedTensor`. If you wish to map a function over the
individual values, then you should use:
* `tf.ragged.map_flat_values(fn, rt)`
(if fn is expressible as TensorFlow ops)
* `rt.with_flat_values(map_fn(fn, rt.flat_values))`
(otherwise)
E.g.:
>>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
>>> tf.ragged.map_flat_values(lambda x: x + 2, rt)
<tf.RaggedTensor [[3, 4, 5], [], [6, 7], [8]]>
#### SparseTensors
`map_fn` supports `tf.sparse.SparseTensor` inputs and outputs. In particular:
* If `elems` is a `SparseTensor`, then `fn` will be called with each row
of that sparse tensor. In particular, the value passed to `fn` will be a
`tf.sparse.SparseTensor` with one fewer dimension than `elems`.
* If the result of `map_fn` should be a `SparseTensor`, then use a
`tf.SparseTensorSpec` to specify `fn_output_signature`. The individual
`SparseTensor`s returned by `fn` will be stacked into a single
`SparseTensor` with one more dimension.
>>> # Example: SparseTensor input
>>> st = tf.sparse.SparseTensor([[0, 0], [2, 0], [2, 1]], [2, 3, 4], [4, 4])
>>> tf.map_fn(tf.sparse.reduce_sum, st, fn_output_signature=tf.int32)
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([2, 0, 7, 0], dtype=int32)>
>>> # Example: SparseTensor output
>>> tf.sparse.to_dense(
... tf.map_fn(tf.sparse.eye, tf.constant([2, 3]),
... fn_output_signature=tf.SparseTensorSpec(None, tf.float32)))
<tf.Tensor: shape=(2, 3, 3), dtype=float32, numpy=
array([[[1., 0., 0.],
[0., 1., 0.],
[0., 0., 0.]],
[[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]]], dtype=float32)>
Note: `map_fn` should only be used if you need to map a function over the
*rows* of a `SparseTensor`. If you wish to map a function over the nonzero
values, then you should use:
* If the function is expressible as TensorFlow ops, use:
```python
tf.sparse.SparseTensor(st.indices, fn(st.values), st.dense_shape)
```
* Otherwise, use:
```python
tf.sparse.SparseTensor(st.indices, tf.map_fn(fn, st.values),
st.dense_shape)
```
#### `map_fn` vs. vectorized operations
`map_fn` will apply the operations used by `fn` to each element of `elems`,
resulting in `O(elems.shape[0])` total operations. This is somewhat
mitigated by the fact that `map_fn` can process elements in parallel.
However, a transform expressed using `map_fn` is still typically less
efficient than an equivalent transform expressed using vectorized operations.
`map_fn` should typically only be used if one of the following is true:
* It is difficult or expensive to express the desired transform with
vectorized operations.
* `fn` creates large intermediate values, so an equivalent vectorized
transform would take too much memory.
* Processing elements in parallel is more efficient than an equivalent
vectorized transform.
* Efficiency of the transform is not critical, and using `map_fn` is
more readable.
E.g., the example given above that maps `fn=lambda t: tf.range(t, t + 3)`
across `elems` could be rewritten more efficiently using vectorized ops:
>>> elems = tf.constant([3, 5, 2])
>>> tf.range(3) + tf.expand_dims(elems, 1)
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[3, 4, 5],
[5, 6, 7],
[2, 3, 4]], dtype=int32)>
In some cases, `tf.vectorized_map` can be used to automatically convert a
function to a vectorized equivalent.
#### Eager execution
When executing eagerly, `map_fn` does not execute in parallel even if
`parallel_iterations` is set to a value > 1. You can still get the
performance benefits of running a function in parallel by using the
`tf.function` decorator:
>>> fn=lambda t: tf.range(t, t + 3)
>>> @tf.function
... def func(elems):
... return tf.map_fn(fn, elems, parallel_iterations=3)
>>> func(tf.constant([3, 5, 2]))
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[3, 4, 5],
[5, 6, 7],
[2, 3, 4]], dtype=int32)>
Note: if you use the `tf.function` decorator, any non-TensorFlow Python
code that you may have written in your function won't get executed. See
`tf.function` for more details. The recommendation would be to debug without
`tf.function` but switch to it to get performance benefits of running `map_fn`
in parallel.
Args:
fn: The callable to be performed. It accepts one argument, which will have
the same (possibly nested) structure as `elems`. Its output must have the
same structure as `fn_output_signature` if one is provided; otherwise it
must have the same structure as `elems`.
elems: A tensor or (possibly nested) sequence of tensors, each of which will
be unstacked along their first dimension. `fn` will be applied to the
nested sequence of the resulting slices. `elems` may include ragged and
sparse tensors. `elems` must consist of at least one tensor.
dtype: Deprecated: Equivalent to `fn_output_signature`.
parallel_iterations: (optional) The number of iterations allowed to run in
parallel. When graph building, the default value is 10. While executing
eagerly, the default value is set to 1.
back_prop: (optional) False disables support for back propagation.
swap_memory: (optional) True enables GPU-CPU memory swapping.
infer_shape: (optional) False disables tests for consistent output shapes.
name: (optional) Name prefix for the returned tensors.
fn_output_signature: The output signature of `fn`. Must be specified if
`fn`'s input and output signatures are different (i.e., if their
structures, dtypes, or tensor types do not match).
`fn_output_signature` can be specified using any of the following:
* A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`)
* A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`)
* A `tf.SparseTensorSpec` (to describe a `tf.sparse.SparseTensor`)
* A (possibly nested) tuple, list, or dict containing the above types.
Returns:
A tensor or (possibly nested) sequence of tensors. Each tensor stacks the
results of applying `fn` to tensors unstacked from `elems` along the first
dimension, from first to last. The result may include ragged and sparse
tensors.
Raises:
TypeError: if `fn` is not callable or the structure of the output of
`fn` and `fn_output_signature` do not match.
ValueError: if the lengths of the output of `fn` and `fn_output_signature`
do not match, or if the `elems` does not contain any tensor.
Examples:
>>> elems = np.array([1, 2, 3, 4, 5, 6])
>>> tf.map_fn(lambda x: x * x, elems)
<tf.Tensor: shape=(6,), dtype=int64, numpy=array([ 1, 4, 9, 16, 25, 36])>
>>> elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
>>> tf.map_fn(lambda x: x[0] * x[1], elems, fn_output_signature=tf.int64)
<tf.Tensor: shape=(3,), dtype=int64, numpy=array([-1, 2, -3])>
>>> elems = np.array([1, 2, 3])
>>> tf.map_fn(lambda x: (x, -x), elems,
... fn_output_signature=(tf.int64, tf.int64))
(<tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 2, 3])>,
<tf.Tensor: shape=(3,), dtype=int64, numpy=array([-1, -2, -3])>)
"""
# This function uses a `while_loop` to call `fn` on each value of the input
# tensor(s) (unstacked on dimension 0). The following sequence of variables
# are used to transform the input tensor(s) (`elems`) into the output
# tensor(s) (`result`):
#
# - Preparing and unstacking input values for the while_loop:
# - elems: The input tensor(s) to map_fn. May include composite tensors.
# - elems_flat: Flattened list of tensors from elems (using nest.flatten)
# May include composite tensors.
# - elems_batchable: Concatenation of "batchable tensor lists" for each
# tensor in elems_flat. This "boxes" composite tensors
# into sliceable tf.Tensor objects. For more info see:
# TensorSpec._to_batched_tensor_list
# - elems_batchable_ta: List of TensorArrays used to unstack each Tensor
# in elems_batchable into elems_value_batchable.
#
# - Calling `fn` on each unstacked value in the body of the while_loop:
# - elems_value_batchable: Single unstacked value from elems_batchable.
# - elems_value_flat: Single unstacked value from elems_flat,
# constructed from elems_value_batchable (using
# TensorSpec._from_tensor_list).
# - elems_value: Single unstacked value from elems (the input to fn).
# - result_value: Result of calling `fn(elems_value)`. May contain
# composite tensors.
# - result_value_flat: Flattened list of tensors from result_value.
# May contain composite tensors.
# - result_value_batchable: Concatenation of batchable tensor lists for
# each tensor in result_value_flat
# (using TensorSpec._to_tensor_list).
#
# - Collecting and stacking output values from the while_loop:
# - result_batchable_ta: List of TensorArrays used to stack each tensor
# ta result_value_batchable into result_batchable.
# - result_batchable: Stacked tensors from result_batchable_ta.
# - result_flat: Flat list of tensors for the result, constructed from
# results bactchable (using TensorSpec._from_tensor_list).
# - result: Structured result value packed from results flat
# (using nest.pack_sequence_as).
if fn_output_signature is None:
fn_output_signature = dtype
if not callable(fn):
raise TypeError(f"The provided function {fn.__name__} is not callable."
"fn must be callable.")
in_graph_mode = not context.executing_eagerly()
# Set the default number of parallel_iterations depending on graph/eager mode.
if in_graph_mode and not parallel_iterations:
parallel_iterations = 10
elif not in_graph_mode and not parallel_iterations:
parallel_iterations = 1
elif not in_graph_mode and parallel_iterations > 1:
logging.log_first_n(
logging.WARN, "Setting parallel_iterations > 1 has no "
"effect when executing eagerly. Consider calling map_fn"
" with tf.function to execute fn in "
"parallel.", 1)
parallel_iterations = 1
# Explicitly read values of ResourceVariables.
elems = variable_utils.convert_variables_to_tensors(elems)
# Flatten the input tensors, and get the TypeSpec for each one.
elems_flat = nest.flatten(elems)
# Check in case this is an empty list
if len(elems_flat) == 0:
raise ValueError(
"elems must be a Tensor or (possibly nested) sequence of Tensors. "
"Got {}, which does not contain any Tensors.".format(elems))
elems_flat_signature = [type_spec.type_spec_from_value(e) for e in elems_flat]
elems_unflatten = lambda x: nest.pack_sequence_as(elems, x)
# Flatten fn's output signature.
if fn_output_signature is None:
# If fn_output_signature was not specified, then assume that it matches the
# input signature.
result_flat_signature = [
_most_general_compatible_type(s)._unbatch() # pylint: disable=protected-access
for s in elems_flat_signature
]
result_unflatten = elems_unflatten
else:
result_flat_signature = [
_dtype_to_spec(d) for d in nest.flatten(fn_output_signature)
]
result_unflatten = lambda x: nest.pack_sequence_as(fn_output_signature, x)
with ops.name_scope(name, "map", elems_flat):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager
if in_graph_mode:
# Any get_variable calls in fn will cache the first call locally
# and not issue repeated network I/O requests for each iteration.
varscope = vs.get_variable_scope()
varscope_caching_device_was_none = False
if varscope.caching_device is None:
# TODO(ebrevdo): Change to using colocate_with here and in other
# methods.
varscope.set_caching_device(lambda op: op.device)
varscope_caching_device_was_none = True
elems_flat = [
ops.convert_to_tensor_or_composite(t, name="elem") for t in elems_flat
]
# Check that inputs are not scalars.
first_elem = elems_flat[0]
if hasattr(first_elem, "shape"):
elems_static_shape = first_elem.shape
if elems_static_shape.ndims is not None and elems_static_shape.ndims < 1:
raise ValueError(
"Elements in elems must be 1+ dimensional Tensors, not scalars")
# Box any composite tensors into tensor lists.
elems_batchable = _elems_flat_to_batchable(elems_flat)
# Find the number of iterations, n. (may be known statically.)
n_static = tensor_shape.Dimension(
tensor_shape.dimension_value(
elems_batchable[0].get_shape().with_rank_at_least(1)[0]))
for tensor in elems_batchable[1:]:
n_static.assert_is_compatible_with(
tensor_shape.Dimension(
tensor_shape.dimension_value(
tensor.get_shape().with_rank_at_least(1)[0])))
n = n_static.value or array_ops.shape(elems_batchable[0])[0]
# Convert elems to tensor array.
# TODO(edloper): Should we set infer_shape=False for composite tensors?
elems_batchable_ta = [
tensor_array_ops.TensorArray(
dtype=t.dtype, size=n, dynamic_size=False, infer_shape=True)
for t in elems_batchable
]
# Unpack elements
elems_batchable_ta = [
ta.unstack(t) for (ta, t) in zip(elems_batchable_ta, elems_batchable)
]
i = constant_op.constant(0)
# Prepare result tensor array.
# TODO(edloper): Should we set infer_shape=False for composite tensors?
result_batchable_tensor_spec = (
_result_flat_signature_to_batchable_tensor_spec(result_flat_signature))
result_batchable_ta = []
for spec in result_batchable_tensor_spec:
result_batchable_ta.append(
tensor_array_ops.TensorArray(
dtype=spec.dtype, size=n, dynamic_size=False,
infer_shape=infer_shape, element_shape=spec.shape))
def compute(i, tas):
"""The loop body of map_fn.
Args:
i: the loop counter
tas: the flat TensorArray accumulator list
Returns:
(i + 1, tas): the updated counter + updated TensorArrays
Raises:
TypeError: if fn_output_signature and result_value structure don't match
ValueType: if fn_output_signature and result_value lengths don't match
"""
elems_value_batchable = [ta.read(i) for ta in elems_batchable_ta]
elems_value_flat = _elems_value_batchable_to_flat(elems_value_batchable,
elems_flat_signature)
elems_value = elems_unflatten(elems_value_flat)
ag_ctx = autograph_ctx.control_status_ctx()
autographed_fn = autograph.tf_convert(fn, ag_ctx)
result_value = autographed_fn(elems_value)
nest.assert_same_structure(fn_output_signature or elems, result_value)
result_value_flat = nest.flatten(result_value)
result_value_batchable = _result_value_flat_to_batchable(
result_value_flat, result_flat_signature)
tas = [
ta.write(i, value) for (ta, value) in zip(tas, result_value_batchable)
]
return (i + 1, tas)
_, r_a = control_flow_ops.while_loop(
lambda i, _: i < n,
compute, (i, result_batchable_ta),
parallel_iterations=parallel_iterations,
back_prop=back_prop,
swap_memory=swap_memory,
maximum_iterations=n)
result_batchable = [r.stack() for r in r_a]
# Update each output tensor w/ static shape info about the outer dimension.
for r in result_batchable:
r.set_shape(tensor_shape.TensorShape(n_static).concatenate(
r.get_shape()[1:]))
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager
if in_graph_mode and varscope_caching_device_was_none:
varscope.set_caching_device(None)
result_flat = _result_batchable_to_flat(result_batchable,
result_flat_signature,
n_static)
result = result_unflatten(result_flat)
return result
def _dtype_to_spec(d):
if not isinstance(d, type_spec.TypeSpec):
d = tensor_spec.TensorSpec(None, d)
return d
def _most_general_compatible_type(spec):
"""Returns the most general TypeSpec compatible with `spec`."""
# TODO(edloper): Consider adding most_general_compatible_type to TypeSpec API
if isinstance(spec, tensor_spec.TensorSpec):
return tensor_spec.TensorSpec(None, spec.dtype)
elif isinstance(spec, ragged_tensor.RaggedTensorSpec):
# pylint: disable=protected-access
return ragged_tensor.RaggedTensorSpec(None, spec._dtype, spec._ragged_rank,
spec._row_splits_dtype)
elif isinstance(spec, sparse_tensor.SparseTensorSpec):
# pylint: disable=protected-access
return sparse_tensor.SparseTensorSpec(None, spec.dtype)
else:
return spec
def _result_flat_signature_to_batchable_tensor_spec(result_flat_signature):
"""Converts result_flat_signature -> result_batchable_tensor_specs."""
tensor_specs = []
for spec in result_flat_signature:
if not isinstance(spec, type_spec.BatchableTypeSpec):
raise TypeError("map_fn can not generate %s outputs" % (spec,))
tensor_specs.extend(spec._flat_tensor_specs) # pylint: disable=protected-access
return tensor_specs
def _elems_flat_to_batchable(elems_flat):
"""Converts elems_flat -> elems_batchable."""
elems_batchable = []
for elems_tensor in elems_flat:
spec = type_spec.type_spec_from_value(elems_tensor)
if not isinstance(spec, type_spec.BatchableTypeSpec):
raise TypeError("map_fn can not consume %s inputs: got %r" %
(spec, elems_tensor))
# pylint: disable=protected-access
elems_batchable.extend(spec._to_batched_tensor_list(elems_tensor))
return elems_batchable
def _elems_value_batchable_to_flat(elems_value_batchable, elems_flat_signature):
"""Converts elems_value_batchable -> elems_value_flat."""
elems_value_flat = []
i = 0
for spec in elems_flat_signature:
# pylint: disable=protected-access
spec = spec._unbatch()
tensor_list = elems_value_batchable[i:i + len(spec._flat_tensor_specs)]
elems_value_flat.append(spec._from_compatible_tensor_list(tensor_list))
i += len(tensor_list)
assert i == len(elems_value_batchable)
return elems_value_flat
def _result_value_flat_to_batchable(result_value_flat, result_flat_signature):
"""Converts result_value_flat -> result_value_batchable."""
result_value_batchable = []
for (r_value, r_spec) in zip(result_value_flat, result_flat_signature):
if isinstance(r_spec, tensor_spec.TensorSpec):
result_value_batchable.append(r_value)
else:
if not r_spec.is_compatible_with(r_value):
raise ValueError(
"Error in map_fn:\n Expected `fn` to return a:\n %s\n"
" But it returned a:\n %s\n (value=%s)\n"
" To fix, update the `fn_output_signature` (or `dtype`) "
"argument to `map_fn`." %
(r_spec, type_spec.type_spec_from_value(r_value), r_value))
result_value_batchable.extend(r_spec._to_tensor_list(r_value)) # pylint: disable=protected-access
return result_value_batchable
def _result_batchable_to_flat(result_batchable, result_flat_signature,
batch_size):
"""Converts result_batchable -> result_flat."""
result_flat = []
i = 0
for spec in result_flat_signature:
# pylint: disable=protected-access
num_tensors = len(spec._flat_tensor_specs)
result_flat.append(
spec._batch(batch_size)._from_compatible_tensor_list(
result_batchable[i:i + num_tensors]))
i += num_tensors
assert i == len(result_batchable)
return result_flat
@tf_export("map_fn", v1=[])
@deprecation.deprecated_arg_values(
None,
"""back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.map_fn(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.map_fn(fn, elems))""",
warn_once=True,
back_prop=False)
@deprecation.deprecated_args(None, "Use fn_output_signature instead", "dtype")
def map_fn_v2(fn,
elems,
dtype=None,
parallel_iterations=None,
back_prop=True,
swap_memory=False,
infer_shape=True,
name=None,
fn_output_signature=None):
"""Transform `elems` by applying `fn` to each element unstacked on axis 0."""
if fn_output_signature is None:
fn_output_signature = dtype
return map_fn(
fn=fn,
elems=elems,
fn_output_signature=fn_output_signature,
parallel_iterations=parallel_iterations,
back_prop=back_prop,
swap_memory=swap_memory,
infer_shape=infer_shape,
name=name)
# Docstring for v2 is the same as v1, except that back_prop is deprecated.
map_fn_v2.__doc__ = re.sub(
r"( back_prop: \(optional\) )(.*)",
r"\1Deprecated: prefer using `tf.stop_gradient` instead. \2",
map_fn.__doc__)
assert "prefer using `tf.stop_gradient` instead" in map_fn_v2.__doc__