Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DTensor][Shampoo] add _tenso.zero function #95863

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
132 changes: 128 additions & 4 deletions test/distributed/_tensor/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@
# Owner(s): ["oncall: distributed"]

import torch
from torch.distributed._tensor import (
DTensor,
Shard,
)
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard, zeros
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
Expand Down Expand Up @@ -40,5 +37,132 @@ def test_init_ops(self):
self._run_init_op(torch.nn.init.constant_, 2.4)


class DTensorConstructorTest(DTensorTestBase):

@property
def world_size(self):
return 4

@with_comms
def test_zeros_full_mesh(self):
# construct a cuda device 1d mesh
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
placements = [Shard(0)]
size = [32, 3]
dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
self.assertEqual(dist_tensor.size(), torch.Size(size))
local_tensor = dist_tensor.to_local()
self.assertEqual(local_tensor.size(), torch.Size([8, 3]))

local_tensor = torch.zeros(8, 3)
self.assertEqual(dist_tensor.to_local(), local_tensor)

self.assertEqual(dist_tensor.device.type, self.device_type)

# 1d sharded unevenly
size = [31, 3]
dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
self.assertEqual(dist_tensor.size(), torch.Size(size))
local_tensor = dist_tensor.to_local()
if self.rank <= 2:
self.assertEqual(local_tensor.size(), torch.Size([8, 3]))
self.assertEqual(torch.zeros(8, 3), local_tensor)
else:
self.assertEqual(local_tensor.size(), torch.Size([7, 3]))
self.assertEqual(torch.zeros(7, 3), local_tensor)

# construct a cuda device mesh with 2d: shard, replicate
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))
placements = [Shard(0), Replicate()]
size = [32, 4]
dist_tensor = zeros(size, device_mesh=mesh, placements=placements)

self.assertEqual(dist_tensor.size(), torch.Size(size))
local_tensor = dist_tensor.to_local()
self.assertEqual(local_tensor.size(), torch.Size([16, 4]))
self.assertEqual(local_tensor, torch.zeros([16, 4]))

# construct a cuda device mesh with 2d: shard, shard
placements = [Shard(0), Shard(1)]
size = [32, 4]
dist_tensor = zeros(size, device_mesh=mesh, placements=placements)

self.assertEqual(dist_tensor.size(), torch.Size(size))
local_tensor = dist_tensor.to_local()
self.assertEqual(local_tensor.size(), torch.Size([16, 2]))
self.assertEqual(local_tensor, torch.zeros([16, 2]))

# 2d sharded unevenly
placements = [Shard(0), Shard(1)]
size = [31, 3]
dist_tensor = zeros(size, device_mesh=mesh, placements=placements)

self.assertEqual(dist_tensor.size(), torch.Size(size))
local_tensor = dist_tensor.to_local()
if self.rank == 0:
self.assertEqual(local_tensor, torch.zeros([16, 2]))
elif self.rank == 1:
self.assertEqual(local_tensor, torch.zeros([16, 1]))
elif self.rank == 2:
self.assertEqual(local_tensor, torch.zeros([15, 2]))
elif self.rank == 3:
self.assertEqual(local_tensor, torch.zeros([15, 1]))

@with_comms
def test_zeros_submesh(self):
# default world_size is 4
# construct a cuda device 1d mesh
sub_mesh_list = [0, 3]
mesh = DeviceMesh(self.device_type, sub_mesh_list)
placements = [Shard(0)]
size = [32, 3]
dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
self.assertEqual(dist_tensor.size(), torch.Size(size))
local_tensor = dist_tensor.to_local()

if self.rank in sub_mesh_list:
self.assertEqual(local_tensor.size(), torch.Size([16, 3]))
self.assertEqual(local_tensor, torch.zeros([16, 3]))
else:
self.assertEqual(local_tensor.size(), torch.Size([0]))
self.assertEqual(local_tensor, torch.tensor([]))

# construct a cuda device 1d mesh: unevenly
sub_mesh_list = [0, 1, 3]
mesh = DeviceMesh(self.device_type, sub_mesh_list)
placements = [Shard(0)]
size = [32, 3]
dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
self.assertEqual(dist_tensor.size(), torch.Size(size))
local_tensor = dist_tensor.to_local()

