Skip to content

Commit

Permalink
Update documentation for Mesh and Layout.
Browse files Browse the repository at this point in the history
Also clean up other docstrings in api.py

PiperOrigin-RevId: 441292565
  • Loading branch information
srujun authored and tensorflow-jenkins committed Apr 12, 2022
1 parent 8727d03 commit f9f3a19
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 14 deletions.
22 changes: 20 additions & 2 deletions tensorflow/dtensor/python/api.py
Expand Up @@ -175,13 +175,31 @@ def unpack(tensor: Any) -> Sequence[Any]:

@tf_export("experimental.dtensor.fetch_layout", v1=[])
def fetch_layout(tensor: ops.Tensor) -> layout_lib.Layout:
"""Returns the layout of a DTensor."""
"""Fetches the layout of a DTensor.
Args:
tensor: The DTensor whose layout is to be fetched.
Returns:
The `Layout` of this DTensor.
Raises:
RuntimeError: When not called eagerly.
"""
return _dtensor_device().fetch_layout(tensor)


@tf_export("experimental.dtensor.check_layout", v1=[])
def check_layout(tensor: ops.Tensor, layout: layout_lib.Layout) -> None:
"""Asserts that the layout of `tensor` is `layout`."""
"""Asserts that the layout of the DTensor is `layout`.
Args:
tensor: A DTensor whose layout is to be checked.
layout: The `Layout` to compare against.
Raises:
ValueError: If the layout of `tensor` does not match the supplied `layout`.
"""
if fetch_layout(tensor) != layout:
raise ValueError("Layout of tensor: " + str(fetch_layout(tensor)) +
", did not match expected layout: " + str(layout))
Expand Down
97 changes: 85 additions & 12 deletions tensorflow/dtensor/python/layout.py
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python definitions for Mesh and Layout."""
"""Python definitions for `Mesh` and `Layout`."""

import collections
import itertools
Expand All @@ -39,9 +39,31 @@

@tf_export('experimental.dtensor.Mesh', v1=[])
class Mesh(object):
"""Represents a Mesh configuration over a certain list of Mesh Dimensions."""
"""Represents a Mesh configuration over a certain list of Mesh Dimensions.
A mesh consists of named dimensions with sizes, which describe how a set of
devices are arranged. Defining tensor layouts in terms of mesh dimensions
allows us to efficiently determine the communication required when computing
an operation with tensors of different layouts.
A mesh provides information not only about the placement of the tensors but
also the topology of the underlying devices. For example, we can group 8 TPUs
as a 1-D array for data parallelism or a `2x4` grid for (2-way) data
parallelism and (4-way) model parallelism.
Note: the utilities `dtensor.create_mesh` and
`dtensor.create_distributed_mesh` provide a simpler API to create meshes for
single- or multi-client use cases.
"""

_dim_dict: Dict[str, MeshDimension]
_dim_names: List[str]
_local_device_ids: List[int]
_global_device_ids: np.ndarray
_name: str
_local_devices = List[tf_device.DeviceSpec]
_global_devices = Optional[List[tf_device.DeviceSpec]]
_device_type: str

def __init__(self,
dim_names: List[str],
Expand All @@ -52,17 +74,25 @@ def __init__(self,
global_devices: Optional[List[tf_device.DeviceSpec]] = None):
"""Builds a Mesh.
The dim_names and global_device_ids arguments describe the dimension names
and shape for the mesh.
The `dim_names` and `global_device_ids` arguments describe the dimension
names and shape for the mesh.
For example,
```python
dim_names = ('x', 'y'),
global_device_ids = [[0, 1],
[2, 3],
[4, 5]]
```
defines a 2D mesh of shape 3x2. A reduction over the 'x' dimension will
reduce across columns (0, 2, 4), and a reduction over the 'y' dimension
reduces across rows.
reduce across columns (0, 2, 4) and (1, 3, 5), and a reduction over the 'y'
dimension reduces across rows.
Note: the utilities `dtensor.create_mesh` and
`dtensor.create_distributed_mesh` provide a simpler API to create meshes for
single- or multi-client use cases.
Args:
dim_names: A list of strings indicating dimension names.
Expand Down Expand Up @@ -174,8 +204,7 @@ def host_mesh(self):
len(self._local_devices), self.to_string(),
v_cpus_counts))
device_array = np.asarray([
spec.replace(device_type='CPU')
for spec in self._local_devices
spec.replace(device_type='CPU') for spec in self._local_devices
]).reshape((len(self._local_devices), 1))
global_devices = None
if self._global_devices:
Expand Down Expand Up @@ -212,12 +241,15 @@ def unravel_index(self):
"""Returns a dictionary from device ID to {dim_name: dim_index}.
For example, for a 3x2 mesh, return this:
```
{ 0: {'x': 0, 'y', 0},
1: {'x': 0, 'y', 1},
2: {'x': 1, 'y', 0},
3: {'x': 1, 'y', 1},
4: {'x': 2, 'y', 0},
5: {'x': 2, 'y', 1} }.
5: {'x': 2, 'y', 1} }
```
"""
idx_ranges = [
range(self.dim_size(dim_name)) for dim_name in self._dim_names
Expand Down Expand Up @@ -399,14 +431,54 @@ def __eq__(self, other):
# TODO(hthu): Consider making this class immutable.
@tf_export('experimental.dtensor.Layout', v1=[])
class Layout(object):
"""Represents the layout information for a Tensor."""
"""Represents the layout information of a DTensor.
A layout describes how a distributed tensor is partitioned across a mesh (and
thus across devices). For each axis of the tensor, the corresponding
sharding spec indicates which dimension of the mesh it is sharded over. A
special sharding spec `UNSHARDED` indicates that axis is replicated on
all the devices of that mesh.
For example, let's consider a 1-D mesh:
```
Mesh(["TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5"], [("x", 6)])
```
This mesh arranges 6 TPU devices into a 1-D array. `Layout([UNSHARDED], mesh)`
is a layout for rank-1 tensor which is replicated on the 6 devices.
For another example, let's consider a 2-D mesh:
```
Mesh(["TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5"],
[("x", 3), ("y", 2)])
```
This mesh arranges 6 TPU devices into a `3x2` 2-D array.
`Layout(["x", UNSHARDED], mesh)` is a layout for rank-2 tensor whose first
axis is sharded on mesh dimension "x" and the second axis is replicated. If we
place `np.arange(6).reshape((3, 2))` using this layout, the individual
components tensors would look like:
```
Device | Component
TPU:0 [[0, 1]]
TPU:1 [[0, 1]]
TPU:2 [[2, 3]]
TPU:3 [[2, 3]]
TPU:4 [[4, 5]]
TPU:5 [[4, 5]]
```
"""

def __init__(self, sharding_specs: List[str], mesh: Mesh):
"""Builds a Layout from a list of dimension names and a Mesh.
Args:
sharding_specs: List of sharding specifications, each corresponding to a
tensor dimension. Each specification (dim_sharding) can either be a mesh
tensor axis. Each specification (dim_sharding) can either be a mesh
dimension or the special value UNSHARDED.
mesh: A mesh configuration for the Tensor.
Expand Down Expand Up @@ -571,5 +643,6 @@ def delete(self, dims: List[int]) -> 'Layout':
if not isinstance(dims, list):
dims = [dims]
new_specs = [
spec for i, spec in enumerate(self.sharding_specs) if i not in dims]
spec for i, spec in enumerate(self.sharding_specs) if i not in dims
]
return Layout(new_specs, self.mesh)

0 comments on commit f9f3a19

Please sign in to comment.