/
layout.py
648 lines (531 loc) · 23.1 KB
/
layout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python definitions for `Mesh` and `Layout`."""
import collections
import itertools
from typing import List, Dict, Optional
import numpy as np
from tensorflow.dtensor.proto import layout_pb2
from tensorflow.python.framework import config as tf_config
from tensorflow.python.framework import device as tf_device
from tensorflow.python.util.tf_export import tf_export
# UNSHARDED indicates a tensor dimension is not sharded over any mesh dimension.
UNSHARDED = 'unsharded'
MATCH = 'match'
tf_export('experimental.dtensor.UNSHARDED', v1=[]).export_constant(
__name__, 'UNSHARDED')
tf_export('experimental.dtensor.MATCH', v1=[]).export_constant(
__name__, 'MATCH')
MeshDimension = collections.namedtuple('MeshDimension', ['name', 'size'])
@tf_export('experimental.dtensor.Mesh', v1=[])
class Mesh(object):
"""Represents a Mesh configuration over a certain list of Mesh Dimensions.
A mesh consists of named dimensions with sizes, which describe how a set of
devices are arranged. Defining tensor layouts in terms of mesh dimensions
allows us to efficiently determine the communication required when computing
an operation with tensors of different layouts.
A mesh provides information not only about the placement of the tensors but
also the topology of the underlying devices. For example, we can group 8 TPUs
as a 1-D array for data parallelism or a `2x4` grid for (2-way) data
parallelism and (4-way) model parallelism.
Note: the utilities `dtensor.create_mesh` and
`dtensor.create_distributed_mesh` provide a simpler API to create meshes for
single- or multi-client use cases.
"""
_dim_dict: Dict[str, MeshDimension]
_dim_names: List[str]
_local_device_ids: List[int]
_global_device_ids: np.ndarray
_name: str
_local_devices = List[tf_device.DeviceSpec]
_global_devices = Optional[List[tf_device.DeviceSpec]]
_device_type: str
def __init__(self,
dim_names: List[str],
global_device_ids: np.ndarray,
local_device_ids: List[int],
local_devices: List[tf_device.DeviceSpec],
mesh_name: str = '',
global_devices: Optional[List[tf_device.DeviceSpec]] = None):
"""Builds a Mesh.
The `dim_names` and `global_device_ids` arguments describe the dimension
names and shape for the mesh.
For example,
```python
dim_names = ('x', 'y'),
global_device_ids = [[0, 1],
[2, 3],
[4, 5]]
```
defines a 2D mesh of shape 3x2. A reduction over the 'x' dimension will
reduce across columns (0, 2, 4) and (1, 3, 5), and a reduction over the 'y'
dimension reduces across rows.
Note: the utilities `dtensor.create_mesh` and
`dtensor.create_distributed_mesh` provide a simpler API to create meshes for
single- or multi-client use cases.
Args:
dim_names: A list of strings indicating dimension names.
global_device_ids: An ndarray of global device IDs is used to compose
DeviceSpecs describing the mesh. The shape of this array determines the
size of each mesh dimension. Values in this array should increment
sequentially from 0. This argument is the same for every DTensor client.
local_device_ids: A list of local device IDs equal to a subset of values
in global_device_ids. They indicate the position of local devices in the
global mesh. Different DTensor clients must contain distinct
local_device_ids contents. All local_device_ids from all DTensor clients
must cover every element in global_device_ids.
local_devices: The list of devices hosted locally. The elements correspond
1:1 to those of local_device_ids.
mesh_name: The name of the mesh. Currently, this is rarely used, and is
mostly used to indicate whether it is a CPU, GPU, or TPU-based mesh.
global_devices (optional): The list of global devices. Set when multiple
device meshes are in use.
"""
# Check if input args are valid.
if not isinstance(global_device_ids, np.ndarray):
raise ValueError('Variable global_device_ids must be an ndarray.')
if global_device_ids.size == 0:
raise ValueError('Variable global_device_ids must be non-empty.')
flat_global_device_ids = global_device_ids.flatten()
# global_device_ids are expected to be consecutive numbers.
# LINT.IfChange
distance = flat_global_device_ids[0]
if any(
(gid - i != distance) for i, gid in enumerate(flat_global_device_ids)):
raise ValueError('global_device_ids must sequentially increase: %s' %
global_device_ids)
# LINT.ThenChange(//tensorflow/dtensor/cc/dtensor_device.cc)
if len(dim_names) != global_device_ids.ndim:
raise ValueError(
'Number of mesh dimensions does not match number of dimension names.')
if not isinstance(local_device_ids, list):
raise ValueError('Variable local_device_ids must be a list of integers.')
if not isinstance(local_devices, list):
raise ValueError('Variable local_devices must be a list of DeviceSpecs.')
if global_devices and not isinstance(global_devices, list):
raise ValueError('Variable global_devices must be a list of DeviceSpecs.')
if not local_devices and not global_devices:
raise ValueError('Empty list of devices not allowed.')
local_devices_set = set(local_devices)
local_device_only_contains_host_cpu = (
len(local_devices_set) == 1 and
list(local_devices_set)[0].device_type == 'CPU')
if not local_device_only_contains_host_cpu and len(local_devices) != len(
local_devices_set):
raise ValueError('Duplicate devices found in mesh specification %s.' %
[d for d in local_devices if local_devices.count(d) > 1])
if len(local_device_ids) != len(local_devices):
raise ValueError(
'Variable local_device_ids does not have same size as local_devices.')
if len(local_device_ids) > len(np.ravel(global_device_ids)):
raise ValueError('Cannot have more local than gobal device IDs.')
device_types = set([device.device_type for device in local_devices])
if not device_types:
device_types = set([device.device_type for device in global_devices])
if None in device_types:
raise ValueError('device_type is required')
if len(device_types) > 1:
raise ValueError('Devices containing multiple device_types : %s' %
device_types)
# Set object's state.
self._device_type = device_types.pop()
self._dim_names = dim_names
self._dim_dict = {
dim_name: MeshDimension(dim_name, global_device_ids.shape[i])
for i, dim_name in enumerate(dim_names)
}
self._global_device_ids = global_device_ids
self._local_device_ids = local_device_ids
self._local_devices = local_devices
self._global_devices = global_devices
self._name = mesh_name
@property
def dim_names(self) -> List[str]:
return self._dim_names
@property
def name(self) -> str:
return self._name
def is_remote(self) -> bool:
return not self._local_device_ids and self._global_device_ids.size > 0
def host_mesh(self):
"""Returns the 1-1 mapped host mesh."""
if self.device_type().upper() == 'CPU':
return self
v_cpus_counts = len(tf_config.list_logical_devices('CPU'))
if v_cpus_counts < len(self._local_devices):
raise ValueError('Must have at least {0} virtual CPUs for mesh : {1}, '
'but got : {2} virtual CPUs.'.format(
len(self._local_devices), self.to_string(),
v_cpus_counts))
device_array = np.asarray([
spec.replace(device_type='CPU') for spec in self._local_devices
]).reshape((len(self._local_devices), 1))
global_devices = None
if self._global_devices:
global_devices = [
spec.replace(device_type='CPU') for spec in self._global_devices
]
h_mesh = Mesh(
self._dim_names,
self._global_device_ids,
self.local_device_ids(),
np.ravel(device_array).tolist(),
global_devices=global_devices)
return h_mesh
def device_type(self) -> str:
return self._device_type
def contains_dim(self, dim_name: str) -> bool:
return dim_name in self._dim_dict
def __contains__(self, dim_name: str) -> bool:
return self.contains_dim(dim_name)
def dim_size(self, dim_name: str) -> int:
"""Returns the size of a dimension."""
if dim_name not in self._dim_dict.keys():
raise ValueError(('"{dim_name}" not a dimension name in current mesh. ' +
'Dimension names: {dim_names}.').format(
dim_name=dim_name,
dim_names=list(self._dim_dict.keys())))
return self._dim_dict[dim_name].size
def unravel_index(self):
"""Returns a dictionary from device ID to {dim_name: dim_index}.
For example, for a 3x2 mesh, return this:
```
{ 0: {'x': 0, 'y', 0},
1: {'x': 0, 'y', 1},
2: {'x': 1, 'y', 0},
3: {'x': 1, 'y', 1},
4: {'x': 2, 'y', 0},
5: {'x': 2, 'y', 1} }
```
"""
idx_ranges = [
range(self.dim_size(dim_name)) for dim_name in self._dim_names
]
mesh_pos = itertools.product(*idx_ranges)
mapping = {}
for device_id, device_pos in enumerate(mesh_pos):
device_loc = {}
for dim_name, dim_index in zip(self._dim_names, device_pos):
device_loc[dim_name] = dim_index
mapping[device_id] = device_loc
return mapping
def min_global_device_id(self) -> int:
"""Returns the minimum global device ID."""
# global_device_ids sequentially increases.
return self._global_device_ids.flatten()[0]
def local_device_ids(self) -> List[int]:
"""Returns a list of local device IDs."""
return self._local_device_ids
def local_device_locations(self) -> List[Dict[str, int]]:
"""Returns a list of local device locations.
A device location is a dictionary from dimension names to indices on those
dimensions.
"""
mapping = self.unravel_index()
return [mapping[device_id] for device_id in self.local_device_ids()]
def local_devices(self) -> List[str]:
"""Returns a list of local device specs represented as strings."""
return [d.to_string() for d in self._local_devices]
def num_local_devices(self) -> int:
"""Returns the number of local devices."""
return len(self._local_devices)
def to_string(self) -> str:
"""Returns string representation of Mesh."""
# Get proto representation
mesh_proto = self.as_proto()
# Separate individual elements with ','.
name = mesh_proto.name
dim_str = ','.join(
dim.name + '=' + str(dim.size) for dim in mesh_proto.mesh_dimensions)
global_ids = ','.join(str(id) for id in mesh_proto.global_device_ids)
local_ids = ','.join(str(id) for id in mesh_proto.local_device_ids)
devices = ','.join(dev for dev in mesh_proto.local_devices)
components = [name, dim_str, global_ids, local_ids, devices]
if mesh_proto.global_devices:
global_devices = ','.join(dev for dev in mesh_proto.global_devices)
components.append(global_devices)
# Separate mesh components with '|'.
return '|'.join(components)
def as_proto(self) -> layout_pb2.MeshProto:
"""Returns mesh protobuffer."""
mesh_proto = layout_pb2.MeshProto()
mesh_proto.name = self._name
for i, mesh_dimension in enumerate(self._dim_names):
dim = mesh_proto.mesh_dimensions.add()
dim.name = mesh_dimension
dim.size = self._global_device_ids.shape[i]
for d in np.ravel(self._global_device_ids):
mesh_proto.global_device_ids.append(d)
for d in self._local_device_ids:
mesh_proto.local_device_ids.append(d)
for d in self._local_devices:
mesh_proto.local_devices.append(d.to_string())
if self._global_devices:
for d in self._global_devices:
mesh_proto.global_devices.append(d.to_string())
return mesh_proto
@staticmethod
def from_string(mesh_str: str) -> 'Mesh':
"""Construct a mesh instance from input `proto`."""
# Separate elements of mesh.
mesh_parts = mesh_str.split('|')
global_dev_str = None
if len(mesh_parts) == 5:
name, mesh_dim_strs, global_id_str, local_id_str, dev_str = mesh_parts
elif len(mesh_parts) == 6:
(name, mesh_dim_strs, global_id_str, local_id_str, dev_str,
global_dev_str) = mesh_parts
else:
raise ValueError('Invalid mesh string : %s' % mesh_str)
# Load mesh proto.
mesh_proto = layout_pb2.MeshProto()
mesh_proto.name = name
for mesh_dim_str in mesh_dim_strs.split(','):
name, size_str = mesh_dim_str.split('=')
dim = mesh_proto.mesh_dimensions.add()
dim.name = name
dim.size = int(size_str)
for global_id in global_id_str.split(','):
mesh_proto.global_device_ids.append(int(global_id))
if local_id_str:
for local_id in local_id_str.split(','):
mesh_proto.local_device_ids.append(int(local_id))
if dev_str:
for dev in dev_str.split(','):
mesh_proto.local_devices.append(dev)
if global_dev_str:
for dev in global_dev_str.split(','):
mesh_proto.global_devices.append(dev)
return Mesh.from_proto(mesh_proto)
@staticmethod
def from_proto(proto: layout_pb2.MeshProto) -> 'Mesh':
"""Construct a mesh instance from input `proto`."""
shape = [dim.size for dim in proto.mesh_dimensions]
# Convert global_device ids list back into array form
global_device_ids = [int(d) for d in proto.global_device_ids]
global_device_ids = np.asarray(global_device_ids).reshape(shape)
# Construct local_device_ids list
local_device_ids = [int(d) for d in proto.local_device_ids]
# Convert local devices list back to array form
local_devices = [
tf_device.DeviceSpec.from_string(d) for d in proto.local_devices
]
# Convert global devices list back to array form
global_devices = [
tf_device.DeviceSpec.from_string(d) for d in proto.global_devices
]
name = proto.name
dims = [dim.name for dim in proto.mesh_dimensions]
return Mesh(dims, global_device_ids, local_device_ids, local_devices, name,
global_devices)
def shape(self) -> List[int]:
return [self.dim_size(dim) for dim in self._dim_names]
@property
def size(self) -> int:
return len(np.ravel(self._global_device_ids))
def __getitem__(self, dim_name: str) -> MeshDimension:
if dim_name not in self._dim_dict:
raise KeyError(
f'Dimension {dim_name} not defined in mesh: {self._dim_dict.keys()}')
return self._dim_dict[dim_name]
# TODO(b/168730933): Define a nicer mesh ID.
def __hash__(self):
return hash(self.as_proto().SerializeToString(deterministic=True))
def __eq__(self, other):
if not isinstance(other, type(self)) and not isinstance(self, type(other)):
raise ValueError('comparing with type : {0} but expecting : {1}'.format(
type(other), type(self)))
return self.as_proto().SerializeToString() == other.as_proto(
).SerializeToString()
# TODO(hthu): Consider making this class immutable.
@tf_export('experimental.dtensor.Layout', v1=[])
class Layout(object):
"""Represents the layout information of a DTensor.
A layout describes how a distributed tensor is partitioned across a mesh (and
thus across devices). For each axis of the tensor, the corresponding
sharding spec indicates which dimension of the mesh it is sharded over. A
special sharding spec `UNSHARDED` indicates that axis is replicated on
all the devices of that mesh.
For example, let's consider a 1-D mesh:
```
Mesh(["TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5"], [("x", 6)])
```
This mesh arranges 6 TPU devices into a 1-D array. `Layout([UNSHARDED], mesh)`
is a layout for rank-1 tensor which is replicated on the 6 devices.
For another example, let's consider a 2-D mesh:
```
Mesh(["TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5"],
[("x", 3), ("y", 2)])
```
This mesh arranges 6 TPU devices into a `3x2` 2-D array.
`Layout(["x", UNSHARDED], mesh)` is a layout for rank-2 tensor whose first
axis is sharded on mesh dimension "x" and the second axis is replicated. If we
place `np.arange(6).reshape((3, 2))` using this layout, the individual
components tensors would look like:
```
Device | Component
TPU:0 [[0, 1]]
TPU:1 [[0, 1]]
TPU:2 [[2, 3]]
TPU:3 [[2, 3]]
TPU:4 [[4, 5]]
TPU:5 [[4, 5]]
```
"""
def __init__(self, sharding_specs: List[str], mesh: Mesh):
"""Builds a Layout from a list of dimension names and a Mesh.
Args:
sharding_specs: List of sharding specifications, each corresponding to a
tensor axis. Each specification (dim_sharding) can either be a mesh
dimension or the special value UNSHARDED.
mesh: A mesh configuration for the Tensor.
Returns:
A valid Layout built with given layout & mesh.
"""
# Validate mesh
if not isinstance(mesh, Mesh):
raise ValueError('mesh is not a valid Mesh object.')
# Validate sharding spec
for _, dim_sharding in enumerate(sharding_specs):
# If special value no need to check for uniqueness, just skip.
if dim_sharding == UNSHARDED or dim_sharding == MATCH:
continue
# Check dim_sharding is unique.
if sharding_specs.count(dim_sharding) > 1:
raise ValueError(
('Mesh dimension {mesh_dim} was repeated in sharding ' +
'specification {sharding_specs}. Mesh dimensions must be unique ' +
'in a layout.').format(
mesh_dim=dim_sharding, sharding_specs=sharding_specs))
# Check dim_sharding is mesh dimension.
if dim_sharding not in mesh:
raise ValueError(
('{dim_sharding}: A dimension sharding must either be a ' +
'valid mesh dimension or UNSHARDED.').format(
dim_sharding=dim_sharding))
# Set object's state
self.sharding_specs = sharding_specs
self.rank = len(sharding_specs)
self.mesh = mesh
self.shape = [self.num_shards(i) for i in range(self.rank)]
@staticmethod
def from_string(layout_str: str) -> 'Layout':
"""Parses layout string."""
layout_parts = layout_str.split(' ')
if len(layout_parts) != 2:
raise ValueError(
'layout string must contain two parts: specs and mesh. But got {}.'
.format(layout_str))
sharding_specs_str = layout_parts[0].replace('sharding_specs:', '')
mesh_str = layout_parts[1].replace('mesh:', '')
sharding_specs = sharding_specs_str.split(',')[:-1]
mesh = Mesh.from_string(mesh_str)
layout = Layout(sharding_specs, mesh)
return layout
@staticmethod
def from_str(layout_str: bytes) -> 'Layout':
layout_proto = layout_pb2.LayoutProto()
layout_proto.ParseFromString(layout_str)
sharding_specs = [
sharding_spec.sharding_spec
for sharding_spec in layout_proto.sharding_specs
]
mesh = Mesh.from_proto(layout_proto.mesh_config)
return Layout(sharding_specs, mesh)
def offset_to_shard(self):
"""Mapping from offset in a flattened list to shard index."""
unravel_index = self.mesh.unravel_index()
locations = [None] * self.mesh.size
for offset, mesh_loc in unravel_index.items():
loc = []
for dim_sharding in self.sharding_specs:
if dim_sharding == UNSHARDED:
loc.append(0)
else:
loc.append(mesh_loc[dim_sharding])
locations[offset] = tuple(loc)
return locations
def offset_tuple_to_global_index(self, offset_tuple):
"""Mapping from offset to index in global tensor."""
index = 0
for i, o in enumerate(offset_tuple):
m = 1
for x in range(i + 1, self.rank):
m = m * self.num_shards(x)
index = index + m * o
return index
def unravel(self, unpacked_tensors: List[np.ndarray]) -> np.ndarray:
"""Convert a flattened list of shards into a sharded array."""
unravelled = np.ndarray([self.num_shards(i) for i in range(self.rank)],
dtype=np.object)
for offset, loc in enumerate(self.offset_to_shard()):
unravelled[loc] = unpacked_tensors[offset]
return unravelled
def num_shards(self, idx: int) -> int:
"""Returns the number of shards for tensor dimension `idx`."""
dim_sharding = self.sharding_specs[idx]
if dim_sharding == UNSHARDED:
return 1
if dim_sharding == MATCH:
return -1
return self.mesh.dim_size(dim_sharding)
def as_proto(self) -> layout_pb2.LayoutProto:
"""Create a proto representation of a layout."""
layout_proto = layout_pb2.LayoutProto()
for dim_sharding in self.sharding_specs:
tensor_dim = layout_proto.sharding_specs.add()
tensor_dim.sharding_spec = dim_sharding
layout_proto.mesh_config.CopyFrom(self.mesh_proto())
return layout_proto
def mesh_proto(self) -> layout_pb2.MeshProto:
return self.mesh.as_proto()
def is_fully_replicated(self) -> bool:
return all([self.num_shards(i) == 1 for i in range(self.rank)])
# A layout with no sharding specs is acceptable, therefore we only check the
# mesh.
def to_string(self) -> str:
"""Returns string representation of Layout."""
sharding_spec_str = 'sharding_specs:'
# Add comma after each instruction.
for spec in self.sharding_specs:
sharding_spec_str += spec + ','
mesh_str = 'mesh:' + self.mesh.to_string()
return sharding_spec_str + ' ' + mesh_str
def serialized_string(self) -> bytes:
return self.as_proto().SerializeToString()
def __eq__(self, other) -> bool:
return self.serialized_string() == other.serialized_string()
def __repr__(self) -> str:
return str(self.as_proto())
@staticmethod
def replicated(mesh: Mesh, rank: int) -> 'Layout':
"""Returns a replicated layout of rank `rank`."""
return Layout([UNSHARDED] * rank, mesh)
@staticmethod
def batch_sharded(mesh: Mesh, batch_dim: str, rank: int) -> 'Layout':
"""Returns a layout sharded on batch dimension."""
return Layout([batch_dim] + [UNSHARDED] * (rank - 1), mesh)
@staticmethod
def inner_sharded(mesh: Mesh, inner_dim: str, rank: int) -> 'Layout':
"""Returns a layout sharded on inner dimension."""
return Layout([UNSHARDED] * (rank - 1) + [inner_dim], mesh)
def delete(self, dims: List[int]) -> 'Layout':
"""Returns the layout with the give dimensions deleted."""
if not isinstance(dims, list):
dims = [dims]
new_specs = [
spec for i, spec in enumerate(self.sharding_specs) if i not in dims
]
return Layout(new_specs, self.mesh)