-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathstrategy_gather_test.py
669 lines (538 loc) · 26.8 KB
/
strategy_gather_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for common methods in strategy classes."""
from absl.testing import parameterized
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import central_storage_strategy
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import test_util
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.distribute.collective_all_reduce_strategy import CollectiveAllReduceStrategy
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.platform import test
from tensorflow.python.util import nest
@tf_test_util.with_eager_op_as_function
@combinations.generate(
combinations.combine(
strategy=[
strategy_combinations.default_strategy,
strategy_combinations.one_device_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.central_storage_strategy_with_two_gpus,
strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_one_gpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations.mirrored_strategy_with_two_cpus,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
],
mode=['eager'],
pure_eager=[True, False]) + combinations.combine(
strategy=[
strategy_combinations.tpu_strategy,
strategy_combinations.tpu_strategy_packed_var,
strategy_combinations.tpu_strategy_one_step,
strategy_combinations.cloud_tpu_strategy,
],
mode=['eager'],
pure_eager=[False]))
class GatherTest(test.TestCase, parameterized.TestCase):
def _gather_same_shape_and_verify(self, value_on_replica, axis, pure_eager,
strategy):
distributed_values = strategy.experimental_distribute_values_from_function(
lambda _: array_ops.identity(value_on_replica))
def run():
return strategy.gather(distributed_values, axis=axis)
if not pure_eager:
run = def_function.function(run)
all_results = [
value_on_replica for _ in range(strategy.num_replicas_in_sync)
]
expected_result = array_ops.concat(all_results, axis=axis)
self.assertAllEqual(expected_result, run().numpy())
def testGatherPerReplicaDense1D0Axis(self, strategy, pure_eager):
"""A DistributedValues object with two tensors of shape [3] on each replica gathers to a tensor of [6]."""
single_value = constant_op.constant([1, 2, 3])
axis = 0
self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy)
def testGatherPerReplicaDense2D0Axis(self, strategy, pure_eager):
"""A DistributedValues object with two tensors of [1, 3] on each replica gathers along 0th dim to a tensor of [2, 3]."""
single_value = constant_op.constant([[1, 2, 3]])
axis = 0
self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy)
def testGatherPerReplicaDense2D1Axis(self, strategy, pure_eager):
"""A DistributedValues object with two tensors of [1, 3] on each replica gathers along 1st dim to a tensor of [1, 6]."""
single_value = constant_op.constant([[1, 2, 3]])
axis = 1
self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy)
def testGatherPerReplicaDense3D0Axis(self, strategy, pure_eager):
"""A DistributedValues object with two tensors of [1, 2, 2] on each replica gathers along 0th dim to a tensor of [2, 2, 2]."""
single_value = constant_op.constant([[[1, 2], [1, 2]]])
axis = 0
self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy)
def testGatherPerReplicaDense3D1Axis(self, strategy, pure_eager):
"""A DistributedValues object with two tensors of [1, 2, 2] on each replica gathers along 1nd dimension to a tensor of [1, 4, 2]."""
single_value = constant_op.constant([[[1, 2], [1, 2]]])
axis = 1
self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy)
def testGatherPerReplicaDense3D2Axis(self, strategy, pure_eager):
"""A DistributedValues object with two tensors of [1, 2, 2] on each replica gathers along 2nd dimension to a tensor of [1, 2, 4]."""
single_value = constant_op.constant([[[1, 2], [1, 2]]])
axis = 2
self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy)
def testGatherDiffShapeAtAxis0(self, strategy, pure_eager):
"""Different `Axis`-th (0) dimension: shape [1, 1], [2, 1] -> [3, 1]."""
def value_fn(ctx):
return constant_op.constant(
1, shape=(ctx.replica_id_in_sync_group + 1, 1))
distributed_values = strategy.experimental_distribute_values_from_function(
value_fn)
axis = 0
def run():
return strategy.gather(distributed_values, axis=axis)
if not pure_eager:
run = def_function.function(run)
expected_result = constant_op.constant(
1, shape=(sum(range(strategy.num_replicas_in_sync + 1)), 1))
self.assertAllEqual(expected_result, run().numpy())
def testGatherDiffShapeAtAxis1(self, strategy, pure_eager):
"""Different `Axis`-th (non-0) dimension: shape [1, 1], [1, 2] -> [1, 3]."""
def value_fn(ctx):
return constant_op.constant(
1, shape=(1, ctx.replica_id_in_sync_group + 1))
distributed_values = strategy.experimental_distribute_values_from_function(
value_fn)
axis = 1
def run():
return strategy.gather(distributed_values, axis=axis)
if not pure_eager:
run = def_function.function(run)
expected_result = constant_op.constant(
1, shape=(1, sum(range(strategy.num_replicas_in_sync + 1))))
self.assertAllEqual(expected_result, run().numpy())
def testGatherRaiseDiffShapeAtNonAxis(self, strategy, pure_eager):
"""Different at non-`axis`-th dimension : [1, 1], [1, 2], 0th -> raise error."""
if isinstance(strategy, CollectiveAllReduceStrategy
) and _get_num_replicas_per_client(strategy) > 1:
self.skipTest('b/167331966')
if strategy.num_replicas_in_sync <= 1:
self.skipTest('Test for more than 1 replica only.')
def value_fn(ctx):
return constant_op.constant(
1, shape=(1, ctx.replica_id_in_sync_group + 1))
distributed_values = strategy.experimental_distribute_values_from_function(
value_fn)
axis = 0
def run():
return strategy.gather(distributed_values, axis=axis)
if not pure_eager:
run = def_function.function(run)
if isinstance(strategy, CollectiveAllReduceStrategy):
with self.assertRaisesRegex(errors.InvalidArgumentError,
r'Shape mismatch'):
run()
elif isinstance(strategy,
(mirrored_strategy.MirroredStrategy,
central_storage_strategy.CentralStorageStrategy)):
with self.assertRaisesRegex((errors.InvalidArgumentError, ValueError),
r'Dimension \d in both shapes must be equal'):
run()
def testGatherRaiseSparse(self, strategy, pure_eager):
dense_shape = [5, 2]
t0 = _make_indexed_slices(
values=[[1., 2.]], indices=[2], dense_shape=dense_shape)
def run(value):
return strategy.gather(value, axis=0)
with self.assertRaisesRegex(
NotImplementedError,
r'gather does not support IndexedSlices'):
if pure_eager:
run(t0)
else:
def_function.function(run)(t0)
def testGatherRaiseDifferentRank(self, strategy, pure_eager):
"""Different rank: [1,], [1, 2] -> raise error."""
if strategy.num_replicas_in_sync <= 1:
self.skipTest('Test for more than 1 replicas.')
if isinstance(strategy, CollectiveAllReduceStrategy
) and _get_num_replicas_per_client(strategy) > 1:
self.skipTest('b/167331966')
def value_fn(ctx):
return array_ops.ones(shape=(range(1, ctx.replica_id_in_sync_group + 2)))
distributed_values = strategy.experimental_distribute_values_from_function(
value_fn)
axis = 0
def run():
return strategy.gather(distributed_values, axis=axis)
if not pure_eager:
run = def_function.function(run)
if isinstance(strategy, CollectiveAllReduceStrategy):
with self.assertRaisesRegex(errors.InvalidArgumentError,
r'Shape mismatch'):
run()
elif isinstance(
strategy,
(mirrored_strategy.MirroredStrategy,
central_storage_strategy.CentralStorageStrategy)):
if pure_eager:
with self.assertRaises(errors.InvalidArgumentError) as e:
run()
# Different error message depending on whether collective ops is used.
self.assertRegexMatch(
str(e.exception),
['Ranks of all input tensors should match', 'Shape mismatch'])
else:
with self.assertRaises((errors.InvalidArgumentError, ValueError)) as e:
run()
self.assertRegexMatch(
str(e.exception),
[r'Shape must be rank \d but is rank \d', 'Shape mismatch'])
elif _is_tpu_strategy(strategy) and pure_eager:
with self.assertRaisesRegex(ValueError,
r'Dimension \d in both shapes must be equal'):
run()
else:
with self.assertRaisesRegex(ValueError,
r'Shape must be rank \d but is rank \d'):
run()
# Ideally, here we should split them into another test class, AllGatherTest.
# But doing that makes two initialize_tpu_system() calls and one of them times
# out, on Kokoro. Integrating two into one avoids it.
def _all_gather_same_shape_and_verify(self, value_on_replica, axis,
pure_eager, strategy):
per_replica_value = strategy.experimental_distribute_values_from_function(
lambda _: array_ops.identity(value_on_replica))
def replica_fn(per_replica_value):
ctx = ds_context.get_replica_context()
local_value = array_ops.identity(per_replica_value)
return ctx.all_gather(local_value, axis=axis)
if not pure_eager:
replica_fn = def_function.function(replica_fn)
result = strategy.experimental_local_results(
strategy.run(replica_fn, args=(per_replica_value,)))
all_value = [value_on_replica for _ in range(strategy.num_replicas_in_sync)]
expect = array_ops.concat(all_value, axis=axis)
expected_result = [expect] * _get_num_replicas_per_client(strategy)
self.assertAllClose(expected_result, result)
def testAllGatherPerReplicaDense1D0Axis(self, strategy, pure_eager):
"""all_gather(..., axis=0,...) a DistributedValues with a Tensor of shape (3,) on two replica returns a PerReplica of tensor(s) with shape (6,)."""
single_value = constant_op.constant([1, 2, 3], dtype=dtypes.float32)
axis = 0
self._all_gather_same_shape_and_verify(single_value, axis, pure_eager,
strategy)
def testAllGatherPerReplicaDense2D0Axis(self, strategy, pure_eager):
"""all_gather(..., axis=0,...) a DistributedValues with a Tensor of shape (1,3) on two replica returns PerReplica of tensor(s) with shape (2,3)."""
single_value = constant_op.constant([[1, 2, 3]])
axis = 0
self._all_gather_same_shape_and_verify(single_value, axis, pure_eager,
strategy)
def testAllGatherPerReplicaDense2D1Axis(self, strategy, pure_eager):
"""all_gather(..., axis=1,...) a DistributedValues with a Tensor of shape (1,3) on two replica returns PerReplica of tensor(s) with shape (1,6)."""
single_value = constant_op.constant([[1, 2, 3]])
axis = 1
self._all_gather_same_shape_and_verify(single_value, axis, pure_eager,
strategy)
def testAllGatherPerReplicaDense3D0Axis(self, strategy, pure_eager):
"""all_gather(..., axis=0,...) a DistributedValues with a Tensor of shape (1,2,2) on two replica returns PerReplica of tensor(s) with shape (2,2,2)."""
single_value = constant_op.constant([[[1, 2], [1, 2]]])
axis = 0
self._all_gather_same_shape_and_verify(single_value, axis, pure_eager,
strategy)
def testAllGatherPerReplicaDense3D1Axis(self, strategy, pure_eager):
"""all_gather(..., axis=1,...) a DistributedValues with a Tensor of shape (1,2,2) on two replica returns PerReplica of tensor(s) with shape (1,4,2)."""
single_value = constant_op.constant([[[1, 2], [1, 2]]])
axis = 1
self._all_gather_same_shape_and_verify(single_value, axis, pure_eager,
strategy)
def testAllGatherPerReplicaDense3D2Axis(self, strategy, pure_eager):
"""all_gather(..., axis=2,...) a DistributedValues with a Tensor of shape (1,2,2) on two replica returns PerReplica of tensor(s) with shape (1,2,4)."""
single_value = constant_op.constant([[[1, 2], [1, 2]]])
axis = 2
self._all_gather_same_shape_and_verify(single_value, axis, pure_eager,
strategy)
def testAllGatherDiffValueTPU(self, strategy, pure_eager):
# Test for TPU only since it can't be tested via testAllGatherDiffShape*
if not _is_tpu_strategy(strategy):
self.skipTest('Test for TPU only. For other strategies case already'
' covered in other tests')
data = [[1], [2], [3], [4], [5], [6], [7], [8]]
axis = 0
dataset = dataset_ops.DatasetV2.from_tensor_slices(data).batch(8)
input_iterator = iter(strategy.experimental_distribute_dataset(dataset))
@def_function.function
def replica_fn(per_replica_value):
ctx = ds_context.get_replica_context()
return ctx.all_gather(array_ops.identity(per_replica_value), axis=axis)
result = strategy.experimental_local_results(
strategy.run(replica_fn, args=(next(input_iterator),)))
expected_result = [data] * _get_num_replicas_per_client(strategy)
self.assertAllClose(expected_result, result)
def testAllGatherDiffShapeAtAxis0(self, strategy, pure_eager):
"""Different `Axis==0`-th dimension: shape [1, 1], [2, 1] -> [3, 1]."""
if _is_tpu_strategy(strategy):
self.skipTest('TPU does not support all_gather different shapes')
def value_fn(ctx):
return constant_op.constant(
1, shape=(ctx.replica_id_in_sync_group + 1, 1))
per_replica_value = strategy.experimental_distribute_values_from_function(
value_fn)
expect = constant_op.constant(
1, shape=(sum(range(strategy.num_replicas_in_sync + 1)), 1))
def run(value):
value_identity = array_ops.identity(value)
ctx = ds_context.get_replica_context()
return ctx.all_gather(value_identity, axis=0)
if not pure_eager:
run = def_function.function(run)
expected_result = [expect] * _get_num_replicas_per_client(strategy)
result = strategy.experimental_local_results(
strategy.run(run, args=(per_replica_value,)))
self.assertAllEqual(expected_result, result)
def testAllGatherDiffShapeAtAxis1(self, strategy, pure_eager):
"""Different `Axis`-th (not 0th) dimension: shape [1, 1], [1, 2] -> [1, 3]."""
if _is_tpu_strategy(strategy):
self.skipTest('TPU does not support all_gather different shapes')
def value_fn(ctx):
return constant_op.constant(
1, shape=(1, ctx.replica_id_in_sync_group + 1))
per_replica_value = strategy.experimental_distribute_values_from_function(
value_fn)
expect = constant_op.constant(
1, shape=(1, sum(range(strategy.num_replicas_in_sync + 1))))
def run(value):
value_identity = array_ops.identity(value)
ctx = ds_context.get_replica_context()
return ctx.all_gather(value_identity, axis=1)
if not pure_eager:
run = def_function.function(run)
expected_result = [expect] * _get_num_replicas_per_client(strategy)
result = strategy.experimental_local_results(
strategy.run(run, args=(per_replica_value,)))
self.assertAllEqual(expected_result, result)
def testAllGatherNest(self, strategy, pure_eager):
if _is_tpu_strategy(strategy):
self.skipTest('TPU does not support all_gather different shapes')
axis = 1
def value_fn(ctx):
value = constant_op.constant(
1, shape=(1, ctx.replica_id_in_sync_group + 1))
return value
per_replica_value = strategy.experimental_distribute_values_from_function(
value_fn)
expect_1 = constant_op.constant(
1, shape=(1, sum(range(strategy.num_replicas_in_sync + 1))))
expected_per_replica_1 = [expect_1] * _get_num_replicas_per_client(strategy)
value_2 = constant_op.constant([[[1, 2], [1, 2]]])
expect_2 = array_ops.concat(
[value_2 for _ in range(strategy.num_replicas_in_sync)], axis=axis)
expected_per_replica_2 = [expect_2] * _get_num_replicas_per_client(strategy)
def run(value):
value_1 = array_ops.identity(value)
value_3 = array_ops.identity(value_2)
ctx = ds_context.get_replica_context()
return ctx.all_gather([value_1, value_3], axis=axis)
if not pure_eager:
run = def_function.function(run)
result = strategy.run(run, args=(per_replica_value,))
self.assertAllEqual(expected_per_replica_1,
strategy.experimental_local_results(result[0]))
self.assertAllEqual(expected_per_replica_2,
strategy.experimental_local_results(result[1]))
def testAllGatherNest1D0Axis(self, strategy, pure_eager):
"""all_gather(..., axis=0,...) a nest of DistributedValues."""
single_value = constant_op.constant([1, 2, 3])
axis = 0
def run():
value_identity = array_ops.identity(single_value)
ctx = ds_context.get_replica_context()
return ctx.all_gather([value_identity, value_identity], axis=axis)
if not pure_eager:
run = def_function.function(run)
all_value = [single_value for _ in range(strategy.num_replicas_in_sync)]
expect = array_ops.concat(all_value, axis=axis)
expected_per_replica = [expect] * _get_num_replicas_per_client(strategy)
result = strategy.run(run)
for gathered_result in result:
self.assertAllEqual(expected_per_replica,
strategy.experimental_local_results(gathered_result))
def testAllGatherRaiseDiffShapeAtNonAxis(self, strategy, pure_eager):
"""Different at non-`axis`-th dimension : [2, 1], [1, 1], all_gather(...axis=1...) -> raise error."""
if _is_tpu_strategy(strategy):
self.skipTest('TODO(b/169108777): raise a clear error message in xla.')
if isinstance(strategy, CollectiveAllReduceStrategy
) and _get_num_replicas_per_client(strategy) > 1:
self.skipTest('b/167331966')
if strategy.num_replicas_in_sync <= 1:
self.skipTest('Test for more than 1 replica only.')
def value_fn(ctx):
return constant_op.constant(
1, shape=(1, ctx.replica_id_in_sync_group + 1))
per_replica_value = strategy.experimental_distribute_values_from_function(
value_fn)
def run(value):
value_identity = array_ops.identity(value)
ctx = ds_context.get_replica_context()
return ctx.all_gather(value_identity, axis=0)
if not pure_eager:
run = def_function.function(run)
if isinstance(strategy, CollectiveAllReduceStrategy):
with self.assertRaisesRegex(errors.InvalidArgumentError,
r'Shape mismatch'):
strategy.run(run, args=(per_replica_value,))
elif isinstance(strategy,
(mirrored_strategy.MirroredStrategy,
central_storage_strategy.CentralStorageStrategy)):
with self.assertRaisesRegex((errors.InvalidArgumentError, ValueError),
r'Dimension \d in both shapes must be equal'):
strategy.run(run, args=(per_replica_value,))
def testAllGatherRaiseSparse(self, strategy, pure_eager):
dense_shape = [5, 2]
t0 = _make_indexed_slices(
values=[[1., 2.]], indices=[2], dense_shape=dense_shape)
def replica_fn(value):
ctx = ds_context.get_replica_context()
return ctx.all_gather(value, axis=0)
with self.assertRaisesRegex(
NotImplementedError,
r'all_gather does not support IndexedSlices'):
if not pure_eager:
strategy.run(def_function.function(replica_fn), args=(t0,))
else:
strategy.run(replica_fn, args=(t0,))
def testAllGatherRaiseDifferentRank(self, strategy, pure_eager):
"""Different rank: [1,], [1, 2] -> raise error."""
if _is_tpu_strategy(strategy):
self.skipTest('TODO(b/169108777): raise a clear error message in xla.')
if strategy.num_replicas_in_sync <= 1:
self.skipTest('Test for more than 1 replicas.')
if isinstance(strategy, CollectiveAllReduceStrategy
) and _get_num_replicas_per_client(strategy) > 1:
self.skipTest('b/167331966')
def value_fn(ctx):
return array_ops.ones(shape=(range(1, ctx.replica_id_in_sync_group + 2)))
per_replica_value = strategy.experimental_distribute_values_from_function(
value_fn)
def run(value):
value_identity = array_ops.identity(value)
ctx = ds_context.get_replica_context()
return ctx.all_gather(value_identity, axis=0)
if not pure_eager:
run = def_function.function(run)
if isinstance(strategy, CollectiveAllReduceStrategy):
with self.assertRaisesRegex(errors.InvalidArgumentError,
r'Shape mismatch'):
strategy.run(run, args=(per_replica_value,))
elif isinstance(strategy,
(mirrored_strategy.MirroredStrategy,
central_storage_strategy.CentralStorageStrategy)):
if pure_eager:
with self.assertRaises(errors.InvalidArgumentError) as e:
strategy.run(run, args=(per_replica_value,))
# Different error message depending on whether collective ops is used.
self.assertRegexMatch(
str(e.exception),
['Ranks of all input tensors should match', 'Shape mismatch'])
else:
with self.assertRaises((errors.InvalidArgumentError, ValueError)) as e:
strategy.run(run, args=(per_replica_value,))
self.assertRegexMatch(
str(e.exception),
[r'Shape must be rank \d but is rank \d', 'Shape mismatch'])
else:
with self.assertRaisesRegex(ValueError,
r'Dimension \d in both shapes must be equal'):
strategy.run(run, args=(per_replica_value,))
def testAllGatherGradient(self, strategy, pure_eager):
if pure_eager:
self.skipTest('`tf.gradients` is not supported with eager execution '
'without using tf.functions.')
def all_gather_fn(value):
axis = 1
ctx = ds_context.get_replica_context()
return ctx.all_gather(array_ops.identity(value), axis)
gradient_comp = sum(range(1, strategy.num_replicas_in_sync + 1))
gradient = [[gradient_comp], [gradient_comp]]
grads_for_all_replicas = [gradient] * _get_num_replicas_per_client(strategy)
@def_function.function
def step(c):
x = constant_op.constant([[3.], [5.]])
mid = all_gather_fn(x)
y = mid * c
return gradients_impl.gradients_v2(y, [x])[0]
def value_fn(ctx):
x = [1., 2., 3., 4., 5., 6., 7., 8.]
return array_ops.constant([x[ctx.replica_id_in_sync_group]])
per_replica_value = strategy.experimental_distribute_values_from_function(
value_fn)
result = strategy.experimental_local_results(
strategy.run(step, args=(per_replica_value,)))
self.assertAllEqual(grads_for_all_replicas, result)
def testAllGatherGradientNest(self, strategy, pure_eager):
if pure_eager:
self.skipTest('`tf.gradients` is not supported with eager execution '
'without using tf.functions.')
def all_gather_fn(value):
axis = 1
ctx = ds_context.get_replica_context()
return ctx.all_gather(array_ops.identity(value), axis)
gradient_comp = sum(range(1, strategy.num_replicas_in_sync + 1))
gradient = [[gradient_comp], [gradient_comp]]
grads_for_all_replicas = [gradient] * _get_num_replicas_per_client(strategy)
@def_function.function
def step(c):
x = constant_op.constant([[3.], [5.]])
y = constant_op.constant([[2.], [4.]])
mid = all_gather_fn([x, y])
y = mid * c
return gradients_impl.gradients_v2(y, [x])[0]
def value_fn(ctx):
x = [1., 2., 3., 4., 5., 6., 7., 8.]
return array_ops.constant([x[ctx.replica_id_in_sync_group]])
per_replica_value = strategy.experimental_distribute_values_from_function(
value_fn)
result = strategy.experimental_local_results(
strategy.run(step, args=(per_replica_value,)))
self.assertAllEqual(grads_for_all_replicas, result)
def _make_indexed_slices(values, indices, dense_shape):
tensor = indexed_slices.IndexedSlices(
values=constant_op.constant(values),
indices=constant_op.constant(indices),
dense_shape=constant_op.constant(dense_shape))
return tensor
def _get_num_replicas_per_client(strategy):
if isinstance(strategy, CollectiveAllReduceStrategy):
resolver = strategy.cluster_resolver
return max(nest.flatten(resolver.num_accelerators())[0], 1)
else:
return strategy.num_replicas_in_sync
def _is_tpu_strategy(strategy):
return isinstance(strategy,
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1,
tpu_strategy.TPUStrategyV2))
if __name__ == '__main__':
test_util.main()