if self.rank in sub_mesh_list:
if self.rank != 3:
self.assertEqual(local_tensor.size(), torch.Size([11, 3]))
self.assertEqual(local_tensor, torch.zeros([11, 3]))
else:
self.assertEqual(local_tensor.size(), torch.Size([10, 3]))
self.assertEqual(local_tensor, torch.zeros([10, 3]))
else:
self.assertEqual(local_tensor.size(), torch.Size([0]))
self.assertEqual(local_tensor, torch.tensor([]))

# construct a cuda device 2d mesh
sub_mesh_list = [[0], [3]]
mesh = DeviceMesh(self.device_type, sub_mesh_list)
placements = [Shard(0), Shard(1)]
size = [32, 3]
dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
self.assertEqual(dist_tensor.size(), torch.Size(size))
local_tensor = dist_tensor.to_local()

if self.rank in [0, 3]:
self.assertEqual(local_tensor.size(), torch.Size([16, 3]))
self.assertEqual(local_tensor, torch.zeros([16, 3]))
else:
self.assertEqual(local_tensor.size(), torch.Size([0]))
self.assertEqual(local_tensor, torch.tensor([]))


if __name__ == "__main__":
run_tests()
74 changes: 74 additions & 0 deletions test/distributed/_tensor/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Owner(s): ["oncall: distributed"]

import torch
from torch.distributed._tensor.device_mesh import DeviceMesh
from torch.distributed._tensor.placement_types import Replicate, Shard
from torch.distributed._tensor.utils import compute_local_tensor_size

from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)


class UtilTest(DTensorTestBase):
@property
def world_size(self):
return 8

@with_comms
def test_compute_local_tensor_size_2d(self):
# mesh: 4 * 2
mesh_tensor = torch.arange(self.world_size).reshape(4, 2)
mesh = DeviceMesh(self.device_type, mesh_tensor)
size = torch.Size([8, 6])

# replicate, replicate
placements1 = [Replicate(), Replicate()]
local_size1 = compute_local_tensor_size(size, mesh, placements1)
self.assertEqual(local_size1, torch.Size([8, 6]))

# replicate, shard
placements2 = [Replicate(), Shard(0)]
local_size2 = compute_local_tensor_size(size, mesh, placements2)
self.assertEqual(local_size2, torch.Size([4, 6]))

# shard, shard
placements3 = [Shard(0), Shard(1)]
local_size3 = compute_local_tensor_size(size, mesh, placements3)
self.assertEqual(local_size3, torch.Size([2, 3]))

@with_comms
def test_compute_local_tensor_size_2d_not_evenly(self):
# mesh: 4 * 2
mesh_tensor = torch.arange(self.world_size).reshape(4, 2)
mesh = DeviceMesh(self.device_type, mesh_tensor)
size = torch.Size([7, 7])
rank_coordinates = mesh.get_coordinate()

# replicate, shard
placements2 = [Replicate(), Shard(0)]
local_size2 = compute_local_tensor_size(size, mesh, placements2)
if rank_coordinates[1] < 1:
self.assertEqual(local_size2, torch.Size([4, 7]))
else:
self.assertEqual(local_size2, torch.Size([3, 7]))

# shard, shard
placements3 = [Shard(0), Shard(1)]
local_size3 = compute_local_tensor_size(size, mesh, placements3)
# first dim
if rank_coordinates[0] < 3:
self.assertEqual(local_size3[0], 2)
else:
self.assertEqual(local_size3[0], 1)
# second dim
if rank_coordinates[1] < 1:
self.assertEqual(local_size3[1], 4)
else:
self.assertEqual(local_size3[1], 3)


if __name__ == "__main__":
run_tests()
75 changes: 73 additions & 2 deletions torch/distributed/_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from typing import Callable, cast, Optional, Sequence

# Import all builtin dist tensor ops
import torch
import torch.distributed._tensor.ops
from torch.distributed._tensor.api import DTensor, distribute_tensor, distribute_module
from torch.distributed._tensor.api import distribute_module, distribute_tensor, DTensor
from torch.distributed._tensor.device_mesh import DeviceMesh, get_global_device_mesh
from torch.distributed._tensor.placement_types import Placement, Replicate, Shard

