/
array_ops_stack.py
214 lines (169 loc) · 6.99 KB
/
array_ops_stack.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Tests for this file live in python/kernel_tests/array_ops_test.py
"""Operations to stack and unstack tensors."""
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@tf_export("stack")
@dispatch.add_dispatch_support
def stack(values, axis=0, name="stack"):
"""Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor.
See also `tf.concat`, `tf.tile`, `tf.repeat`.
Packs the list of tensors in `values` into a tensor with rank one higher than
each tensor in `values`, by packing them along the `axis` dimension.
Given a list of length `N` of tensors of shape `(A, B, C)`;
if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`.
if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`.
Etc.
For example:
>>> x = tf.constant([1, 4])
>>> y = tf.constant([2, 5])
>>> z = tf.constant([3, 6])
>>> tf.stack([x, y, z])
<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[1, 4],
[2, 5],
[3, 6]], dtype=int32)>
>>> tf.stack([x, y, z], axis=1)
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[1, 2, 3],
[4, 5, 6]], dtype=int32)>
This is the opposite of unstack. The numpy equivalent is `np.stack`
>>> np.array_equal(np.stack([x, y, z]), tf.stack([x, y, z]))
True
Args:
values: A list of `Tensor` objects with the same shape and type.
axis: An `int`. The axis to stack along. Defaults to the first dimension.
Negative values wrap around, so the valid range is `[-(R+1), R+1)`.
name: A name for this operation (optional).
Returns:
output: A stacked `Tensor` with the same type as `values`.
Raises:
ValueError: If `axis` is out of the range [-(R+1), R+1).
"""
if axis == 0:
try:
# If the input is a constant list, it can be converted to a constant op
return ops.convert_to_tensor(values, name=name)
except (TypeError, ValueError, NotImplementedError):
pass # Input list contains non-constant tensors
value_shape = ops.convert_to_tensor(values[0], name=name)._shape_tuple() # pylint: disable=protected-access
if value_shape is not None:
expanded_num_dims = len(value_shape) + 1
if axis < -expanded_num_dims or axis >= expanded_num_dims:
raise ValueError(f"Argument `axis` = {axis} not in range "
f"[{-expanded_num_dims}, {expanded_num_dims})")
return gen_array_ops.pack(values, axis=axis, name=name)
@tf_export("unstack")
@dispatch.add_dispatch_support
def unstack(value, num=None, axis=0, name="unstack"):
"""Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors.
Unpacks tensors from `value` by chipping it along the `axis` dimension.
>>> x = tf.reshape(tf.range(12), (3,4))
>>>
>>> p, q, r = tf.unstack(x)
>>> p.shape.as_list()
[4]
>>> i, j, k, l = tf.unstack(x, axis=1)
>>> i.shape.as_list()
[3]
This is the opposite of stack.
>>> x = tf.stack([i, j, k, l], axis=1)
More generally if you have a tensor of shape `(A, B, C, D)`:
>>> A, B, C, D = [2, 3, 4, 5]
>>> t = tf.random.normal(shape=[A, B, C, D])
The number of tensor returned is equal to the length of the target `axis`:
>>> axis = 2
>>> items = tf.unstack(t, axis=axis)
>>> len(items) == t.shape[axis]
True
The shape of each result tensor is equal to the shape of the input tensor,
with the target `axis` removed.
>>> items[0].shape.as_list() # [A, B, D]
[2, 3, 5]
The value of each tensor `items[i]` is equal to the slice of `input` across
`axis` at index `i`:
>>> for i in range(len(items)):
... slice = t[:,:,i,:]
... assert tf.reduce_all(slice == items[i])
#### Python iterable unpacking
With eager execution you _can_ unstack the 0th axis of a tensor using python's
iterable unpacking:
>>> t = tf.constant([1,2,3])
>>> a,b,c = t
`unstack` is still necessary because Iterable unpacking doesn't work in
a `@tf.function`: Symbolic tensors are not iterable.
You need to use `tf.unstack` here:
>>> @tf.function
... def bad(t):
... a,b,c = t
... return a
>>>
>>> bad(t)
Traceback (most recent call last):
...
OperatorNotAllowedInGraphError: ...
>>> @tf.function
... def good(t):
... a,b,c = tf.unstack(t)
... return a
>>>
>>> good(t).numpy()
1
#### Unknown shapes
Eager tensors have concrete values, so their shape is always known.
Inside a `tf.function` the symbolic tensors may have unknown shapes.
If the length of `axis` is unknown `tf.unstack` will fail because it cannot
handle an unknown number of tensors:
>>> @tf.function(input_signature=[tf.TensorSpec([None], tf.float32)])
... def bad(t):
... tensors = tf.unstack(t)
... return tensors[0]
>>>
>>> bad(tf.constant([1.0, 2.0, 3.0]))
Traceback (most recent call last):
...
ValueError: Cannot infer argument `num` from shape (None,)
If you know the `axis` length you can pass it as the `num` argument. But this
must be a constant value.
If you actually need a variable number of tensors in a single `tf.function`
trace, you will need to use exlicit loops and a `tf.TensorArray` instead.
Args:
value: A rank `R > 0` `Tensor` to be unstacked.
num: An `int`. The length of the dimension `axis`. Automatically inferred if
`None` (the default).
axis: An `int`. The axis to unstack along. Defaults to the first dimension.
Negative values wrap around, so the valid range is `[-R, R)`.
name: A name for the operation (optional).
Returns:
The list of `Tensor` objects unstacked from `value`.
Raises:
ValueError: If `axis` is out of the range `[-R, R)`.
ValueError: If `num` is unspecified and cannot be inferred.
InvalidArgumentError: If `num` does not match the shape of `value`.
"""
if num is None:
value = ops.convert_to_tensor(value)
value_shape = value.get_shape()
if value_shape.ndims is not None:
if axis < -value_shape.ndims or axis >= value_shape.ndims:
raise ValueError(f"Argument `axis` = {axis} not in range "
f"[{-value_shape.ndims}, {value_shape.ndims})")
num = value_shape.dims[axis].value
if num is None:
raise ValueError(f"Cannot infer argument `num` from shape {value_shape}")
return gen_array_ops.unpack(value, num=num, axis=axis, name=name)