Skip to content

Commit

Permalink
[DTensor][Shampoo] add _tenso.zero function (#95863)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #95863

implement zeros function inside DTensor API
- user specify the zeros tensor shape, and the function will create local zero tensor given the placement information

Test Plan:
{F889157756} - unit test for util function for compute_local_tensor_size
- unit test for _tensor.zeros

Reviewed By: wanchaol

Differential Revision: D43630718

fbshipit-source-id: 85adfdf0f2c6f5e2fcbc9df84e90e8d6abc7d2c5
  • Loading branch information
minddrummer authored and facebook-github-bot committed Mar 2, 2023
1 parent d9f822b commit 51e1778
Show file tree
Hide file tree
Showing 5 changed files with 319 additions and 8 deletions.
125 changes: 121 additions & 4 deletions test/distributed/_tensor/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@
# Owner(s): ["oncall: distributed"]

import torch
from torch.distributed._tensor import (
DTensor,
Shard,
)
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard, zeros
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
Expand Down Expand Up @@ -39,6 +36,126 @@ def test_init_ops(self):
self._run_init_op(torch.nn.init.uniform_, a=0, b=1.2)
self._run_init_op(torch.nn.init.constant_, 2.4)

@with_comms
def test_zeros_full_mesh(self):
# construct a cuda device 1d mesh
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
placements = [Shard(0)]
size = [32, 3]
dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
self.assertEqual(dist_tensor.size(), torch.Size(size))
local_tensor = dist_tensor.to_local()
self.assertEqual(local_tensor.size(), torch.Size([8, 3]))

local_tensor = torch.zeros(8, 3)
self.assertEqual(dist_tensor.to_local(), local_tensor)

self.assertEqual(dist_tensor.device.type, self.device_type)

# 1d sharded unevenly
size = [31, 3]
dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
self.assertEqual(dist_tensor.size(), torch.Size(size))
local_tensor = dist_tensor.to_local()
if self.rank <= 2:
self.assertEqual(local_tensor.size(), torch.Size([8, 3]))
self.assertEqual(torch.zeros(8, 3), local_tensor)
else:
self.assertEqual(local_tensor.size(), torch.Size([7, 3]))
self.assertEqual(torch.zeros(7, 3), local_tensor)

# construct a cuda device mesh with 2d: shard, replicate
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))
placements = [Shard(0), Replicate()]
size = [32, 4]
dist_tensor = zeros(size, device_mesh=mesh, placements=placements)

self.assertEqual(dist_tensor.size(), torch.Size(size))
local_tensor = dist_tensor.to_local()
self.assertEqual(local_tensor.size(), torch.Size([16, 4]))
self.assertEqual(local_tensor, torch.zeros([16, 4]))

# construct a cuda device mesh with 2d: shard, shard
placements = [Shard(0), Shard(1)]
size = [32, 4]
dist_tensor = zeros(size, device_mesh=mesh, placements=placements)

self.assertEqual(dist_tensor.size(), torch.Size(size))
local_tensor = dist_tensor.to_local()
self.assertEqual(local_tensor.size(), torch.Size([16, 2]))
self.assertEqual(local_tensor, torch.zeros([16, 2]))

# 2d sharded unevenly
placements = [Shard(0), Shard(1)]
size = [31, 3]
dist_tensor = zeros(size, device_mesh=mesh, placements=placements)

self.assertEqual(dist_tensor.size(), torch.Size(size))
local_tensor = dist_tensor.to_local()
if self.rank == 0:
self.assertEqual(local_tensor, torch.zeros([16, 2]))
elif self.rank == 1:
self.assertEqual(local_tensor, torch.zeros([16, 1]))
elif self.rank == 2:
self.assertEqual(local_tensor, torch.zeros([15, 2]))
elif self.rank == 3:
self.assertEqual(local_tensor, torch.zeros([15, 1]))

@with_comms
def test_zeros_submesh(self):
# default world_size is 4
# construct a cuda device 1d mesh
sub_mesh_list = [0, 3]
mesh = DeviceMesh(self.device_type, sub_mesh_list)
placements = [Shard(0)]
size = [32, 3]
dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
self.assertEqual(dist_tensor.size(), torch.Size(size))
local_tensor = dist_tensor.to_local()