from torch.distributed._tensor.utils import compute_local_tensor_size

# All public APIs from dtensor package
__all__ = [
Expand All @@ -17,3 +18,73 @@
"Shard",
"Replicate",
]


def zeros(
*size,
requires_grad: bool = False,
dtype: torch.dtype = None,
layout: torch.layout = torch.strided,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,
) -> DTensor:
"""
Returns a :class:`DTensor` filled with the scalar value 0.

Args:
size (int...): a sequence of integers defining the shape of the output
Dtensor. Can be a variable number of arguments or a collection like a list or tuple.
E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..))
Keyword args:
requires_grad (bool, optional): If autograd should record operations on the
returned tensor. Default: ``False``.
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
Default: ``torch.strided``.
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
placement: a sequence of :class:`Placement` type: Shard, Replicate, _Partial

Returns:
A :class:`DTensor` object on each rank
"""
# if device_mesh is None, use the global one
device_mesh = get_global_device_mesh() if device_mesh is None else device_mesh
# set default placements to replicated if not specified
if placements is None:
placements = [Replicate() for _ in range(device_mesh.ndim)]
assert device_mesh.ndim == len(
placements
), "mesh dimension doesnot match the length of placements"

if len(size) == 1 and isinstance(size[0], Sequence):
torch_size = size[0]
else:
torch_size = list(size)
torch_size = torch.Size(torch_size)
assert layout == torch.strided, "layout value not supported!"
torch_stride = torch._prims_common.make_contiguous_strides_for(torch_size)

local_size = compute_local_tensor_size(torch_size, device_mesh, placements)
if local_size is None:
local_tensor = torch.tensor([], dtype=dtype, requires_grad=requires_grad)
else:
local_tensor = torch.zeros(
local_size,
device=device_mesh.device_type,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
)

dtensor = DTensor(
local_tensor=local_tensor,
device_mesh=device_mesh,
placements=placements,
shape=torch_size,
dtype=local_tensor.dtype,
stride=torch_stride,
requires_grad=requires_grad,
)

return dtensor
4 changes: 2 additions & 2 deletions torch/distributed/_tensor/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,8 @@ def get_rank(self) -> int:

def get_coordinate(self) -> Optional[List[int]]:
"""
Return the relative index of this rank relative to a given
dimension of the mesh. If this rank is not part of the mesh, return None.
Return the relative indices of this rank relative to all
dimensions of the mesh. If this rank is not part of the mesh, return None.
"""
return self._coordinate_on_dim if self._coordinate_on_dim else None

Expand Down
43 changes: 43 additions & 0 deletions torch/distributed/_tensor/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Optional, Sequence

import torch
from torch.distributed._tensor.device_mesh import DeviceMesh
from torch.distributed._tensor.placement_types import Placement, Replicate, Shard


def compute_local_tensor_size(
size: torch.Size, device_mesh: DeviceMesh, placements: Sequence[Placement]
) -> Optional[torch.Size]:
"""
Args:
size(torch.Size): define the shape of the whole Dtensor.
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
placement: a sequence of :class:`Placement` type: Shard, Replicate

Returns:
A :class:`torch.Size` for the local tensor on the device_mesh
"""
if device_mesh.get_coordinate() is None:
return None
else:
local_size = list(size)
rank_coordinates = device_mesh.get_coordinate()
if rank_coordinates is None:
return None
for idx, placement in enumerate(placements):
if isinstance(placement, Replicate):
continue
elif isinstance(placement, Shard):
curr_coordinate = rank_coordinates[idx]
shard_dim = placement.dim
len_before_shard = local_size[shard_dim]
num_chucks = device_mesh.size(idx)

len_after_shard, _ = placement._local_shard_size_on_dim(
len_before_shard, num_chucks, curr_coordinate
)
local_size[shard_dim] = len_after_shard
else:
raise RuntimeError(f"placement type {type(placement)} not supported!")

return torch.Size(local_size)