-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcross_device_ops.py
1395 lines (1165 loc) · 56 KB
/
cross_device_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Classes for different algorithms of reduction and broadcasting."""
import collections
import copy
import multiprocessing.dummy
import multiprocessing.pool
import threading
import six
from tensorflow.python.client import device_lib
from tensorflow.python.distribute import collective_util
from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import ps_values
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import tpu_values
from tensorflow.python.distribute import values as value_lib
from tensorflow.python.distribute import values_util
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import kernels
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
from tensorflow.tools.docs import doc_controls
def check_destinations(destinations):
"""Checks whether `destinations` is not empty.
Args:
destinations: a `DistributedValues`, variable, or string object.
Returns:
Boolean which is True if `destinations` is not empty.
"""
# Calling bool() on a ResourceVariable is not allowed.
if isinstance(destinations,
(resource_variable_ops.BaseResourceVariable, ops.Tensor)):
return bool(destinations.device)
return bool(destinations)
def validate_destinations(destinations):
"""Validates the `destination` is one of expected types."""
if not isinstance(
destinations,
(value_lib.DistributedValues, ops.Tensor, indexed_slices.IndexedSlices,
ps_values.AggregatingVariable, six.string_types,
tpu_values.TPUMirroredVariable
)) and not resource_variable_ops.is_resource_variable(destinations):
raise ValueError("destinations must be one of a `DistributedValues` object,"
" a tf.Variable object, or a device string.")
if not check_destinations(destinations):
raise ValueError("destinations can not be empty")
def reduce_non_distributed_value(reduce_op,
value,
destinations,
num_replicas_in_graph,
canonicalize_devices=True):
"""Reduce a non-DistributedValue `value` to `destinations`."""
if isinstance(value, value_lib.DistributedValues):
raise ValueError("You are passing a `DistributedValues` to "
"`reduce_non_distributed_value`, which is not allowed.")
# If the same value is present on all replicas then the PerReplica value will
# be a single value. We also handle the case when `value` is a single value
# and equal to 0.
# TODO:(b/138823479): handle the tensor value properly.
if not tensor_util.is_tf_type(value) and value == 0:
return 0
# If there is only a single value and the reduce op is MEAN,
# that value should be on all destinations.
if reduce_op == reduce_util.ReduceOp.MEAN:
return value
elif num_replicas_in_graph != 1:
# We do not support a reduce op of SUM if the value is the same across
# all replicas. We call this as part of assign functions for
# MirroredVariables and summing up identical values across replicas is not
# clearly defined.
raise ValueError("A non-DistributedValues value %s cannot be reduced with "
"the given reduce op %s." % (value, reduce_op))
else:
validate_destinations(destinations)
return simple_broadcast(
value, destinations, canonicalize_devices=canonicalize_devices)
def _make_tensor_into_per_replica(input_tensor):
"""Converts a single tensor into a PerReplica object."""
if isinstance(input_tensor, value_lib.DistributedValues):
return input_tensor
# If input is not a Tensor, convert it to a Tensor first.
if not tensor_util.is_tensor(input_tensor):
input_tensor = ops.convert_to_tensor(input_tensor)
if hasattr(input_tensor, "device"):
return value_lib.PerReplica((input_tensor,))
raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object "
"because it doesn't have device set.")
def _normalize_value_destination_pairs(value_destination_pairs):
"""Converts each tensor into a PerReplica object in the input list."""
result = []
value_destination_pairs = list(value_destination_pairs)
if not isinstance(value_destination_pairs, (list, tuple)):
raise ValueError("`value_destination_pairs` should be a list or tuple")
for pair in value_destination_pairs:
if not isinstance(pair, tuple):
raise ValueError(
"Each element of `value_destination_pairs` should be a tuple.")
if len(pair) != 2:
raise ValueError("Each element of `value_destination_pairs` should be a "
"tuple of size 2.")
per_replica = _make_tensor_into_per_replica(pair[0])
result.append((per_replica, pair[1]))
return result
def _validate_value_destination_pairs(value_destination_pairs):
"""Validates value_destination_pairs are valid."""
# TODO(yuefengz): raise exceptions instead of returning False.
if not value_destination_pairs: return False
if not isinstance(value_destination_pairs, (list, tuple)): return False
if not all(isinstance(pair, tuple) for pair in value_destination_pairs):
return False
if not all(isinstance(v[0], value_lib.PerReplica)
for v in value_destination_pairs):
return False
return True
# TODO(yuefengz): consider calling this function in the caller of
# CrossDeviceOps.
def get_devices_from(destinations, canonicalize_devices=True):
if isinstance(destinations, value_lib.DistributedValues):
return destinations._devices # pylint: disable=protected-access
if canonicalize_devices:
if isinstance(destinations, six.string_types):
return (device_util.resolve(destinations),)
return (device_util.resolve(destinations.device),)
# Let placer canonicalize and resolve destination devices.
if isinstance(destinations, six.string_types):
return (device_util.canonicalize_without_job_and_task(destinations),)
return (device_util.canonicalize_without_job_and_task(destinations.device),)
def _devices_match(left, right, canonicalize_devices=True):
return left is right or set(get_devices_from(
left, canonicalize_devices)) == set(
get_devices_from(right, canonicalize_devices))
def _all_devices_match(value_destination_pairs, canonicalize_devices=True):
if not all(
_devices_match(v, d, canonicalize_devices)
for v, d in value_destination_pairs):
return False
if not all(
_devices_match(v, value_destination_pairs[0][0], canonicalize_devices)
for v, _ in value_destination_pairs[1:]):
return False
return True
def simple_broadcast(value,
destinations,
always_mirrored=False,
canonicalize_devices=True):
"""Broadcast `value` to `destinations` using simple copies."""
devices = get_devices_from(destinations, canonicalize_devices)
if len(devices) == 1 and not always_mirrored:
return cross_device_utils.copy_tensor_or_indexed_slices_to_device(
value, devices[0])
else:
value_updates = []
for d in devices:
value_updates.append(
cross_device_utils.copy_tensor_or_indexed_slices_to_device(value, d))
return distribute_utils.regroup(value_updates,
wrap_class=value_lib.Mirrored)
def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
reduce_op):
"""Reduces the value by accumulation_fn and reduce_op."""
all_values = per_replica_value.values
if not all_values:
raise ValueError("`per_replica_value` must be non-empty")
count = len(all_values)
with ops.device(reduce_to_device):
with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
reduced = cross_device_utils.aggregate_tensors_or_indexed_slices(
all_values, accumulation_fn)
if reduce_op == reduce_util.ReduceOp.MEAN:
reduced = cross_device_utils.divide_by_n_tensors_or_indexed_slices(
reduced, count)
elif reduce_op != reduce_util.ReduceOp.SUM:
raise ValueError("`reduce_op` must be Reduce.SUM or Reduce.MEAN.")
return reduced
def _simple_gather(per_replica_value, reduce_to_device, axis):
"""Concatenate all values in the DistributedValues input and return."""
all_values = per_replica_value.values
if not all_values:
raise ValueError("`per_replica_value` must be non-empty")
with ops.device(reduce_to_device):
with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
gathered = array_ops.concat(all_values, axis)
return gathered
@tf_export("distribute.CrossDeviceOps")
class CrossDeviceOps(object):
"""Base class for cross-device reduction and broadcasting algorithms.
The main purpose of this class is to be passed to
`tf.distribute.MirroredStrategy` in order to choose among different cross
device communication implementations. Prefer using the methods of
`tf.distribute.Strategy` instead of the ones of this class.
Implementations:
* `tf.distribute.ReductionToOneDevice`
* `tf.distribute.NcclAllReduce`
* `tf.distribute.HierarchicalCopyAllReduce`
"""
def __init__(self):
self._canonicalize_devices = True
pass
@property
def _num_between_graph_workers(self):
# Returns 1 by default, the value may be overridden by sub classes.
return 1
def reduce(self, reduce_op, per_replica_value, destinations, options=None):
"""Reduce `per_replica_value` to `destinations`.
See `tf.distribute.StrategyExtended.reduce_to`. This can only be called in
the cross-replica context.
Args:
reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
combined.
per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
like object.
destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
`tf.Tensor` alike object, or a device string. It specifies the devices
to reduce to. To perform an all-reduce, pass the same to `value` and
`destinations`. Note that if it's a `tf.Variable`, the value is reduced
to the devices of that variable, and this method doesn't update the
variable.
options: a `tf.distribute.experimental.CommunicationOptions`. See
`tf.distribute.experimental.CommunicationOptions` for details.
Returns:
A `tf.Tensor` or `tf.distribute.DistributedValues`.
Raises:
ValueError: if per_replica_value can't be converted to a
`tf.distribute.DistributedValues` or if destinations is not a string,
`tf.Variable` or `tf.distribute.DistributedValues`.
"""
if options is None:
options = collective_util.Options()
per_replica_value = _make_tensor_into_per_replica(per_replica_value)
validate_destinations(destinations)
# Shortcut if `per_replica_value` only contains one value.
if self._num_between_graph_workers == 1 and len(
per_replica_value.values) == 1 and _devices_match(
per_replica_value, destinations, self._canonicalize_devices):
with ops.device(per_replica_value.values[0].device):
v = array_ops.identity(per_replica_value.values[0])
return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored)
if options is None:
options = collective_util.Options()
return self.reduce_implementation(reduce_op, per_replica_value,
destinations, options)
def _gather(self, per_replica_value, destinations, axis, options=None):
"""Gather `per_replica_value` to `destinations`.
Args:
per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
like object.
destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
`tf.Tensor` alike object, or a device string. It specifies the devices
to gather to. To perform an all-gather, pass the same to `value` and
`destinations`. Note that if it's a `tf.Variable`, the value is gathered
to the devices of that variable, and this method doesn't update the
variable.
axis: specifies the dimension to gather along within each replica's
tensor.
options: a `tf.distribute.experimental.CommunicationOptions`. See
`tf.distribute.experimental.CommunicationOptions` for details.
Returns:
A `tf.Tensor` or `tf.distribute.DistributedValues`
Raises:
ValueError: if per_replica_value can't be converted to a
`tf.distribute.DistributedValues` or if destinations is not a string,
`tf.Variable` or `tf.distribute.DistributedValues`.
"""
if isinstance(per_replica_value, indexed_slices.IndexedSlices):
raise NotImplementedError("gather/all_gather does not support "
"IndexedSlices")
if options is None:
options = collective_util.Options()
per_replica_value = _make_tensor_into_per_replica(per_replica_value)
validate_destinations(destinations)
# Shortcut if `per_replica_value` only contains one value.
if self._num_between_graph_workers == 1 and len(
per_replica_value.values) == 1 and _devices_match(
per_replica_value, destinations, self._canonicalize_devices):
with ops.device(per_replica_value.values[0].device):
v = array_ops.identity(per_replica_value.values[0])
return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored)
return self._gather_implementation(per_replica_value, destinations, axis,
options)
def _gather_implementation(self, per_replica_value, destinations, axis,
options):
"""Implementation of `gather` method of `tf.distribute.CrossDeviceOps`.
Overriding this method is useful for subclass implementers.
Args:
per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
like object.
destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
`tf.Tensor` alike object, or a device string. It specifies the devices
to gather to. To perform an all-gather, pass the same to `value` and
`destinations`. Note that if it's a `tf.Variable`, the value is gathered
to the devices of that variable, this method doesn't update the
variable.
axis: specifies the dimension to gather along within each replica's
tensor.
options: a `tf.distribute.experimental.CommunicationOptions`. See
`tf.distribute.experimental.CommunicationOptions` for details.
Returns:
A `tf.Tensor` or `tf.distribute.DistributedValues`.
Raises:
ValueError: if per_replica_value can't be converted to a
`tf.distribute.DistributedValues` or if destinations is not a string,
`tf.Variable` or `tf.distribute.DistributedValues`.
"""
raise NotImplementedError(
"_gather method must be implemented in descendants.")
def batch_reduce(self, reduce_op, value_destination_pairs, options=None):
"""Reduce values to destinations in batches.
See `tf.distribute.StrategyExtended.batch_reduce_to`. This can only be
called in the cross-replica context.
Args:
reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
combined.
value_destination_pairs: a sequence of (value, destinations) pairs. See
`tf.distribute.CrossDeviceOps.reduce` for descriptions.
options: a `tf.distribute.experimental.CommunicationOptions`. See
`tf.distribute.experimental.CommunicationOptions` for details.
Returns:
A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair
in `value_destination_pairs`.
Raises:
ValueError: if `value_destination_pairs` is not an iterable of
tuples of `tf.distribute.DistributedValues` and destinations.
"""
if options is None:
options = collective_util.Options()
# TODO(yuefengz): if destinations are different, split into several
# `_batch_reduce` invocations.
if not _validate_value_destination_pairs(value_destination_pairs):
# If the first element of each pair is a tensor, we try to turn it into a
# PerReplica object.
value_destination_pairs = _normalize_value_destination_pairs(
value_destination_pairs)
for _, d in value_destination_pairs:
validate_destinations(d)
# Shortcut all PerReplica objects only contain one value.
if self._num_between_graph_workers == 1 and _all_devices_match(
value_destination_pairs, self._canonicalize_devices) and len(
value_destination_pairs[0][0].values) == 1:
return [
distribute_utils.regroup(v.values, wrap_class=value_lib.Mirrored)
for v, _ in value_destination_pairs
]
if options is None:
options = collective_util.Options()
return self.batch_reduce_implementation(reduce_op, value_destination_pairs,
options)
def broadcast(self, tensor, destinations):
"""Broadcast `tensor` to `destinations`.
This can only be called in the cross-replica context.
Args:
tensor: a `tf.Tensor` like object. The value to broadcast.
destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
`tf.Tensor` alike object, or a device string. It specifies the devices
to broadcast to. Note that if it's a `tf.Variable`, the value is
broadcasted to the devices of that variable, this method doesn't update
the variable.
Returns:
A `tf.Tensor` or `tf.distribute.DistributedValues`.
"""
validate_destinations(destinations)
return self.broadcast_implementation(tensor, destinations)
@doc_controls.for_subclass_implementers
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
options):
"""Implementation of `reduce`.
Overriding this method is useful for subclass implementers.
Args:
reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
combined.
per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
like object.
destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
`tf.Tensor` alike object, or a device string. It specifies the devices
to reduce to. To perform an all-reduce, pass the same to `value` and
`destinations`. Note that if it's a `tf.Variable`, the value is reduced
to the devices of that variable, this method doesn't update the
variable.
options: a `tf.distribute.experimental.CommunicationOptions`. See
`tf.distribute.experimental.CommunicationOptions` for details.
Returns:
A `tf.Tensor` or `tf.distribute.DistributedValues`.
Raises:
ValueError: if per_replica_value can't be converted to a
`tf.distribute.DistributedValues` or if destinations is not a string,
`tf.Variable` or `tf.distribute.DistributedValues`.
"""
raise NotImplementedError(
"_reduce method must be implemented in descendants.")
@doc_controls.for_subclass_implementers
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
options):
"""Implementation of `batch_reduce`.
Overriding this method is useful for subclass implementers.
Args:
reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
combined.
value_destination_pairs: a sequence of (value, destinations) pairs. See
`reduce` for descriptions.
options: a `tf.distribute.experimental.CommunicationOptions`. See
`tf.distribute.experimental.CommunicationOptions` for details.
Returns:
A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair
in `value_destination_pairs`.
Raises:
ValueError: if `value_destination_pairs` is not an iterable of
tuples of `tf.distribute.DistributedValues` and destinations.
"""
raise NotImplementedError(
"batch_reduce_implementation method must be implemented in descendants."
)
@doc_controls.for_subclass_implementers
def broadcast_implementation(self, tensor, destinations):
"""Implementation of `broadcast`.
Args:
tensor: a `tf.Tensor` like object. The value to broadcast.
destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
`tf.Tensor` alike object, or a device string. It specifies the devices
to broadcast to.
`destinations`. Note that if it's a `tf.Variable`, the value is
broadcasted to the devices of that variable, this method doesn't update
the variable.
Returns:
A `tf.Tensor` or `tf.distribute.DistributedValues`.
"""
return simple_broadcast(
tensor,
destinations,
always_mirrored=True,
canonicalize_devices=self._canonicalize_devices)
# ========================== Collective APIs ================================
#
# Different than `reduce`, `batch_reduce` and `broadcast` which must be called
# in cross-replcia context, collective APIs are to be called in replica
# context.
def _all_reduce(self, reduce_op, value, replica_id, options):
"""All-reduce the `value` across all replicas so that all get the result.
`value` can be a nested structure of tensors or `IndexedSlices`. The
implementation should generally batch the all-reduces when possible.
`options` can be set to hint the batching behavior.
This API must be called in a replica context.
Args:
reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
be combined.
value: Value to be reduced. A tensor or a nested structure of tensors or
`IndexedSlices`.
replica_id: An interger indicating the id of the replica where this
all_reduce is called under. This is the local replica id that ranges
from 0 to len(local_devices) - 1.
options: A `tf.distribute.experimental.CommunicationOptions`.
Returns:
A tensor/IndexedSlices or a nested strucutre of tensors/IndexedSlices with
the reduced values. The structure is the same as `value`.
"""
raise NotImplementedError("_all_reduce must be implemented in descendants.")
@tf_export("distribute.ReductionToOneDevice")
class ReductionToOneDevice(CrossDeviceOps):
"""A CrossDeviceOps implementation that copies values to one device to reduce.
This implementation always copies values to one device to reduce them, then
broadcast reduced values to the destinations. It doesn't support efficient
batching.
Here is how you can use `ReductionToOneDevice` in
`tf.distribute.MirroredStrategy`:
```
strategy = tf.distribute.MirroredStrategy(
cross_device_ops=tf.distribute.ReductionToOneDevice())
```
"""
def __init__(self, reduce_to_device=None, accumulation_fn=None):
"""Initializes with a device to reduce to and a way to accumulate.
Args:
reduce_to_device: the intermediate device to reduce to. If None, reduce
to the first device in `destinations` of the `reduce` method.
accumulation_fn: a function that does accumulation. If None,
`tf.math.add_n` is used.
"""
self.reduce_to_device = reduce_to_device
self.accumulation_fn = accumulation_fn or math_ops.add_n
super(ReductionToOneDevice, self).__init__()
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
options):
del options # Unused.
if check_destinations(destinations):
devices = get_devices_from(destinations, self._canonicalize_devices)
else:
devices = get_devices_from(per_replica_value, self._canonicalize_devices)
reduce_to_device = self.reduce_to_device or devices[0]
logging.log_first_n(
logging.INFO,
"Reduce to %s then broadcast to %r." % (reduce_to_device, devices), 10)
reduced = _simple_reduce(per_replica_value, reduce_to_device,
self.accumulation_fn, reduce_op)
return self.broadcast(reduced, destinations)
def _gather_implementation(self, per_replica_value, destinations, axis,
options):
del options # Unused.
if check_destinations(destinations):
devices = get_devices_from(destinations, self._canonicalize_devices)
else:
devices = get_devices_from(per_replica_value, self._canonicalize_devices)
reduce_to_device = self.reduce_to_device or devices[0]
logging.log_first_n(
logging.INFO,
"Gather to %s then broadcast to %r." % (reduce_to_device, devices), 10)
gathered = _simple_gather(per_replica_value, reduce_to_device, axis)
return self.broadcast(gathered, destinations)
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
options):
return [
self.reduce_implementation(
reduce_op, t, destinations=v, options=options)
for t, v in value_destination_pairs
]
def _group_value_by_device(per_replica_values):
"""Group values into sublists by their devices.
This grouping is needed to call the all-reduce library because it expects a
list of the following form:
[[(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ...],
[(grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ...],
[(grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ...],
...
]
Args:
per_replica_values: a list of PerReplica objects.
Returns:
a list of lists, each sublist has components for its corresponding device of
PerReplica objects, paired with a None.
"""
destinations = per_replica_values[0]._devices # pylint: disable=protected-access
grouped = [[] for _ in range(len(destinations))]
for per_replica_value in per_replica_values:
# pylint: disable=protected-access
for i, v in enumerate(per_replica_value.values):
assert per_replica_value._devices == destinations
grouped[i].append((v, None))
return grouped
def _ungroup_and_make_mirrored(grouped_reduced,
destinations,
reduce_op,
num_between_graph_workers=1):
"""Ungroup results from all-reduce and make Mirrored objects.
Each all-reduce result will be divided by the number of destinations before
Mirrored objects are created if reduce_op is "mean".
Args:
grouped_reduced: a list of lists, each sublist has components for each
device, paired with a None. It is the result from
cross_device_utils.aggregate_gradients_using*.
destinations: a value to colocate the result with.
reduce_op: Indicates how values will be aggregated. Accepted values
are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`.
num_between_graph_workers: number of workers in the between-graph
replication.
Returns:
a list of Mirrored objects.
"""
num_replicas = len(get_devices_from(destinations)) * num_between_graph_workers
index = [[] for _ in range(len(grouped_reduced[0]))]
for per_replica_reduced in grouped_reduced:
for i, (v, _) in enumerate(per_replica_reduced):
if reduce_op == reduce_util.ReduceOp.MEAN:
with ops.device(v.device):
index[i].append(v / num_replicas)
else:
index[i].append(v)
return [distribute_utils.regroup(
v, wrap_class=value_lib.Mirrored) for v in index]
class _ConcatAndSplitPacker(object):
"""Concatenate and split tensors for reduction."""
def __init__(self, num_packs=1):
"""Initialize the _ConcatAndSplitPacker object.
Args:
num_packs: specifies the number of split packs that will be
formed.
Raises:
ValueError: if num_packs is not greater than 0.
"""
if num_packs <= 0:
raise ValueError("num_packs must be greater than zero.")
self.num_packs = num_packs
def pack(self, grouped_grads_and_vars):
"""Pack tensors."""
self.grouped_grads_and_vars = grouped_grads_and_vars
self.all_device_shapes = []
self.all_device_sizes = []
device_grad_packs = []
for device_grads_and_vars in grouped_grads_and_vars:
with ops.colocate_with(device_grads_and_vars[0][0]):
# Flatten all the grads.
flat_grads = [
array_ops.reshape(g, [-1]) for g, _ in device_grads_and_vars
]
# Remember the original shape of all the grads.
device_shapes = [array_ops.shape(g) for g, _ in device_grads_and_vars]
# Remember the original sizes of all the grads.
device_sizes = [array_ops.size(g) for g, _ in device_grads_and_vars]
# Concat all the flat grads into a big flat tensor.
concat_grads = array_ops.concat(flat_grads, 0)
# Split the big tensor into num_splits packs. In cases where the
# total size is not divisible num_splits, the last pack gets
# more elements.
# TODO(zhengxq): it is also possible to optimize away all the concat
# as well.
num_splits = self.num_packs
# The array_ops.size function will sometimes remove static shapes. So if
# all gradient shapes are defined, we use another method to get the
# total size.
# TODO(yuefengz): move this logic to array_ops.size.
if all(g.shape.is_fully_defined() for g, _ in device_grads_and_vars):
total_grad_size = sum(
[g.shape.num_elements() for g, _ in device_grads_and_vars])
else:
total_grad_size = array_ops.size(concat_grads)
split_size = total_grad_size // num_splits
split_size_last = total_grad_size - split_size * (num_splits - 1)
split_sizes = [split_size] * (num_splits - 1) + [split_size_last]
grad_packs = array_ops.split(concat_grads, split_sizes)
# Ready to aggregate the repacked gradients, with fake variables.
# TODO(zhengxq): It is hacky to have to use fake variables.
# We should remove the need for variables in
# aggregate_gradients_using*.
device_grad_packs.append(zip(grad_packs, [None] * num_splits))
self.all_device_shapes.append(device_shapes)
self.all_device_sizes.append(device_sizes)
return device_grad_packs
def unpack(self, summed_device_grad_packs):
"""Reverse the pack."""
aggregated_device_grads = []
for (summed_device_grad_packs,
device_grads_and_vars, device_shapes, device_sizes) in zip(
summed_device_grad_packs, self.grouped_grads_and_vars,
self.all_device_shapes, self.all_device_sizes):
# pylint: enable=line-too-long
# Reverse the packing operations in the previous steps. Form the
# summed gradients back into their original shapes.
with ops.colocate_with(summed_device_grad_packs[0][0]):
# Form a list of the summed grad packs.
device_grad_packs = [g for g, _ in summed_device_grad_packs]
# Concat them back into a big flat tensor.
device_grads_concat = array_ops.concat(device_grad_packs, 0)
# Split the tensors back into their original sizes.
grads_with_sizes = array_ops.split(device_grads_concat, device_sizes)
# Reshape the tensors back into their original shapes.
grads_with_shapes = [
array_ops.reshape(grad, shape)
for shape, grad in zip(device_shapes, grads_with_sizes)
]
# Form the list with the original list of variables.
summed_device_grads = [
(g, v) for g, (_, v) in zip(grads_with_shapes,
device_grads_and_vars)
]
aggregated_device_grads.append(summed_device_grads)
return aggregated_device_grads
def _pack_tensors(device_grads, num_packs=0):
"""Pack tensors if specified."""
if num_packs > 0:
tensor_packer = _ConcatAndSplitPacker(num_packs)
device_grad_packs = tensor_packer.pack(device_grads)
else:
tensor_packer = None
device_grad_packs = device_grads
return device_grad_packs, tensor_packer
def _unpack_tensors(reduced, tensor_packer=None):
"""Unpack tensors if they are packed before all-reduce."""
if tensor_packer:
return tensor_packer.unpack(reduced)
return reduced
class AllReduceCrossDeviceOps(CrossDeviceOps):
"""All-reduce implementation of CrossDeviceOps.
It performs all-reduce when applicable using NCCL or hierarchical copy. For
the batch API, tensors will be repacked or aggregated for more efficient
cross-device transportation.
For reduces that are not all-reduce, it falls back to
`tf.distribute.ReductionToOneDevice`.
"""
def __init__(self, all_reduce_alg="nccl", num_packs=1):
"""Initializes the object.
Args:
all_reduce_alg: the all-reduce algorithm to use, currently only "nccl" or
"hierarchical_copy" are supported.
num_packs: a non-negative integer. The number of packs to split values
into. If zero, no packing will be done.
"""
self._all_reduce_alg = all_reduce_alg
self._num_packs = num_packs
self._simple_cross_replica_ops = ReductionToOneDevice()
super(AllReduceCrossDeviceOps, self).__init__()
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
options):
del options # Unused.
# To use NCCL or all-reduce, source and destination devices should match,
# and none of the devices should be CPU.
if (_devices_match(per_replica_value, destinations) and
not any("cpu" in d.lower() for d in get_devices_from(destinations))):
return self._batch_all_reduce(reduce_op, [per_replica_value])[0]
else:
return self._simple_cross_replica_ops.reduce(reduce_op, per_replica_value,
destinations)
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
options):
if _all_devices_match(value_destination_pairs):
return self._batch_all_reduce(reduce_op,
[v[0] for v in value_destination_pairs])
else:
return [
self.reduce_implementation(reduce_op, value, dest, options)
for value, dest in value_destination_pairs
]
def _batch_all_reduce(self, reduce_op, per_replica_values):
"""All-reduce algorithm in a batch."""
dense_values, dense_indices, sparse_values, sparse_indices = (
cross_device_utils.split_by_sparsity(per_replica_values))
if dense_values:
dense_results = self._do_batch_all_reduce(reduce_op, dense_values)
else:
dense_results = []
if sparse_values:
sparse_results = self._do_batch_all_reduce_sparse(reduce_op,
sparse_values)
else:
sparse_results = []
return cross_device_utils.stitch_values(((dense_results, dense_indices),
(sparse_results, sparse_indices)))
def _do_batch_all_reduce(self, reduce_op, dense_values):
"""Run batch all-reduces."""
logging.log_first_n(
logging.INFO,
"batch_all_reduce: %d all-reduces with algorithm = %s, num_packs = %d" %
(len(dense_values), self._all_reduce_alg, self._num_packs), 10)
destinations = dense_values[0]._devices # pylint: disable=protected-access
grouped = _group_value_by_device(dense_values)
# device_grad_packs:
# [[(t0_gpu0, None), (t1_gpu0, None)], [(t0_gpu1, None), (t1_gpu1, None)]]
device_grad_packs, tensor_packer = _pack_tensors(grouped, self._num_packs)
# The actual aggregation of the repacked gradients. Note that they are
# sharded among different aggregation trees. So it is important to strike
# the balance on num_splits.
if self._all_reduce_alg == "nccl":
# TODO(yuefengz): merge this into the all-reduce library.
reduced = cross_device_utils.aggregate_gradients_using_nccl(
device_grad_packs)
else:
# TODO(yuefengz): check that gpu ids in `destinations` are in ascending
# order.
reduced = (
cross_device_utils.aggregate_gradients_using_hierarchical_copy(
destinations, device_grad_packs))
reduced = _unpack_tensors(reduced, tensor_packer)
return _ungroup_and_make_mirrored(reduced, dense_values[0], reduce_op)
def _do_batch_all_reduce_sparse(self, reduce_op, sparse_values):
"""Run batch all-reduce for sparse values."""
logging.log_first_n(
logging.WARN,
"Efficient allreduce is not supported for %d IndexedSlices" %
len(sparse_values), 10)
# Use `sparse_values` as destinations to do all-reduces. It is effectively
# an allgather under the hood but not an efficient one.
return self._simple_cross_replica_ops.batch_reduce(
reduce_op, zip(sparse_values, sparse_values))
def _gather_implementation(self, per_replica_value, destinations, axis,
options):
logging.log_first_n(
logging.WARN,
"gather/all_gather with NCCL or HierarchicalCopy is not supported. "
"Falling back to gather on one device and then broadcast. We're working"
" on a more efficient implementation.", 3)
return ReductionToOneDevice()._gather(per_replica_value, destinations, axis, # pylint: disable=protected-access
options)
# For compatibility with code using the old name of `AllReduceCrossDeviceOps`.
AllReduceCrossTowerOps = AllReduceCrossDeviceOps
AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple",
"alg shards limit")
@tf_export("distribute.NcclAllReduce")
class NcclAllReduce(AllReduceCrossDeviceOps):
"""NCCL all-reduce implementation of CrossDeviceOps.
It uses Nvidia NCCL for all-reduce. For the batch API, tensors will be
repacked or aggregated for more efficient cross-device transportation.
For reduces that are not all-reduce, it falls back to
`tf.distribute.ReductionToOneDevice`.
Here is how you can use `NcclAllReduce` in `tf.distribute.MirroredStrategy`:
```
strategy = tf.distribute.MirroredStrategy(
cross_device_ops=tf.distribute.NcclAllReduce())
```
"""
def __init__(self, num_packs=1):
"""Initializes the object.
Args:
num_packs: a non-negative integer. The number of packs to split values
into. If zero, no packing will be done.
Raises:
ValueError: if `num_packs` is negative.
"""
if num_packs < 0:
raise ValueError(
"NCCL all-reduce requires num_packs >= 0, but {} is specified".format(
num_packs))
super(NcclAllReduce, self).__init__(
all_reduce_alg="nccl", num_packs=num_packs)
@tf_export("distribute.HierarchicalCopyAllReduce")
class HierarchicalCopyAllReduce(AllReduceCrossDeviceOps):
"""Hierarchical copy all-reduce implementation of CrossDeviceOps.
It reduces to one GPU along edges in some hierarchy and broadcasts back to
each GPU along the same path. For the batch API, tensors will be repacked or
aggregated for more efficient cross-device transportation.