if self.rank in sub_mesh_list:
self.assertEqual(local_tensor.size(), torch.Size([16, 3]))
self.assertEqual(local_tensor, torch.zeros([16, 3]))
else:
self.assertEqual(local_tensor.size(), torch.Size([0]))
self.assertEqual(local_tensor, torch.tensor([]))

# construct a cuda device 1d mesh: unevenly
sub_mesh_list = [0, 1, 3]
mesh = DeviceMesh(self.device_type, sub_mesh_list)
placements = [Shard(0)]
size = [32, 3]
dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
self.assertEqual(dist_tensor.size(), torch.Size(size))
local_tensor = dist_tensor.to_local()

if self.rank in sub_mesh_list:
if self.rank != 3:
self.assertEqual(local_tensor.size(), torch.Size([11, 3]))
self.assertEqual(local_tensor, torch.zeros([11, 3]))
else:
self.assertEqual(local_tensor.size(), torch.Size([10, 3]))
self.assertEqual(local_tensor, torch.zeros([10, 3]))
else:
self.assertEqual(local_tensor.size(), torch.Size([0]))
self.assertEqual(local_tensor, torch.tensor([]))

# construct a cuda device 2d mesh
sub_mesh_list = [[0], [3]]
mesh = DeviceMesh(self.device_type, sub_mesh_list)
placements = [Shard(0), Shard(1)]
size = [32, 3]
dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
self.assertEqual(dist_tensor.size(), torch.Size(size))
local_tensor = dist_tensor.to_local()

if self.rank in [0, 3]:
self.assertEqual(local_tensor.size(), torch.Size([16, 3]))
self.assertEqual(local_tensor, torch.zeros([16, 3]))
else:
self.assertEqual(local_tensor.size(), torch.Size([0]))
self.assertEqual(local_tensor, torch.tensor([]))


if __name__ == "__main__":
run_tests()
79 changes: 79 additions & 0 deletions test/distributed/_tensor/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
# Owner(s): ["oncall: distributed"]

import os
import sys

import torch
from torch.distributed._tensor.device_mesh import DeviceMesh
from torch.distributed._tensor.placement_types import Placement, Replicate, Shard
from torch.distributed._tensor.utils import compute_local_tensor_size
from torch.testing._internal.common_distributed import TEST_SKIPS

from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)


class UtilTest(DTensorTestBase):
@property
def world_size(self):
return 8

@with_comms
def test_compute_local_tensor_size_2d(self):
# mesh: 4 * 2
mesh_tensor = torch.arange(self.world_size).reshape(4, 2)
mesh = DeviceMesh(self.device_type, mesh_tensor)
size = torch.Size([8, 6])

# replicate, replicate
placements1 = [Replicate(), Replicate()]
local_size1 = compute_local_tensor_size(size, mesh, placements1)
self.assertEqual(local_size1, torch.Size([8, 6]))

# replicate, shard
placements2 = [Replicate(), Shard(0)]
local_size2 = compute_local_tensor_size(size, mesh, placements2)
self.assertEqual(local_size2, torch.Size([4, 6]))

# shard, shard
placements3 = [Shard(0), Shard(1)]
local_size3 = compute_local_tensor_size(size, mesh, placements3)
self.assertEqual(local_size3, torch.Size([2, 3]))

@with_comms
def test_compute_local_tensor_size_2d_not_evenly(self):
# mesh: 4 * 2
mesh_tensor = torch.arange(self.world_size).reshape(4, 2)
mesh = DeviceMesh(self.device_type, mesh_tensor)
size = torch.Size([7, 7])
rank_coordinates = mesh.get_coordinate()

# replicate, shard
placements2 = [Replicate(), Shard(0)]
local_size2 = compute_local_tensor_size(size, mesh, placements2)
if rank_coordinates[1] < 1:
self.assertEqual(local_size2, torch.Size([4, 7]))
else:
self.assertEqual(local_size2, torch.Size([3, 7]))

# shard, shard
placements3 = [Shard(0), Shard(1)]
local_size3 = compute_local_tensor_size(size, mesh, placements3)
# first dim
if rank_coordinates[0] < 3:
self.assertEqual(local_size3[0], 2)
else:
self.assertEqual(local_size3[0], 1)
# second dim
if rank_coordinates[1] < 1:
self.assertEqual(local_size3[1], 4)
else:
self.assertEqual(local_size3[1], 3)


if __name__ == "__main__":
run_tests()
75 changes: 73 additions & 2 deletions torch/distributed/_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from typing import Callable, cast, Optional, Sequence

