-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtpu_util.py
795 lines (652 loc) · 31.4 KB
/
tpu_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TPU-specific utilities for DTensor."""
import functools
import time
from typing import List, Optional, Dict
import numpy as np
from tensorflow.dtensor.python import config
from tensorflow.dtensor.python import dtensor_device
from tensorflow.dtensor.python import gen_dtensor_ops
from tensorflow.dtensor.python import layout as layout_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import topology
from tensorflow.python.util.tf_export import tf_export
_MESH_DIM_X = "x"
_TPU_DEVICE_TYPE = "TPU"
# A dedicated, hidden device used to make C++ API calls.
_dtensor_device = None
# `_topology._mesh_shape` contains the TPU hardware slice size.
# `_topology.device_coordinates` maps TF task-device ordinals to TPU core IDs.
_tpu_topology = None
# Cache core ID <-> location mappings so we need not make repeated C++ calls.
# Both are indexed by TF task-device ordinals.
_all_core_ids = None
_all_core_locations = None
class _CoreLocation:
"""Represents a TPU core's location in the mesh."""
def __init__(self, x: int = 0, y: int = 0, z: int = 0, core: int = 0):
self.x = x
self.y = y
self.z = z
self.core = core
def __eq__(self, other):
if not isinstance(other, _CoreLocation):
return False
return self.x == other.x and self.y == other.y and self.z == other.z and self.core == other.core
def __ne__(self, other):
if not isinstance(other, _CoreLocation):
return True
return not self == other
def __hash__(self):
return hash((self.x, self.y, self.z, self.core))
def __repr__(self):
return f"{type(self).__name__}(x={self.x}, y={self.y}, z={self.z}, core={self.core})"
def to_list(self):
return [self.x, self.y, self.z, self.core]
def _create_device_array(shape, device_type, host_id, local_device_ids=None):
"""Returns ID and device lists that can be used to create a mesh."""
num_global_devices = config.num_global_devices(device_type)
global_device_ids = np.arange(num_global_devices).reshape(shape)
local_device_list = config.local_devices(device_type)
# User can specify local_device_ids or use default list for multi host.
num_local_devices = len(local_device_list)
local_device_ids = [
x + host_id * num_local_devices for x in range(num_local_devices)
] if not local_device_ids else local_device_ids
return global_device_ids, local_device_ids, local_device_list
def _create_tpu_topology(core_locations: List[_CoreLocation], num_tasks: int,
num_devices_per_task: int) -> topology.Topology:
"""Returns a Topology object build from a _CoreLocation list.
Args:
core_locations: A list of _CoreLocation objects sorted first by TF task ID
and then by per-task device ordinals.
num_tasks: The number of TF tasks in the cluster.
num_devices_per_task: The number of TPU devices local to each task.
"""
assert min([l.x for l in core_locations]) == 0
assert min([l.y for l in core_locations]) == 0
assert min([l.z for l in core_locations]) == 0
assert min([l.core for l in core_locations]) == 0
x_max = max([l.x for l in core_locations])
y_max = max([l.y for l in core_locations])
z_max = max([l.z for l in core_locations])
core_max = max([l.core for l in core_locations])
mesh_shape = [x_max + 1, y_max + 1, z_max + 1, core_max + 1]
device_coordinates = [[l.x, l.y, l.z, l.core] for l in core_locations]
device_coordinates = np.asarray(device_coordinates).reshape(
num_tasks, num_devices_per_task, 4)
return topology.Topology(
mesh_shape=mesh_shape, device_coordinates=device_coordinates)
def shutdown_tpu_system():
"""Shuts down the TPU system."""
@def_function.function
def _shutdown_tpu_system():
return gen_dtensor_ops.shutdown_tpu_system()
success = _shutdown_tpu_system() if context.is_tfrt_enabled() else True
if success:
logging.info("TPU system shut down.")
else:
logging.warning("TPU system fails to shut down.")
def tpu_system_init_helper(task_id,
num_tasks,
num_devices,
use_tfrt_host_runtime=True):
"""A helper function to initialize multi-client tpu system."""
@def_function.function
def _tpu_init_fn():
return gen_dtensor_ops.configure_and_initialize_global_tpu(
use_tfrt_host_runtime=use_tfrt_host_runtime)
@def_function.function
def _set_global_tpu_array_fn(topology_proto):
gen_dtensor_ops.d_tensor_set_global_tpu_array(topology_proto)
with ops.device("/job:" + config.full_job_name() + "/device:TPU_SYSTEM:0"): # pylint: disable=protected-access
my_core_ids = _tpu_init_fn()
logging.info("TPU core IDs: %s", my_core_ids)
# `my_core_ids` contains the IDs of TPU cores attached to this host.
#
# To generate correct and efficient XLA AllReduce group assignment, we must
# merge these arrays from all hosts and broadcast the result back to all
# hosts, so all hosts can use these mappings in their MLIR passes.
#
# This is essentially doing what WaitForDistributedTpuOp and
# SetGlobalTPUArrayOp do, in our multi-client environment.
num_devices_per_task = int(num_devices / num_tasks)
# Create a one-time use mesh and layout just for merging core IDs.
mesh = layout_lib.Mesh([_MESH_DIM_X],
*_create_device_array((num_devices,), _TPU_DEVICE_TYPE,
config.client_id()))
layout = layout_lib.Layout([_MESH_DIM_X, layout_lib.UNSHARDED], mesh)
device = dtensor_device.DTensorDevice(meshes=[mesh])
logging.info("TPU core locations: %s",
device.tpu_core_ids_to_locations(my_core_ids))
# At this point, we don't know which cores are attached to other hosts.
# The core ID mappings in the runtime haven't been set yet.
#
# The core ID merging AllReduce below is carefully written so it works
# without needing correct core mappings to be set in the runtime. We will
# use this AllReduce's result to set the core ID mappings, and all future
# user-initiated AllReduces will use the mappings.
#
# The runtime is hard-coded to ignore core ID mappings on this AllReduce.
all_core_ids = np.zeros([num_devices], dtype=np.int32)
for i in range(len(my_core_ids)):
all_core_ids[task_id * num_devices_per_task + i] = my_core_ids[i]
# Only one local device gets valid input: 8 local core IDs among
# (num_tasks - 1) * 8 zeros. The 8 core IDs are set using task ID as offset.
# The other 7 local devices get zero inputs. All devices on all host
# participate in one AllReduce, whose result will be core IDs arranged by
# task-device ordinals.
all_core_ids = constant_op.constant([all_core_ids])
zeros = array_ops.zeros_like(all_core_ids)
all_core_ids = [all_core_ids] + [zeros] * (num_devices_per_task - 1)
with ops.device(device.name):
all_core_ids = device.pack(all_core_ids, layout)
all_core_ids = math_ops.reduce_sum(all_core_ids, axis=[0])
unpacked_all_tpu_ids = device.unpack(all_core_ids)
all_core_ids = list(unpacked_all_tpu_ids[0].numpy())
logging.info("All TPU core IDs: %s", all_core_ids)
# Set the default core ID mappings in the runtime for legacy code and tests.
#
# Legacy code and tests create TPU meshes directly without using the
# `create_tpu_mesh` function below. Those meshes have global device IDs
# equal to TF task-device ordinals. The `all_core_ids` array happens to
# arrange core IDs by TF task-device ordinals. Using this array on those
# meshes guarantee correct although inefficient results.
device.set_tpu_core_ids("", all_core_ids)
# Remember enough global, immutable information to be able to build any ring
# we want prescribed by `create_tpu_mesh` in the future.
global _all_core_ids
_all_core_ids = all_core_ids
all_core_locations = device.tpu_core_ids_to_locations(all_core_ids)
all_core_locations = [
_CoreLocation(l[0], l[1], l[2], l[3]) for l in all_core_locations
]
global _all_core_locations
_all_core_locations = all_core_locations
logging.info("All TPU core locations: %s", all_core_locations)
tpu_topology = _create_tpu_topology(all_core_locations, num_tasks,
num_devices_per_task)
_set_global_tpu_array_fn(tpu_topology.serialized())
return tpu_topology, device
def initialize_tpu_system():
"""Initializes the TPU system."""
# Make sure the server change is fully propagated before attempting to run
# the core ID merging logic below.
context.ensure_initialized()
context.async_wait()
context.context()._clear_caches() # pylint: disable=protected-access
use_tfrt_host_runtime = context.context().use_tfrt
logging.info("Using TFRT host runtime is set to %s", use_tfrt_host_runtime)
try:
task_id = config.client_id()
num_tasks = config.num_clients()
num_devices = config.num_global_devices(_TPU_DEVICE_TYPE)
tpu_topology, device = tpu_system_init_helper(
task_id,
num_tasks,
num_devices,
use_tfrt_host_runtime=use_tfrt_host_runtime)
global _tpu_topology
_tpu_topology = tpu_topology
logging.vlog(1, "TPU Topology: %s, %s", tpu_topology.mesh_shape,
tpu_topology.device_coordinates)
global _dtensor_device
_dtensor_device = device
context.async_wait()
except errors.InvalidArgumentError as e:
raise errors.NotFoundError(
None, None,
"Initialization failed, no valid TPUs found. " + str(e)) from e
except errors.InternalError as e:
logging.error("Hit internal error during TPU system initialization. "
+ "It is likely hareware failure. \nPlease check the error "
+ "messages above to see whether that's the case. \nIf so, "
+ "consider to restart the job or try another machine.")
raise e
# Clear out the eager context caches since the memory is invalid now.
logging.info("Clearing out eager caches")
context.context()._clear_caches() # pylint: disable=protected-access
def _enumerate_cores(bounds: List[int], ring_bounds: List[int],
ring_sizes: List[int], host_bounds: List[int],
host_sizes: List[int]) -> List[List[int]]:
"""Enumerates cores within `bounds` from fatest to slowest varying axes.
Args:
bounds: Upper bounds of axes, from fastest to slowest varying.
ring_bounds: Upper bounds of ring size per axis in the same axis order.
ring_sizes: Number consecutive cores in the ring built so far, cumulatively.
host_bounds: Number of axis values per host in the same axis order.
host_sizes: Number consecutive cores on one host, cumulatively.
Returns:
Cores represented as a list of 4 integers in the same axis order.
"""
if not bounds:
return [[]]
# Recursively enumerate cores under all but the slowest varying axis.
partials = _enumerate_cores(bounds[:-1], ring_bounds[:-1], ring_sizes[:-1],
host_bounds[:-1], host_sizes[:-1])
# Append the slowest varying axis to the end of all partial results.
# From ring_i|j to host_i|j to core_i|j, use progressively smaller or equal
# iteration groupings until every one of the bounds[-1] * len(partials)
# combinations is iterated on.
# Despite the six levels of nested loops below, the total time complexity for
# this invocation is O(N), where N is the number of cores in the topology.
results = []
for ring_i in range(0, bounds[-1], ring_bounds[-1]):
for ring_j in range(0, len(partials), ring_sizes[-1]):
for host_i in range(ring_i, ring_i + ring_bounds[-1], host_bounds[-1]):
for host_j in range(ring_j, ring_j + ring_sizes[-1], host_sizes[-1]):
for i in range(host_i, host_i + host_bounds[-1]):
for j in range(host_j, host_j + host_sizes[-1]):
results.append(partials[j] + [i])
return results
def _enumerate_core_locations(bounds: List[int], ring_bounds: List[int],
axes: List[str],
can_split_host_across_rings: bool,
ring_size: int) -> List[_CoreLocation]:
"""Enumerates all possible core locations under the axis iteration order.
Args:
bounds: A list of 4 positive integers, upper bound values for x, y, z, core.
ring_bounds: A list of 4 positive integers, upper bound values for ring size
in x, y, z, core axes.
axes: A permutation of ["x", "y", "z", "core"], the axis iteration order.
can_split_host_across_rings: If true, devices attached to the same host may
get assigned to different rings.
ring_size: Number of devices in a ring, only for argument validation.
Returns:
A list of all CoreLocation objects defined in a TPU slice of shape `bounds`,
sorted by axis iteration order specified by `axes`.
For example, given bounds=[2, 2, 1, 2] and axes=["core", "z", "y", "x"],
return 8 core locations expressed in (x, y, z, core) format but iterated in
core -> z -> y -> x order (fatest to slowest varying):
[_CoreLocation(0, 0, 0, 0),
_CoreLocation(0, 0, 0, 1),
_CoreLocation(0, 1, 0, 0),
_CoreLocation(0, 1, 0, 1),
_CoreLocation(1, 0, 0, 0),
_CoreLocation(1, 0, 0, 1),
_CoreLocation(1, 1, 0, 0),
_CoreLocation(1, 1, 0, 1)]
Raises:
ValueError: If ring_size cannot be fulfilled without splitting hosts.
"""
num_cores_per_chip = bounds[3]
if num_cores_per_chip != 1 and num_cores_per_chip != 2:
raise ValueError("Unsupported TPU slice size: %s" % bounds)
# Translate `axes` from string to integer format.
axes = [{"x": 0, "y": 1, "z": 2, "core": 3}[axis] for axis in axes]
# Reorder bounds from fastest to slowest varying axes.
bounds = [bounds[i] for i in axes]
# Set and validate host_bounds.
if can_split_host_across_rings:
# If we can split hosts, shrink every host to effectively contain 1 device.
host_bounds = [1, 1, 1, 1]
elif np.prod(bounds) <= 2:
# We must be running on 1x1 or 1x1x1 Forge.
host_bounds = [[1, 1, 1, num_cores_per_chip][i] for i in axes]
else:
# Other cases including 2x2 Forge and Borg must use a full donut.
host_bounds = [[2, 2, 1, num_cores_per_chip][i] for i in axes]
# host_sizes is the cumulative products of host_bounts.
host_sizes = [1]
for host_bound in host_bounds:
host_sizes.append(host_sizes[-1] * host_bound)
host_size = host_sizes.pop()
# When can_split_host_across_rings is false, a ring must contain at least as
# many devices as a host has.
if ring_size < host_size:
assert not can_split_host_across_rings
raise ValueError(
"Rings too small for can_split_host_across_rings = False: %d" %
ring_size)
# Reorder ring_bounds and validate it's element-wise >= host_bounds.
ring_bounds = [ring_bounds[i] for i in axes]
if ring_bounds < host_bounds:
raise ValueError("ring_bounds %s should be >= host_bounds %s" %
(ring_bounds, host_bounds))
ring_sizes = [1]
# ring_sizes is the cumulative products of ring_bounds.
for ring_bound in ring_bounds:
ring_sizes.append(ring_sizes[-1] * ring_bound)
ring_sizes.pop()
# Enumerate cores in the given iteration order. Each core is represented as a
# list of int, which are offsets from fatest to slowest varying axes.
cores = _enumerate_cores(bounds, ring_bounds, ring_sizes, host_bounds,
host_sizes)
# Reorder offsets of each core back to the x, y, z, core order.
core_locations = []
for core in cores:
core = [core[axes.index(i)] for i in range(4)]
core_locations.append(_CoreLocation(core[0], core[1], core[2], core[3]))
return core_locations
def _build_all_reduce_ring(core_locations: List[_CoreLocation],
rotate: bool = False) -> List[int]:
"""Reorders a list of TPU cores to optimize for AllReduce performance.
This is ported from the C++ tensorflow::BuildAllReduceRing function,
mixed with some logic from TF TPU's device_assignment._ring_3d.
Args:
core_locations: A list of core locations expressed as [x, y, z, core].
rotate: If true, scan the cores in a column-major order. False by default.
Returns:
A permutation of the input list such that neighbors in the sequence are
nearby in the TPU topology.
"""
permutation = list(range(len(core_locations)))
if not permutation:
return permutation
logging.vlog(2, "Core locations in: %s", core_locations)
first_column = min([l.x for l in core_locations])
first_row = min([l.y for l in core_locations])
same_z = (len(set([l.z for l in core_locations])) == 1)
logging.vlog(2, "first_column: %d", first_column)
logging.vlog(2, "first_row: %d", first_row)
logging.vlog(2, "same_z: %s", same_z)
def _cmp_2d(ia: int, ib: int) -> int:
if not rotate:
a = core_locations[ia]
b = core_locations[ib]
# Order the first column last in the sequence, except for the first row.
a_first = (a.x == first_column and a.y != first_row)
b_first = (b.x == first_column and b.y != first_row)
if a_first != b_first:
return -1 if b_first else 1
# Order rows in increasing order, unless in the first column.
if a.y != b.y:
return b.y - a.y if a_first else a.y - b.y
# Order even rows left to right, odd rows right to left.
if a.x != b.x:
return a.x - b.x if a.y % 2 == 0 else b.x - a.x
# Order cores in increasing order.
return a.core - b.core
else:
a = core_locations[ia]
b = core_locations[ib]
# Order the first row last in the sequence, except for the first column.
a_first = (a.y == first_row and a.x != first_column)
b_first = (b.y == first_row and b.x != first_column)
if a_first != b_first:
return -1 if b_first else 1
# Order columns in increasing order, unless in the first row.
if a.x != b.x:
return b.x - a.x if a_first else a.x - b.x
# Order even columns top down, odd columns bottom up.
if a.y != b.y:
return a.y - b.y if a.x % 2 == 0 else b.y - a.y
# Order cores in increasing order.
return a.core - b.core
def _cmp_3d(ia: int, ib: int) -> int:
a = core_locations[ia]
b = core_locations[ib]
a_corner = (a.x == first_column and a.y == first_row)
b_corner = (b.x == first_column and b.y == first_row)
# If both are in the corner, order in reverse z then core order.
if a_corner and b_corner:
return b.z - a.z if a.z != b.z else a.core - b.core
# Corner cores always go after non-corner cores.
if a_corner != b_corner:
return -1 if b_corner else 1
# Both non-corner cores are on the same z-plane. Reverse odd z-planes.
if a.z == b.z:
return _cmp_2d(ia, ib) if a.z % 2 == 0 else -_cmp_2d(ia, ib)
# Both non-corner cores are on different z-planes. Smaller z goes first.
return a.z - b.z
# If all cores are on the same z-plane, order as usual. Otherwise, order
# neighbor z-planes in opposite orders. Stack all z-planes along the z axis
# and connect them in one corner.
if same_z:
permutation.sort(key=functools.cmp_to_key(_cmp_2d))
else:
permutation.sort(key=functools.cmp_to_key(_cmp_3d))
logging.vlog(2, "Permutation out: %s", permutation)
return permutation
def _build_orthogonal_rings(
core_locations: List[_CoreLocation], ring_size: int,
rotate_ring_across_rings: bool) -> List[_CoreLocation]:
"""Build two all-reduce rings orthogonal to each other.
One ring includes every `ring_size` consecutive core locations. It is usually
applied to the model-parallel dimension of a mesh to achieve best 1D
all-reduce performance. The other ring includes core locations separated by
a stride of `ring_size`. It is usually applied to the data-parallel dimension
of a mesh to get predictable strided all-reduce performance.
Args:
core_locations: A list of core locations expressed as [x, y, z, core].
ring_size: The number of core locations in the consecutive ring.
rotate_ring_across_rings: Build column-major secondary rings.
Returns:
A permutation of the input list forming the described rings.
"""
# Build a ring for the first `ring_size` cores, and apply that permutation to
# every group of `ring_size` cores.
num_cores = len(core_locations)
permutation = _build_all_reduce_ring(core_locations[:ring_size])
for r in range(0, num_cores, ring_size):
core_locations[r:r + ring_size] = [
core_locations[r + permutation[i]] for i in range(ring_size)
]
logging.vlog(1, "Permutated core locations: %s", core_locations)
# Build a "ring" for the collection of devices consisting of the 0th device
# from every group, and apply that permutation to every i-th device group.
# This is achieved by transposing the list and back.
transposed = []
for i in range(ring_size):
transposed += [
core_locations[g + i] for g in range(0, num_cores, ring_size)
]
num_rings = int(num_cores / ring_size)
permutation = _build_all_reduce_ring(
transposed[:num_rings], rotate=rotate_ring_across_rings)
for r in range(0, num_cores, num_rings):
transposed[r:r + num_rings] = [
transposed[r + permutation[i]] for i in range(num_rings)
]
untransposed = []
for i in range(num_rings):
untransposed += [transposed[g + i] for g in range(0, num_cores, num_rings)]
logging.vlog(1, "Stride-permutated core locations: %s", untransposed)
return untransposed
@tf_export("experimental.dtensor.create_tpu_mesh", v1=[])
def create_tpu_mesh(
mesh_dim_names: List[str],
mesh_shape: List[int],
mesh_name: str,
ring_dims: Optional[int] = None,
ring_axes: Optional[List[str]] = None,
ring_bounds: Optional[List[int]] = None,
can_split_host_across_rings: bool = True,
build_ring_across_rings: bool = False,
rotate_ring_across_rings: bool = False,
use_xla_spmd: bool = layout_lib.USE_XLA_SPMD) -> layout_lib.Mesh:
"""Returns a distributed TPU mesh optimized for AllReduce ring reductions.
Only as many as leading axes specified by `ring_axes` as necessary will be
used to build rings, as long as the subslice formed by these axes have enough
cores to contain a ring of the required size. The leftover axes in `ring_axes`
won't affect results.
This function always uses all TPU devices, and offers more customization than
`tf.experimental.dtensor.create_distributed_mesh`.
Args:
mesh_dim_names: List of mesh dimension names.
mesh_shape: Shape of the mesh.
mesh_name: A unique name for the mesh. If empty, internally generate one.
ring_dims: Optional; The number of leading (ring_dims > 0) or trailing
(ring_dims < 0) mesh dimensions to build rings for. If unspecified, build
rings for all but the first dimension.
ring_axes: Optional; A permutation of ["x", "y", "z", "core"], specifying
the order of TPU topology axes to build rings in. If unspecified, default
to ["core", "x", "y", "z"].
ring_bounds: Optional; The maximum number of devices on each axis, in the x,
y, z, core order. If unspecified, default to physical topology limits.
can_split_host_across_rings: Optional; If true, devices attached to the same
host (i.e., DTensor client) may get assigned to different rings. Setting
it to false may cause some combinations of arguments to be infeasible; see
DeviceAssignmentTest.testCreateMesh[No]SplittingHosts* for examples.
build_ring_across_rings: Optional; If true, also build a data-parallel ring
across model-parallel rings. This ring could be strided.
rotate_ring_across_rings: Optional; If true, build the data-parallel ring in
column-major instead of row-major order.
use_xla_spmd: Boolean when True, will use XLA SPMD instead of
DTensor SPMD.
"""
logging.info("Building a TPU mesh %s of shape %s", mesh_name, mesh_shape)
logging.info("Requested ring_dims: %s", ring_dims)
logging.info("Requested ring_axes: %s", ring_axes)
logging.info("Requested ring_bounds: %s", ring_bounds)
logging.info("Requested can_split_host_across_rings: %s",
can_split_host_across_rings)
if not mesh_name:
mesh_name = "mesh_%f" % time.time()
logging.info("Requested mesh_name: %s", mesh_name)
# By default, build rings for all but the first (usually batch) dimension.
if ring_dims is None:
ring_dims = 1 - len(mesh_shape)
elif ring_dims < -len(mesh_shape) or ring_dims > len(mesh_shape):
raise ValueError("Invalid ring_dims value: %d" % ring_dims)
logging.info("Actual ring_dims: %s", ring_dims)
# By default, vary axes in the core -> x -> y -> z order.
if ring_axes is None:
ring_axes = ["core", "x", "y", "z"]
elif len(ring_axes) != 4:
raise ValueError("Expected 4 elements in ring_axes, got %s" % ring_axes)
elif sorted(ring_axes) != ["core", "x", "y", "z"]:
raise ValueError("Invalid ring_axes value: %s" % ring_axes)
logging.info("Actual ring_axes: %s", ring_axes)
# Validate ring_bounds values.
if _tpu_topology is None:
raise ValueError(
"Invalid TPU topology, run dtensor.initialize_tpu_system() first")
topology_shape = list(_tpu_topology.mesh_shape)
if ring_bounds is None:
ring_bounds = topology_shape
elif len(ring_bounds) != 4:
raise ValueError("Expected 4 elements in ring_bounds, got %s" % ring_bounds)
elif ring_bounds > topology_shape:
raise ValueError("ring_bounds %s should be <= topology sizes %s" %
(ring_bounds, topology_shape))
logging.info("Actual ring_bounds: %s", ring_bounds)
# Compute ring_size, the number of cores in a ring.
if ring_dims > 0:
ring_size = np.prod(mesh_shape[:ring_dims])
elif ring_dims < 0:
ring_size = np.prod(mesh_shape[ring_dims:])
else:
ring_size = 1 # single-core rings
logging.info("Actual ring_size: %d", ring_size)
# Rearrange all cores according to the axis iteration order.
global_core_locations = _enumerate_core_locations(
topology_shape, ring_bounds, ring_axes, can_split_host_across_rings,
ring_size)
logging.vlog(1, "Enumerated core locations: %s", global_core_locations)
num_cores = len(global_core_locations)
# The mesh to be created must use all TPU cores in the system.
mesh_size = np.prod(mesh_shape)
if mesh_size != num_cores:
raise ValueError(
"Invalid mesh size: mesh shape %s cannot 1:1 map to %d TPU cores" %
(mesh_shape, num_cores))
# Build a ring for the `ring_size` dimension and, if required, a strided ring
# for the orthogonal dimension.
if build_ring_across_rings:
global_core_locations = _build_orthogonal_rings(global_core_locations,
ring_size,
rotate_ring_across_rings)
else:
permutation = _build_all_reduce_ring(global_core_locations[:ring_size])
for r in range(0, num_cores, ring_size):
global_core_locations[r:r + ring_size] = [
global_core_locations[r + permutation[i]] for i in range(ring_size)
]
logging.vlog(1, "Permutated core locations: %s", global_core_locations)
# For this point on, change from List[CoreLocation] to List[List[int]] for
# easier interaction with the C++ API.
global_core_locations = [l.to_list() for l in global_core_locations]
if _dtensor_device is None:
raise ValueError("Invalid system device, "
"run dtensor.initialize_accelerator_system() first")
global_core_ids = _dtensor_device.tpu_core_locations_to_ids(
global_core_locations)
# Store a per-mesh mapping in the runtime.
_dtensor_device.set_tpu_core_ids(mesh_name, global_core_ids)
# Create the mesh by manually specifying local_device_ids.
local_core_locations = _tpu_topology.device_coordinates[config.client_id()]
indexes = [
global_core_locations.index(list(local_core_location))
for local_core_location in local_core_locations
]
global_device_ids, local_device_ids, local_device_list = _create_device_array(
mesh_shape, _TPU_DEVICE_TYPE, None, local_device_ids=indexes)
return layout_lib.Mesh(mesh_dim_names, global_device_ids, local_device_ids,
local_device_list, mesh_name, use_xla_spmd)
def get_device_ids(mesh: layout_lib.Mesh,
client_id: Optional[int] = None) -> List[int]:
"""Returns the device IDs of all TPU cores local to the given client.
A device ID is a non-negative integer that uniquely identifies a device in the
mesh. For example, for a 2x2 mesh ('x', 'y'), this function returns a
permutation of [0, 1, 2, 3].
Note that device IDs and device locations are equivalent. The former is a
linearization of the latter along mesh dimensions.
Args:
mesh: A TPU mesh.
client_id: Optional; A DTensor client ID. If empty, query this client.
"""
if mesh.device_type() != _TPU_DEVICE_TYPE:
raise ValueError("The mesh must be a TPU mesh")
if client_id is None or client_id == config.client_id():
return mesh.local_device_ids()
# It's not clear we should ever allow a client to query other clients for
# their device IDs.
raise NotImplementedError(
"Looking up other clients' device IDs is not supported")
def get_device_locations(
mesh: layout_lib.Mesh,
client_id: Optional[int] = None) -> List[Dict[str, int]]:
"""Returns the device locations of all TPU cores local to the given client.
A device location is a dictionary from dimension names to indices on those
dimensions. For example, for a 2x2 mesh ('x', 'y'), this function returns a
permutation of this list:
[{'x': 0, 'y': 0},
{'x': 0, 'y': 1},
{'x': 1, 'y': 0},
{'x': 1, 'y': 1}].
Note that device IDs and device locations are equivalent. The former is a
linearization of the latter along mesh dimensions.
Args:
mesh: A TPU mesh.
client_id: Optional; A DTensor client ID. If empty, query this client.
"""
if mesh.device_type() != _TPU_DEVICE_TYPE:
raise ValueError("The mesh must be a TPU mesh")
if client_id is None or client_id == config.client_id():
return mesh.local_device_locations()
# It's not clear we should ever allow a client to query other clients for
# their device locations.
raise NotImplementedError(
"Looking up other clients' device locations is not supported")
# TODO(b/245589661): Remove dtensor_initialize_tpu_system() and
# dtensor_shutdown_tpu_system() after users stopped using them.
def dtensor_initialize_tpu_system(enable_coordination_service=False):
"""Deprecated way to initialize the TPU system."""
from . import accelerator_util # pylint: disable=g-import-not-at-top
accelerator_util.initialize_accelerator_system(
"TPU", enable_coordination_service=enable_coordination_service)
def dtensor_shutdown_tpu_system():
"""Deprecated way to shutodwn the TPU system."""
from . import accelerator_util # pylint: disable=g-import-not-at-top
accelerator_util.shutdown_accelerator_system()