-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathindexed_slices.py
455 lines (372 loc) · 16.3 KB
/
indexed_slices.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Indexed slices."""
# pylint: disable=g-bad-name
import collections
import warnings
import numpy as np
from tensorflow.core.protobuf import struct_pb2
from tensorflow.python import tf2
from tensorflow.python.eager import context
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import composite_tensor_gradient
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_conversion_registry
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.types import internal
from tensorflow.python.util.compat import collections_abc
from tensorflow.python.util.tf_export import tf_export
class IndexedSlicesCompositeTensorGradient(
composite_tensor_gradient.CompositeTensorGradient):
"""CompositeTensorGradient for IndexedSlices."""
def get_gradient_components(self, value):
return value
def replace_gradient_components(self, value, component_grads):
return component_grads
# TODO(mdan): Should IndexedSlices be a "tensor"?
@tf_export("IndexedSlices")
class IndexedSlices(
internal.IndexedSlices,
internal.NativeObject,
composite_tensor.CompositeTensor):
"""A sparse representation of a set of tensor slices at given indices.
This class is a simple wrapper for a pair of `Tensor` objects:
* `values`: A `Tensor` of any dtype with shape `[D0, D1, ..., Dn]`.
* `indices`: A 1-D integer `Tensor` with shape `[D0]`.
An `IndexedSlices` is typically used to represent a subset of a larger
tensor `dense` of shape `[LARGE0, D1, .. , DN]` where `LARGE0 >> D0`.
The values in `indices` are the indices in the first dimension of
the slices that have been extracted from the larger tensor.
The dense tensor `dense` represented by an `IndexedSlices` `slices` has
```python
dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]
```
The `IndexedSlices` class is used principally in the definition of
gradients for operations that have sparse gradients
(e.g. `tf.gather`).
>>> v = tf.Variable([[0.,1, 2], [2, 3, 4], [4, 5, 6], [6, 7, 8]])
>>> with tf.GradientTape() as tape:
... r = tf.gather(v, [1,3])
>>> index_slices = tape.gradient(r,v)
>>> index_slices
<...IndexedSlices object ...>
>>> index_slices.indices.numpy()
array([1, 3], dtype=int32)
>>> index_slices.values.numpy()
array([[1., 1., 1.],
[1., 1., 1.]], dtype=float32)
Contrast this representation with
`tf.sparse.SparseTensor`,
which uses multi-dimensional indices and scalar values.
"""
def __init__(self, values, indices, dense_shape=None):
"""Creates an `IndexedSlices`."""
self._values = values
self._indices = indices
self._dense_shape = dense_shape
@property
def values(self):
"""A `Tensor` containing the values of the slices."""
return self._values
@property
def indices(self):
"""A 1-D `Tensor` containing the indices of the slices."""
return self._indices
@property
def dense_shape(self):
"""A 1-D `Tensor` containing the shape of the corresponding dense tensor."""
return self._dense_shape
@property
def shape(self):
"""Gets the `tf.TensorShape` representing the shape of the dense tensor.
Returns:
A `tf.TensorShape` object.
"""
if self._dense_shape is None:
return tensor_shape.TensorShape(None)
return tensor_util.constant_value_as_shape(self._dense_shape)
@property
def name(self):
"""The name of this `IndexedSlices`."""
return self.values.name
@property
def device(self):
"""The name of the device on which `values` will be produced, or `None`."""
return self.values.device
@property
def op(self):
"""The `Operation` that produces `values` as an output."""
return self.values.op
@property
def dtype(self):
"""The `DType` of elements in this tensor."""
return self.values.dtype
@property
def graph(self):
"""The `Graph` that contains the values, indices, and shape tensors."""
return self._values.graph
def __str__(self):
return "IndexedSlices(indices=%s, values=%s%s)" % (
self._indices, self._values,
(", dense_shape=%s" %
(self._dense_shape,)) if self._dense_shape is not None else "")
def __neg__(self):
return IndexedSlices(-self.values, self.indices, self.dense_shape)
__composite_gradient__ = IndexedSlicesCompositeTensorGradient()
@property
def _type_spec(self):
indices_shape = self._indices.shape.merge_with(self._values.shape[:1])
dense_shape = tensor_shape.TensorShape([None]).concatenate(
self._values.shape[1:])
if self._dense_shape is not None:
dense_shape_dtype = self._dense_shape.dtype
dense_shape = dense_shape.merge_with(
tensor_util.constant_value_as_shape(self._dense_shape))
else:
dense_shape_dtype = None
return IndexedSlicesSpec(dense_shape, self.dtype, self._indices.dtype,
dense_shape_dtype, indices_shape)
def _shape_invariant_to_type_spec(self, shape):
# From tf.while_loop docs: "If a loop variable is an IndexedSlices, the
# shape invariant must be a shape invariant of the values tensor of the
# IndexedSlices. It means the shapes of the three tensors of the
# IndexedSlices are (shape, [shape[0]], [shape.ndims])."
indices_shape = shape[:1]
dense_shape = tensor_shape.TensorShape([None]).concatenate(shape[1:])
if self._dense_shape is None:
dense_shape_dtype = None
else:
dense_shape_dtype = self._dense_shape.dtype
return IndexedSlicesSpec(dense_shape, self.dtype, self._indices.dtype,
dense_shape_dtype, indices_shape)
def consumers(self):
return self._consumers()
IndexedSlicesValue = collections.namedtuple(
"IndexedSlicesValue", ["values", "indices", "dense_shape"])
@tf_export("IndexedSlicesSpec")
class IndexedSlicesSpec(type_spec.TypeSpec):
"""Type specification for a `tf.IndexedSlices`."""
__slots__ = ["_shape", "_values_dtype", "_indices_dtype",
"_dense_shape_dtype", "_indices_shape"]
value_type = property(lambda self: IndexedSlices)
def __init__(self, shape=None, dtype=dtypes.float32,
indices_dtype=dtypes.int64, dense_shape_dtype=None,
indices_shape=None):
"""Constructs a type specification for a `tf.IndexedSlices`.
Args:
shape: The dense shape of the `IndexedSlices`, or `None` to allow any
dense shape.
dtype: `tf.DType` of values in the `IndexedSlices`.
indices_dtype: `tf.DType` of the `indices` in the `IndexedSlices`. One
of `tf.int32` or `tf.int64`.
dense_shape_dtype: `tf.DType` of the `dense_shape` in the `IndexedSlices`.
One of `tf.int32`, `tf.int64`, or `None` (if the `IndexedSlices` has
no `dense_shape` tensor).
indices_shape: The shape of the `indices` component, which indicates
how many slices are in the `IndexedSlices`.
"""
self._shape = tensor_shape.as_shape(shape)
self._values_dtype = dtypes.as_dtype(dtype)
self._indices_dtype = dtypes.as_dtype(indices_dtype)
if dense_shape_dtype is None:
self._dense_shape_dtype = None
else:
self._dense_shape_dtype = dtypes.as_dtype(dense_shape_dtype)
self._indices_shape = tensor_shape.as_shape(indices_shape).with_rank(1)
def _serialize(self):
return (self._shape, self._values_dtype, self._indices_dtype,
self._dense_shape_dtype, self._indices_shape)
@property
def _component_specs(self):
value_shape = self._indices_shape.concatenate(self._shape[1:])
specs = [
tensor_spec.TensorSpec(value_shape, self._values_dtype),
tensor_spec.TensorSpec(self._indices_shape, self._indices_dtype)]
if self._dense_shape_dtype is not None:
specs.append(
tensor_spec.TensorSpec([self._shape.ndims], self._dense_shape_dtype))
return tuple(specs)
def _to_components(self, value):
if value.dense_shape is None:
return (value.values, value.indices)
else:
return (value.values, value.indices, value.dense_shape)
def _from_components(self, tensor_list):
if (all(isinstance(t, np.ndarray) for t in tensor_list) and
not tf2.enabled()):
if len(tensor_list) == 2:
return IndexedSlicesValue(tensor_list[0], tensor_list[1], None)
else:
return IndexedSlicesValue(*tensor_list)
else:
return IndexedSlices(*tensor_list)
nested_structure_coder.register_codec(
nested_structure_coder.BuiltInTypeSpecCodec(
IndexedSlicesSpec, struct_pb2.TypeSpecProto.INDEXED_SLICES_SPEC
)
)
@tf_export(v1=["convert_to_tensor_or_indexed_slices"])
def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None):
"""Converts the given object to a `Tensor` or an `IndexedSlices`.
If `value` is an `IndexedSlices` or `SparseTensor` it is returned
unmodified. Otherwise, it is converted to a `Tensor` using
`convert_to_tensor()`.
Args:
value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed
by `convert_to_tensor()`.
dtype: (Optional.) The required `DType` of the returned `Tensor` or
`IndexedSlices`.
name: (Optional.) A name to use if a new `Tensor` is created.
Returns:
A `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`.
Raises:
ValueError: If `dtype` does not match the element type of `value`.
"""
return internal_convert_to_tensor_or_indexed_slices(
value=value, dtype=dtype, name=name, as_ref=False)
def internal_convert_to_tensor_or_indexed_slices(value,
dtype=None,
name=None,
as_ref=False):
"""Converts the given object to a `Tensor` or an `IndexedSlices`.
If `value` is an `IndexedSlices` or `SparseTensor` it is returned
unmodified. Otherwise, it is converted to a `Tensor` using
`convert_to_tensor()`.
Args:
value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed
by `convert_to_tensor()`.
dtype: (Optional.) The required `DType` of the returned `Tensor` or
`IndexedSlices`.
name: (Optional.) A name to use if a new `Tensor` is created.
as_ref: True if the caller wants the results as ref tensors.
Returns:
A `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`.
Raises:
ValueError: If `dtype` does not match the element type of `value`.
"""
if isinstance(value, ops.EagerTensor) and not context.executing_eagerly():
return ops.convert_to_tensor(value, dtype=dtype, name=name, as_ref=as_ref)
# TODO(mdan): Name says tensor_or_indexed_slices. So do explicitly just that?
elif isinstance(value, internal.NativeObject):
if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype):
raise ValueError(
"Incompatible tensor conversion requested to `dtype` "
f"{dtypes.as_dtype(dtype).name} for `value` ({value}) with dtype"
f" {value.dtype.name}.")
return value
else:
return ops.convert_to_tensor(value, dtype=dtype, name=name, as_ref=as_ref)
def internal_convert_n_to_tensor_or_indexed_slices(values,
dtype=None,
name=None,
as_ref=False):
"""Converts `values` to a list of `Tensor` or `IndexedSlices` objects.
Any `IndexedSlices` or `SparseTensor` objects in `values` are returned
unmodified.
Args:
values: An iterable of `None`, `IndexedSlices`, `SparseTensor`, or objects
that can be consumed by `convert_to_tensor()`.
dtype: (Optional.) The required `DType` of the returned `Tensor` or
`IndexedSlices`.
name: (Optional.) A name prefix to used when a new `Tensor` is created, in
which case element `i` will be given the name `name + '_' + i`.
as_ref: True if the caller wants the results as ref tensors.
Returns:
A list of `Tensor`, `IndexedSlices`, `SparseTensor` and/or `None` objects.
Raises:
TypeError: If no conversion function is registered for an element in
`values`.
RuntimeError: If a registered conversion function returns an invalid
value.
"""
if not isinstance(values, collections_abc.Iterable):
raise TypeError("Argument `values` must be iterable.")
ret = []
for i, value in enumerate(values):
if value is None:
ret.append(value)
else:
n = None if name is None else "%s_%d" % (name, i)
ret.append(
internal_convert_to_tensor_or_indexed_slices(
value, dtype=dtype, name=n, as_ref=as_ref))
return ret
def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None):
"""Converts `values` to a list of `Output` or `IndexedSlices` objects.
Any `IndexedSlices` or `SparseTensor` objects in `values` are returned
unmodified.
Args:
values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that
can be consumed by `convert_to_tensor()`.
dtype: (Optional.) The required `DType` of the returned `Tensor`
`IndexedSlices`.
name: (Optional.) A name prefix to used when a new `Tensor` is created, in
which case element `i` will be given the name `name + '_' + i`.
Returns:
A list of `Tensor`, `IndexedSlices`, and/or `SparseTensor` objects.
Raises:
TypeError: If no conversion function is registered for an element in
`values`.
RuntimeError: If a registered conversion function returns an invalid
value.
"""
return internal_convert_n_to_tensor_or_indexed_slices(
values=values, dtype=dtype, name=name, as_ref=False)
# Warn the user if we convert a sparse representation to dense with at
# least this number of elements.
_LARGE_SPARSE_NUM_ELEMENTS = 100000000
def _indexed_slices_to_tensor(value, dtype=None, name=None, as_ref=False):
"""Converts an IndexedSlices object `value` to a Tensor.
NOTE(mrry): This function is potentially expensive.
Args:
value: An ops.IndexedSlices object.
dtype: The dtype of the Tensor to be returned.
name: Optional name to use for the returned Tensor.
as_ref: True if a ref is requested.
Returns:
A dense Tensor representing the values in the given IndexedSlices.
Raises:
ValueError: If the IndexedSlices does not have the same dtype.
"""
_ = as_ref
if dtype and not dtype.is_compatible_with(value.dtype):
raise ValueError(
f"Incompatible tensor conversion requested to `dtype` {dtype.name} for "
f"IndexedSlices ({value}) with dtype {value.dtype.name}")
if value.dense_shape is None:
raise ValueError(
"Tensor conversion requested for IndexedSlices for argument `value` "
f"without dense_shape: {value!s}")
# TODO(mrry): Consider adding static shape information to
# IndexedSlices, to avoid using numpy here.
if not context.executing_eagerly():
dense_shape_value = tensor_util.constant_value(value.dense_shape)
if dense_shape_value is not None:
num_elements = np.prod(dense_shape_value)
if num_elements >= _LARGE_SPARSE_NUM_ELEMENTS:
warnings.warn(
"Converting sparse IndexedSlices to a dense Tensor with %d "
"elements. This may consume a large amount of memory." %
num_elements)
return gen_math_ops.unsorted_segment_sum(
value.values, value.indices, value.dense_shape[0], name=name)
tensor_conversion_registry.register_tensor_conversion_function(
IndexedSlices, _indexed_slices_to_tensor)