-
Notifications
You must be signed in to change notification settings - Fork 21.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[DTensor][Shampoo] add _tenso.zero function (#95863)
Summary: Pull Request resolved: #95863 implement zeros function inside DTensor API - user specify the zeros tensor shape, and the function will create local zero tensor given the placement information Test Plan: {F889157756} - unit test for util function for compute_local_tensor_size - unit test for _tensor.zeros Reviewed By: wanchaol Differential Revision: D43630718 fbshipit-source-id: 14b40863b40df696bc87a92166953f5ba2c5425d
- Loading branch information
1 parent
d9f822b
commit 7760204
Showing
5 changed files
with
319 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. | ||
# Owner(s): ["oncall: distributed"] | ||
|
||
import os | ||
import sys | ||
|
||
import torch | ||
from torch.distributed._tensor.device_mesh import DeviceMesh | ||
from torch.distributed._tensor.placement_types import Placement, Replicate, Shard | ||
from torch.distributed._tensor.utils import compute_local_tensor_size | ||
from torch.testing._internal.common_distributed import TEST_SKIPS | ||
|
||
from torch.testing._internal.common_utils import run_tests | ||
from torch.testing._internal.distributed._tensor.common_dtensor import ( | ||
DTensorTestBase, | ||
with_comms, | ||
) | ||
|
||
|
||
class UtilTest(DTensorTestBase): | ||
@property | ||
def world_size(self): | ||
return 8 | ||
|
||
@with_comms | ||
def test_compute_local_tensor_size_2d(self): | ||
# mesh: 4 * 2 | ||
mesh_tensor = torch.arange(self.world_size).reshape(4, 2) | ||
mesh = DeviceMesh(self.device_type, mesh_tensor) | ||
size = torch.Size([8, 6]) | ||
|
||
# replicate, replicate | ||
placements1 = [Replicate(), Replicate()] | ||
local_size1 = compute_local_tensor_size(size, mesh, placements1) | ||
self.assertEqual(local_size1, torch.Size([8, 6])) | ||
|
||
# replicate, shard | ||
placements2 = [Replicate(), Shard(0)] | ||
local_size2 = compute_local_tensor_size(size, mesh, placements2) | ||
self.assertEqual(local_size2, torch.Size([4, 6])) | ||
|
||
# shard, shard | ||
placements3 = [Shard(0), Shard(1)] | ||
local_size3 = compute_local_tensor_size(size, mesh, placements3) | ||
self.assertEqual(local_size3, torch.Size([2, 3])) | ||
|
||
@with_comms | ||
def test_compute_local_tensor_size_2d_not_evenly(self): | ||
# mesh: 4 * 2 | ||
mesh_tensor = torch.arange(self.world_size).reshape(4, 2) | ||
mesh = DeviceMesh(self.device_type, mesh_tensor) | ||
size = torch.Size([7, 7]) | ||
rank_coordinates = mesh.get_coordinate() | ||
|
||
# replicate, shard | ||
placements2 = [Replicate(), Shard(0)] | ||
local_size2 = compute_local_tensor_size(size, mesh, placements2) | ||
if rank_coordinates[1] < 1: | ||
self.assertEqual(local_size2, torch.Size([4, 7])) | ||
else: | ||
self.assertEqual(local_size2, torch.Size([3, 7])) | ||
|
||
# shard, shard | ||
placements3 = [Shard(0), Shard(1)] | ||
local_size3 = compute_local_tensor_size(size, mesh, placements3) | ||
# first dim | ||
if rank_coordinates[0] < 3: | ||
self.assertEqual(local_size3[0], 2) | ||
else: | ||
self.assertEqual(local_size3[0], 1) | ||
# second dim | ||
if rank_coordinates[1] < 1: | ||
self.assertEqual(local_size3[1], 4) | ||
else: | ||
self.assertEqual(local_size3[1], 3) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. | ||
# Owner(s): ["oncall: distributed"] | ||
|
||
from typing import Optional, Sequence | ||
|
||
import torch | ||
from torch.distributed._tensor.device_mesh import DeviceMesh | ||
from torch.distributed._tensor.placement_types import Placement, Replicate, Shard | ||
|
||
|
||
def compute_local_tensor_size( | ||
size: torch.Size, device_mesh: DeviceMesh, placements: Sequence[Placement] | ||
) -> Optional[torch.Size]: | ||
""" | ||
Args: | ||
size(torch.Size): define the shape of the whole Dtensor. | ||
device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks | ||
placement: a sequence of :class:`Placement` type: Shard, Replicate | ||
Returns: | ||
A :class:`torch.Size` for the local tensor on the device_mesh | ||
""" | ||
if device_mesh.get_coordinate() is None: | ||
return None | ||
else: | ||
local_size = list(size) | ||
rank_coordinates = device_mesh.get_coordinate() | ||
for idx, placement in enumerate(placements): | ||
if isinstance(placement, Replicate): | ||
continue | ||
elif isinstance(placement, Shard): | ||
curr_coordinate = rank_coordinates[idx] | ||
shard_dim = placement.dim | ||
len_before_shard = local_size[shard_dim] | ||
num_chucks = device_mesh.size(idx) | ||
|
||
len_after_shard, _ = placement._local_shard_size_on_dim( | ||
len_before_shard, num_chucks, curr_coordinate | ||
) | ||
local_size[shard_dim] = len_after_shard | ||
else: | ||
raise RuntimeError(f"placement type {type(placement)} not supported!") | ||
|
||
return torch.Size(local_size) |