-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsharded_variable.py
910 lines (767 loc) · 35.2 KB
/
sharded_variable.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ShardedVariable class."""
import copy
import math
from typing import Sequence
import weakref
import numpy as np
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices as indexed_slices_lib
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_conversion_registry
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.saved_model import save_context
from tensorflow.python.trackable import base as trackable
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@tf_export('distribute.experimental.partitioners.Partitioner', v1=[])
class Partitioner(object):
"""Partitioner base class: all partitiners inherit from this class.
Partitioners should implement a `__call__` method with the following
signature:
```python
def __call__(self, shape, dtype, axis=0):
# Partitions the given `shape` and returns the partition results.
# See docstring of `__call__` method for the format of partition results.
```
"""
def __call__(self, shape, dtype, axis=0):
"""Partitions the given `shape` and returns the partition results.
Examples of a partitioner that allocates a fixed number of shards:
```python
partitioner = FixedShardsPartitioner(num_shards=2)
partitions = partitioner(tf.TensorShape([10, 3], tf.float32), axis=0)
print(partitions) # [2, 0]
```
Args:
shape: a `tf.TensorShape`, the shape to partition.
dtype: a `tf.dtypes.Dtype` indicating the type of the partition value.
axis: The axis to partition along. Default: outermost axis.
Returns:
A list of integers representing the number of partitions on each axis,
where i-th value correponds to i-th axis.
"""
raise NotImplementedError
@tf_export('distribute.experimental.partitioners.FixedShardsPartitioner', v1=[])
class FixedShardsPartitioner(Partitioner):
"""Partitioner that allocates a fixed number of shards.
Examples:
>>> # standalone usage:
>>> partitioner = FixedShardsPartitioner(num_shards=2)
>>> partitions = partitioner(tf.TensorShape([10, 3]), tf.float32)
>>> [2, 1]
>>>
>>> # use in ParameterServerStrategy
>>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
>>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
"""
def __init__(self, num_shards):
"""Creates a new `FixedShardsPartitioner`.
Args:
num_shards: `int`, number of shards to partition.
"""
self._num_shards = num_shards
def __call__(self, shape, dtype, axis=0):
del dtype
result = [1] * len(shape)
result[axis] = min(self._num_shards, shape.dims[axis].value)
return result
@tf_export('distribute.experimental.partitioners.MinSizePartitioner', v1=[])
class MinSizePartitioner(Partitioner):
"""Partitioner that allocates a minimum size per shard.
This partitioner ensures each shard has at least `min_shard_bytes`, and tries
to allocate as many shards as possible, i.e., keeping shard size as small as
possible. The maximum number of such shards (upper bound) is given by
`max_shards`.
Examples:
>>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=2)
>>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
>>> [2, 1]
>>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=10)
>>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
>>> [6, 1]
>>>
>>> # use in ParameterServerStrategy
>>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
>>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
"""
def __init__(self,
min_shard_bytes=256 << 10,
max_shards=1,
bytes_per_string=16):
"""Creates a new `MinSizePartitioner`.
Args:
min_shard_bytes: Minimum bytes of each shard. Defaults to 256K.
max_shards: Upper bound on the number of shards. Defaults to 1.
bytes_per_string: If the partition value is of type string, this provides
an estimate of how large each string is.
"""
if min_shard_bytes < 1:
raise ValueError('Argument `min_shard_bytes` must be positive. '
f'Received: {min_shard_bytes}')
if max_shards < 1:
raise ValueError('Argument `max_shards` must be positive. '
f'Received: {max_shards}')
if bytes_per_string < 1:
raise ValueError('Argument `bytes_per_string` must be positive. '
f'Received: {bytes_per_string}')
self._min_shard_bytes = min_shard_bytes
self._max_shards = max_shards
self._bytes_per_string = bytes_per_string
def __call__(self, shape, dtype, axis=0):
return partitioned_variables.min_max_variable_partitioner(
max_partitions=self._max_shards,
axis=axis,
min_slice_size=self._min_shard_bytes,
bytes_per_string_element=self._bytes_per_string)(shape, dtype)
@tf_export('distribute.experimental.partitioners.MaxSizePartitioner', v1=[])
class MaxSizePartitioner(Partitioner):
"""Partitioner that keeps shards below `max_shard_bytes`.
This partitioner ensures each shard has at most `max_shard_bytes`, and tries
to allocate as few shards as possible, i.e., keeping shard size as large
as possible.
If the partitioner hits the `max_shards` limit, then each shard may end up
larger than `max_shard_bytes`. By default `max_shards` equals `None` and no
limit on the number of shards is enforced.
Examples:
>>> partitioner = MaxSizePartitioner(max_shard_bytes=4)
>>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
>>> [6, 1]
>>> partitioner = MaxSizePartitioner(max_shard_bytes=4, max_shards=2)
>>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
>>> [2, 1]
>>> partitioner = MaxSizePartitioner(max_shard_bytes=1024)
>>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
>>> [1, 1]
>>>
>>> # use in ParameterServerStrategy
>>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
>>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
"""
def __init__(self, max_shard_bytes, max_shards=None, bytes_per_string=16):
"""Creates a new `MaxSizePartitioner`.
Args:
max_shard_bytes: The maximum size any given shard is allowed to be.
max_shards: The maximum number of shards in `int` created taking
precedence over `max_shard_bytes`.
bytes_per_string: If the partition value is of type string, this provides
an estimate of how large each string is.
"""
if max_shard_bytes < 1:
raise ValueError('Argument `max_shard_bytes` must be positive. '
f'Received {max_shard_bytes}')
if max_shards and max_shards < 1:
raise ValueError('Argument `max_shards` must be positive. '
f'Received {max_shards}')
if bytes_per_string < 1:
raise ValueError('Argument `bytes_per_string` must be positive. '
f'Received: {bytes_per_string}')
self._max_shard_bytes = max_shard_bytes
self._max_shards = max_shards
self._bytes_per_string = bytes_per_string
def __call__(self, shape, dtype, axis=0):
return partitioned_variables.variable_axis_size_partitioner(
max_shard_bytes=self._max_shard_bytes,
max_shards=self._max_shards,
bytes_per_string_element=self._bytes_per_string,
axis=axis)(shape, dtype)
class ShardedVariableSpec(type_spec.TypeSpec):
"""Type specification for a `ShardedVariable`."""
__slots__ = ['_variable_specs']
value_type = property(lambda self: ShardedVariable)
def __init__(self, *variable_specs):
self._variable_specs = tuple(variable_specs)
def _serialize(self):
return self._variable_specs
@property
def _component_specs(self):
return self._variable_specs
def _to_components(self, value):
return value.variables
def _from_components(self, variables):
return ShardedVariable(variables)
class ShardedVariableMixin(trackable.Trackable):
"""Mixin for ShardedVariable."""
# TODO(b/170877138): Remove this mixin once fixed. This mixin is required
# since TPUEmbeddingVariable can't be a CompositeTensor.
def __init__(self, variables, name='ShardedVariable'):
"""Treats `variables` as shards of a larger Variable.
Example:
```
variables = [
tf.Variable(..., shape=(10, 100), dtype=tf.float32),
tf.Variable(..., shape=(15, 100), dtype=tf.float32),
tf.Variable(..., shape=(5, 100), dtype=tf.float32)
]
sharded_variable = ShardedVariableMixin(variables)
assert sharded_variable.shape.as_list() == [30, 100]
```
Args:
variables: A list of `ResourceVariable`s that comprise this sharded
variable. Variables should not be shared between different
`ShardedVariableMixin` objects.
name: String. Name of this container. Defaults to "ShardedVariable".
"""
super(ShardedVariableMixin, self).__init__()
self._variables = variables
self._name = name
if not isinstance(variables, Sequence) or not variables or any(
not isinstance(v, variables_lib.Variable) for v in variables):
raise TypeError('Argument `variables` should be a non-empty list of '
f'`variables.Variable`s. Received {variables}')
var_dtypes = {v.dtype for v in variables}
if len(var_dtypes) > 1:
raise ValueError(
'All elements in argument `variables` must have the same dtype. '
f'Received dtypes: {[v.dtype for v in variables]}')
first_var = variables[0]
self._dtype = first_var.dtype
# All variables must have the same shape for axes > 0.
higher_dim_shapes = {tuple(v.shape.as_list()[1:]) for v in variables}
if len(higher_dim_shapes) > 1:
raise ValueError(
'All elements in argument `variables` must have the same shapes '
'except for the first axis. '
f'Received shapes: {[v.shape for v in variables]}')
first_dim = sum(int(v.shape.as_list()[0]) for v in variables)
self._shape = tensor_shape.TensorShape([first_dim] +
first_var.shape.as_list()[1:])
for v in variables:
v._sharded_container = weakref.ref(self)
self._var_offsets = [
[0 for _ in range(len(first_var.shape))] for _ in range(len(variables))
]
for i in range(1, len(variables)):
# Always partition on the first axis. Offsets on other axes are 0.
self._var_offsets[i][0] += (
self._var_offsets[i - 1][0] + variables[i - 1].shape.as_list()[0])
save_slice_info = [v._get_save_slice_info() for v in variables] # pylint: disable=protected-access
if any(slice_info is not None for slice_info in save_slice_info):
raise ValueError(
'`SaveSliceInfo` should not be set for all elements in argument '
'`variables`. `ShardedVariable` will infer `SaveSliceInfo` according '
'to the order of the elements `variables`. '
f'Received save slice info {save_slice_info}')
# We create an uninitialized saving_variable with the full shape, which can
# be later captured in signatures so that the signatures can treat this
# ShardedVariable as one single variable.
self._saving_variable = resource_variable_ops.UninitializedVariable(
shape=self._shape, dtype=self._dtype, name=self._name,
trainable=self._variables[0].trainable,
synchronization=variables_lib.VariableSynchronization.NONE,
aggregation=variables_lib.VariableAggregation.NONE)
def __iter__(self):
"""Return an iterable for accessing the underlying sharded variables."""
return iter(self._variables)
def __getitem__(self, slice_spec):
"""Extracts the specified region as a Tensor from the sharded variable.
The API contract is identical to `Tensor.__getitem__`. Assignment to the
sliced range is not yet supported.
Args:
slice_spec: The arguments to __getitem__, specifying the global slicing of
the sharded variable.
Returns:
The appropriate slice of tensor based on `slice_spec`.
Raises:
IndexError: If a slice index is out of bound.
TypeError: If `spec_spec` contains Tensor.
"""
# TODO(b/177482728): Support tensor input.
# TODO(b/177482728): Support slice assign, similar to variable slice assign.
if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and
slice_spec.dtype == dtypes.bool) or
(isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool)):
tensor = _var_to_tensor(self)
return array_ops.boolean_mask(tensor=tensor, mask=slice_spec)
if not isinstance(slice_spec, (list, tuple)):
slice_spec = (slice_spec,)
s = slice_spec[0]
if isinstance(s, slice):
first_dim_slice_specs = self._decompose_slice_spec(s)
values = []
for i, var in enumerate(self._variables):
if first_dim_slice_specs[i] is not None:
all_dim_slice_spec = (first_dim_slice_specs[i],) + slice_spec[1:]
values.append(var[all_dim_slice_spec])
if s.step is not None and s.step < 0:
values.reverse()
if not values:
return constant_op.constant([],
dtype=self._dtype,
shape=((0,) + self._shape[1:]))
return array_ops.concat(values, axis=0)
elif s is Ellipsis:
return array_ops.concat([var[slice_spec] for var in self._variables],
axis=0)
elif s is array_ops.newaxis:
return array_ops.concat([var[slice_spec[1:]] for var in self._variables],
axis=0)[array_ops.newaxis]
else:
if isinstance(s, ops.Tensor):
raise TypeError(
'ShardedVariable: using Tensor for indexing is not allowed.')
if s < 0:
s += self._shape[0]
if s < 0 or s >= self._shape[0]:
raise IndexError(
f'ShardedVariable: slice index {s} of dimension 0 out of bounds.')
for i in range(len(self._variables)):
if i == len(self._variables) - 1 or (s > self._var_offsets[i][0] and
s < self._var_offsets[i + 1][0]):
return self._variables[i][(s - self._var_offsets[i][0],) +
slice_spec[1:]]
def _decompose_slice_spec(self, slice_spec):
"""Decompose a global slice_spec into a list of per-variable slice_spec.
`ShardedVariable` only supports first dimension partitioning, thus
`slice_spec` must be for first dimension.
Args:
slice_spec: A python `slice` object that specifies the global slicing.
Returns:
A list of python `slice` objects or None specifying the local slicing for
each component variable. None means no slicing.
For example, given component variables:
v0 = [0, 1, 2]
v1 = [3, 4, 5]
v2 = [6, 7, 8, 9]
If `slice_spec` is slice(start=None, stop=None, step=None), we will have:
v0[returned[0]] = [0, 1, 2]
v1[returned[1]] = [3, 4, 5]
v2[returned[2]] = [6, 7, 8, 9]
If `slice_spec` is slice(start=2, stop=8, step=3), we will have:
v0[returned[0]] = [2]
v1[returned[1]] = [5]
returned[2] == None
If `slice_spec` is slice(start=9, stop=3, step=-2), we will have:
returned[0] == None
v1[returned[1]] = [5]
v2[returned[2]] = [9, 7]
"""
if isinstance(slice_spec.start, ops.Tensor) or isinstance(
slice_spec.stop, ops.Tensor) or isinstance(slice_spec.step, ops.Tensor):
raise TypeError(
'ShardedVariable: using Tensor in slice_spec is not allowed. Please '
'file a feature request with the TensorFlow team.')
result = []
# Normalize start, end and stop.
slice_step = slice_spec.step if slice_spec.step is not None else 1
if slice_step == 0:
raise ValueError('slice step cannot be zero')
slice_start = slice_spec.start
if slice_start is None:
slice_start = 0 if slice_step > 0 else self._shape[0] - 1
elif slice_start < 0:
slice_start += self._shape[0]
slice_end = slice_spec.stop
if slice_end is None:
# After the normalization, we no longer interpret negative index, thus
# "-1" conceptually refers to the element before the first one, which
# doesn't exist. This is to ease the decomposition code.
slice_end = self._shape[0] if slice_step > 0 else -1
elif slice_end < 0:
slice_end += self._shape[0]
# To find the local slice_spec of each component variable, we start from
# the start of the global slice, and iterate through each variable.
# When iterating on a variable, we move the cursor (`cur`) to the first
# index that falls into the variable's range, which becomes the start of
# the variable's local slice_spec. The end of the local_spec is determined
# by using whatever is smaller between global slice end and variable range
# end.
cur = slice_start
if slice_step > 0:
for i in range(len(self._var_offsets)):
var_start = self._var_offsets[i][0]
var_end = (
self._var_offsets[i + 1][0]
if i < len(self._var_offsets) - 1 else self._shape[0])
if cur < var_start:
cur += slice_step * int(math.ceil((var_start - cur) / slice_step))
if cur >= var_end or cur >= slice_end:
result.append(None)
else:
start = cur - var_start
end = min(slice_end, var_end) - var_start
result.append(slice(start, end, slice_step))
else: # slice_step < 0
for i in range(len(self._var_offsets) - 1, -1, -1):
var_start = self._var_offsets[i][0]
var_end = (
self._var_offsets[i + 1][0]
if i < len(self._var_offsets) - 1 else self._shape[0])
if cur >= var_end:
cur += slice_step * int(math.ceil((var_end - cur - 1) / slice_step))
if cur < var_start or cur <= slice_end:
result.append(None)
else:
start = cur - var_start
if slice_end >= var_start:
end = slice_end - var_start
else:
end = None # no explicit end: slice until hitting the boundary.
result.append(slice(start, end, slice_step))
result.reverse()
return result
@property
def _type_spec(self):
return ShardedVariableSpec(
*(resource_variable_ops.VariableSpec(v.shape, v.dtype)
for v in self._variables))
@property
def variables(self):
"""The list of `Variable`s that make up the shards of this object."""
if save_context.in_save_context():
return [self._saving_variable]
return self._variables
@property
def name(self):
"""The name of this object. Used for checkpointing."""
return self._name
@property
def dtype(self):
"""The dtype of all `Variable`s in this object."""
return self._dtype
@property
def shape(self):
"""The overall shape, combining all shards along axis `0`."""
return self._shape
def assign(self, value, use_locking=None, name=None, read_value=True):
for i, v in enumerate(self._variables):
v.assign(array_ops.slice(value, self._var_offsets[i], v.shape.as_list()))
return self
def assign_add(self, delta, use_locking=False, name=None, read_value=True):
for i, v in enumerate(self._variables):
v.assign_add(
array_ops.slice(delta, self._var_offsets[i], v.shape.as_list()))
return self
def assign_sub(self, delta, use_locking=False, name=None, read_value=True):
for i, v in enumerate(self._variables):
v.assign_sub(
array_ops.slice(delta, self._var_offsets[i], v.shape.as_list()))
return self
def _decompose_indices(self, indices):
"""Decompose a global 1D indices into a list of per-variable indices."""
if indices.shape.rank != 1:
raise ValueError(
'ShardedVariable: indices must be 1D Tensor for sparse operations. '
f'Received shape: {indices.shape}')
base = self._shape[0] // len(self._variables)
extra = self._shape[0] % len(self._variables)
# Assert that sharding conforms to "div" sharding
expect_first_dim = [base] * len(self._variables)
for i in range(extra):
expect_first_dim[i] = expect_first_dim[i] + 1
actual_first_dim = [v.shape.as_list()[0] for v in self._variables]
if expect_first_dim != actual_first_dim:
raise NotImplementedError(
'scater_xxx ops are not supported in ShardedVariale that does not '
'conform to "div" sharding')
# For index that falls into the partition that has extra 1, assignment is
# `index // (base + 1)` (no less than `(indices - extra) // base`)
# For index that falls into the partition that doesn't has extra 1,
# assignment is `(indices - extra) // base` (no less than
# `indices // (base + 1)`)
#
# Example:
# base = 10, extra = 2, partitions: [0, 11), [11, 22), [22, 32)
# index = 10 -> partition_assigment = 0
# index = 22 -> partition_assiment = 2
partition_assignments = math_ops.maximum(indices // (base + 1),
(indices - extra) // base)
local_indices = array_ops.where(partition_assignments < extra,
indices % (base + 1),
(indices - extra) % base)
# For whatever reason `dynamic_partition` only supports int32
partition_assignments = math_ops.cast(partition_assignments, dtypes.int32)
per_var_indices = data_flow_ops.dynamic_partition(local_indices,
partition_assignments,
len(self._variables))
return per_var_indices, partition_assignments
def _decompose_indexed_slices(self, indexed_slices):
"""Decompose a global `IndexedSlices` into a list of per-variable ones."""
per_var_indices, partition_assignments = self._decompose_indices(
indexed_slices.indices)
per_var_values = data_flow_ops.dynamic_partition(indexed_slices.values,
partition_assignments,
len(self._variables))
return [
indexed_slices_lib.IndexedSlices(
values=per_var_values[i], indices=per_var_indices[i])
for i in range(len(self._variables))
]
# ==================== scatter ops implementations ======================== #
def scatter_add(self, sparse_delta, use_locking=False, name=None):
"""Implements tf.Variable.scatter_add."""
per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
for i, v in enumerate(self._variables):
new_name = None
if name is not None:
new_name = '{}/part_{}'.format(name, i)
v.scatter_add(per_var_sparse_delta[i], name=new_name)
return self
def scatter_div(self, sparse_delta, use_locking=False, name=None):
"""Implements tf.Variable.scatter_div."""
per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
for i, v in enumerate(self._variables):
new_name = None
if name is not None:
new_name = '{}/part_{}'.format(name, i)
v.scatter_div(per_var_sparse_delta[i], name=new_name)
return self
def scatter_max(self, sparse_delta, use_locking=False, name=None):
"""Implements tf.Variable.scatter_max."""
per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
for i, v in enumerate(self._variables):
new_name = None
if name is not None:
new_name = '{}/part_{}'.format(name, i)
v.scatter_max(per_var_sparse_delta[i], name=new_name)
return self
def scatter_min(self, sparse_delta, use_locking=False, name=None):
"""Implements tf.Variable.scatter_min."""
per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
for i, v in enumerate(self._variables):
new_name = None
if name is not None:
new_name = '{}/part_{}'.format(name, i)
v.scatter_min(per_var_sparse_delta[i], name=new_name)
return self
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
"""Implements tf.Variable.scatter_mul."""
per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
for i, v in enumerate(self._variables):
new_name = None
if name is not None:
new_name = '{}/part_{}'.format(name, i)
v.scatter_mul(per_var_sparse_delta[i], name=new_name)
return self
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
"""Implements tf.Variable.scatter_sub."""
per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
for i, v in enumerate(self._variables):
new_name = None
if name is not None:
new_name = '{}/part_{}'.format(name, i)
v.scatter_sub(per_var_sparse_delta[i], name=new_name)
return self
def scatter_update(self, sparse_delta, use_locking=False, name=None):
"""Implements tf.Variable.scatter_update."""
per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
for i, v in enumerate(self._variables):
new_name = None
if name is not None:
new_name = '{}/part_{}'.format(name, i)
v.scatter_update(per_var_sparse_delta[i], name=new_name)
return self
def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
"""Implements tf.Variable.batch_scatter_update."""
per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
for i, v in enumerate(self._variables):
new_name = None
if name is not None:
new_name = '{}/part_{}'.format(name, i)
v.batch_scatter_update(per_var_sparse_delta[i], name=new_name)
return self
# ================== scatter ops implementations END ====================== #
def sparse_read(self, indices, name=None):
"""Implements tf.Variable.sparse_read."""
per_var_indices, _ = self._decompose_indices(indices)
result = []
for i, v in enumerate(self._variables):
new_name = None
if name is not None:
new_name = '{}/part_{}'.format(name, i)
result.append(v.sparse_read(per_var_indices[i], name=new_name))
return array_ops.concat(result, axis=0)
def _gather_saveables_for_checkpoint(self):
"""Return a `Saveable` for each shard. See `Trackable`."""
def _saveable_factory(name=self.name):
"""Creates `SaveableObject`s for this `ShardedVariable`."""
saveables = []
dims = len(self._variables[0].shape)
var_offset = [0 for _ in range(dims)]
for v in self._variables:
save_slice_info = variables_lib.Variable.SaveSliceInfo(
full_name=self.name,
full_shape=self.shape.as_list(),
var_offset=copy.copy(var_offset),
var_shape=v.shape.as_list())
saveables.append(
saveable_object_util.ResourceVariableSaveable(
v, save_slice_info.spec, name))
var_offset[0] += int(v.shape[0])
return saveables
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
def _export_to_saved_model_graph(self, object_map, tensor_map,
options, **kwargs):
"""For implementing `Trackable`."""
resource_list = []
for v in self._variables + [self._saving_variable]:
resource_list.extend(v._export_to_saved_model_graph( # pylint:disable=protected-access
object_map, tensor_map, options, **kwargs))
object_map[self] = ShardedVariable([object_map[self._saving_variable]],
name=self.name)
return resource_list
@property
def _unique_id(self):
# String-replace to ensure uniqueness for checkpoint tracking
return self.variables[0]._unique_id.replace('part_0', 'sharded') # pylint: disable=protected-access
@property
def _distribute_strategy(self):
return self.variables[0]._distribute_strategy # pylint: disable=protected-access
@property
def _shared_name(self):
return self._name
@property
def is_sharded_variable(self):
return True
def numpy(self):
"""Copies the values in this ShardedVariable to a NumPy array.
First converts to a single Tensor using the registered conversion function,
which concatenates the shards, then uses Tensor.numpy() to convert to
a NumPy array.
Returns:
A NumPy array of the same shape and dtype.
"""
return _var_to_tensor(self).numpy()
@tf_export('__internal__.distribute.ShardedVariable', v1=[])
class ShardedVariable(ShardedVariableMixin, composite_tensor.CompositeTensor):
"""A container for `Variables` that should be treated as shards.
Variables that are too large to fit on a single device (e.g., large
embeddings)
may need to be sharded over multiple devices. This class maintains a list of
smaller variables that can be independently stored on separate devices (eg,
multiple parameter servers), and saves and restores those variables as if they
were a single larger variable.
Objects of this class can be saved with a given number of shards and then
restored from a checkpoint into a different number of shards.
Objects of this class can be saved to SavedModel format using
`tf.saved_model.save`. The SavedModel can be used by programs like TF serving
APIs. It is not yet supported to load the SavedModel with
`tf.saved_model.load`.
Since `ShardedVariable` can be saved and then restored to different number of
shards depending on the restore environments, for example, TF serving APIs
would restore to one shard for serving efficiency, when using
`ShardedVariable` in a tf.function, one should generally not assume it has the
same number of shards across save and load.
Sharding is only supported along the first dimension.
>>> class Model(tf.Module):
... def __init__(self):
... self.sharded_variable = ShardedVariable([
... tf.Variable([3.0], dtype=tf.float32),
... tf.Variable([2.0], dtype=tf.float32)
... ])
...
... @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)])
... def fn(self, x):
... return tf.nn.embedding_lookup(self.sharded_variable.variables, x)
...
... @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)])
... def serve_fn(self, x):
... return tf.nn.embedding_lookup(self.sharded_variable.variables, x)
>>>
>>> model = Model()
>>> model.fn(1).numpy()
2.0
>>> tf.saved_model.save(model, export_dir='/tmp/saved_model',
... signatures=model.serve_fn)
"""
@property
def _type_spec(self):
return ShardedVariableSpec(
*(resource_variable_ops.VariableSpec(v.shape, v.dtype)
for v in self._variables))
@classmethod
def _overload_all_operators(cls):
"""Register overloads for all operators."""
for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
if operator == '__getitem__':
continue
cls._overload_operator(operator)
@classmethod
def _overload_operator(cls, operator):
"""Delegate an operator overload to `ops.Tensor`."""
tensor_operator = getattr(ops.Tensor, operator)
def _operator(v, *args, **kwargs):
return tensor_operator(_var_to_tensor(v), *args, **kwargs)
setattr(cls, operator, _operator)
def __tf_experimental_restore_capture__(self, concrete_function,
internal_capture):
# Avoid restoring captures for functions that use ShardedVariable - the
# layer will be recreated during Keras model loading
# TODO(jmullenbach): support loading models with ShardedVariables using
# tf.saved_model.load
return None
def _should_act_as_resource_variable(self):
"""Pass resource_variable_ops.is_resource_variable check."""
return True
def _write_object_proto(self, proto, options):
resource_variable_ops.write_object_proto_for_resource_variable(
self._saving_variable, proto, options, enforce_naming=False)
def _var_to_tensor(var, dtype=None, name=None, as_ref=False):
"""Converts a `ShardedVariable` to a `Tensor`."""
del name
if dtype is not None and not dtype.is_compatible_with(var.dtype):
raise ValueError(
'Incompatible type conversion requested to type {!r} for variable '
'of type {!r}'.format(dtype.name, var.dtype.name))
if as_ref:
raise NotImplementedError(
"ShardedVariable doesn't support being used as a reference.")
# We use op dispatch mechanism to override embedding_lookup ops when called
# with ShardedVariable. This requires embedding_lookup ops to raise TypeError
# when called with ShardedVariable. However since ShardedVariable can be
# converted to a tensor via concat, embedding_lookup ops would silently
# do the convertion and never raise a TypeError. To be able to properly
# raise a TypeError, namescope is used to detect if this method is called
# within a embedding_lookup op.
# NOTE: This doesn't work in eager mode since op namescope is always cleared
# in eager. This also breaks if user sets the name of embedding_lookup op
# with something that doesn't contain str "embedding_lookup".
#
# TODO(chenkai): Find a more robust way to do this, which should not rely
# on namescope.
if 'embedding_lookup' in ops.get_name_scope():
raise TypeError('Converting ShardedVariable to tensor in embedding lookup'
' ops is disallowed.')
return array_ops.concat(var.variables, axis=0)
# Register a conversion function which reads the value of the variable,
# allowing instances of the class to be used as tensors.
tensor_conversion_registry.register_tensor_conversion_function(
ShardedVariable, _var_to_tensor)
ShardedVariable._overload_all_operators() # pylint: disable=protected-access
# Override the behavior of embedding_lookup(sharded_variable, ...)
@dispatch.dispatch_for_types(embedding_ops.embedding_lookup, ShardedVariable)
def embedding_lookup(params,
ids,
partition_strategy='mod',
name=None,
validate_indices=True,
max_norm=None):
if isinstance(params, list):
params = params[0]
return embedding_ops.embedding_lookup(params.variables, ids,
partition_strategy, name,
validate_indices, max_norm)