# Import all builtin dist tensor ops
import torch
import torch.distributed._tensor.ops
from torch.distributed._tensor.api import DTensor, distribute_tensor, distribute_module
from torch.distributed._tensor.api import distribute_module, distribute_tensor, DTensor
from torch.distributed._tensor.device_mesh import DeviceMesh, get_global_device_mesh
from torch.distributed._tensor.placement_types import Placement, Replicate, Shard

from torch.distributed._tensor.utils import compute_local_tensor_size

# All public APIs from dtensor package
__all__ = [
Expand All @@ -17,3 +18,73 @@
"Shard",
"Replicate",
]


def zeros(
*size,
requires_grad: bool = False,
dtype: torch.dtype = None,
layout: torch.layout = torch.strided,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with the scalar value 0.
Args:
size (int...): a sequence of integers defining the shape of the output
Dtensor. Can be a variable number of arguments or a collection like a list or tuple.
E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..))
Keyword args:
requires_grad (bool, optional): If autograd should record operations on the
returned tensor. Default: ``False``.
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
Default: ``torch.strided``.
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
placement: a sequence of :class:`Placement` type: Shard, Replicate, _Partial
Returns:
A :class:`DTensor` object on each rank
"""
# if device_mesh is None, use the global one
device_mesh = get_global_device_mesh() if device_mesh is None else device_mesh
# set default placements to replicated if not specified
if placements is None:
placements = [Replicate() for _ in range(device_mesh.ndim)]
assert device_mesh.ndim == len(
placements
), "mesh dimension doesnot match the length of placements"

if len(size) == 1 and isinstance(size[0], Sequence):
torch_size = size[0]
else:
torch_size = list(size)
torch_size = torch.Size(torch_size)
assert layout == torch.strided, "layout value not supported!"
torch_stride = torch._prims_common.make_contiguous_strides_for(torch_size)

local_size = compute_local_tensor_size(torch_size, device_mesh, placements)
if local_size is None:
local_tensor = torch.tensor([], dtype=dtype, requires_grad=requires_grad)
else:
local_tensor = torch.zeros(
local_size,
device=device_mesh.device_type,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
)

dtensor = DTensor(
local_tensor=local_tensor,
device_mesh=device_mesh,
placements=placements,
shape=torch_size,
dtype=local_tensor.dtype,
stride=torch_stride,
requires_grad=requires_grad,
)

return dtensor
4 changes: 2 additions & 2 deletions torch/distributed/_tensor/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,8 @@ def get_rank(self) -> int:

def get_coordinate(self) -> Optional[List[int]]:
"""
Return the relative index of this rank relative to a given
dimension of the mesh. If this rank is not part of the mesh, return None.
Return the relative indices of this rank relative to all
dimensions of the mesh. If this rank is not part of the mesh, return None.
"""
return self._coordinate_on_dim if self._coordinate_on_dim else None

Expand Down
44 changes: 44 additions & 0 deletions torch/distributed/_tensor/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
# Owner(s): ["oncall: distributed"]

from typing import Optional, Sequence

import torch
from torch.distributed._tensor.device_mesh import DeviceMesh
from torch.distributed._tensor.placement_types import Placement, Replicate, Shard


def compute_local_tensor_size(
size: torch.Size, device_mesh: DeviceMesh, placements: Sequence[Placement]
) -> Optional[torch.Size]:
"""
Args:
size(torch.Size): define the shape of the whole Dtensor.
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
placement: a sequence of :class:`Placement` type: Shard, Replicate
Returns:
A :class:`torch.Size` for the local tensor on the device_mesh
"""
if device_mesh.get_coordinate() is None:
return None
else:
local_size = list(size)
rank_coordinates = device_mesh.get_coordinate()
for idx, placement in enumerate(placements):
if isinstance(placement, Replicate):
continue
elif isinstance(placement, Shard):
curr_coordinate = rank_coordinates[idx]
shard_dim = placement.dim
len_before_shard = local_size[shard_dim]
num_chucks = device_mesh.size(idx)

len_after_shard, _ = placement._local_shard_size_on_dim(
len_before_shard, num_chucks, curr_coordinate
)
local_size[shard_dim] = len_after_shard
else:
raise RuntimeError(f"placement type {type(placement)} not supported!")

return torch.Size(local_size)

0 comments on commit 51e1778

Please sign in to comment.