Skip to content
55 changes: 54 additions & 1 deletion test/distributed/tensor/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Owner(s): ["oncall: distributed"]

import itertools
from typing import Any

import torch
from torch.distributed.device_mesh import init_device_mesh
Expand All @@ -9,11 +10,18 @@
from torch.distributed.tensor._utils import (
_compute_local_shape_and_global_offset,
_explicit_order_placements,
compute_global_tensor_info,
compute_global_tensor_shape,
compute_local_shape_and_global_offset,
)
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.placement_types import _StridedShard, Replicate, Shard
from torch.distributed.tensor.placement_types import (
_StridedShard,
Partial,
Placement,
Replicate,
Shard,
)
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
Expand Down Expand Up @@ -442,6 +450,51 @@ def test_strided_sharding_assumption_in_meta_compute(self):
)


class UtilSingleDeviceTest(TestCase):
def test_compute_global_tensor_info_unsupported_placement(self):
class MockDeviceMesh:
def size(self, x):
return x

class FakePlacement(Placement):
pass

device_mesh: Any = MockDeviceMesh()
local_tensor = torch.tensor([1])
with self.assertRaises(RuntimeError):
compute_global_tensor_info(local_tensor, device_mesh, [FakePlacement()])

def test_compute_global_tensor_info_non_shard_placements(self):
class MockDeviceMesh:
def size(self, x):
return x

device_mesh: Any = MockDeviceMesh()
local_tensor = torch.tensor([[1], [2]])
global_size, global_stride = compute_global_tensor_info(
local_tensor, device_mesh, [Replicate(), Partial()]
)
self.assertEqual(global_size, local_tensor.size())
self.assertEqual(global_stride, local_tensor.stride())

def test_compute_global_tensor_info_shard_placement(self):
class MockDeviceMesh:
def size(self, dim):
return dim + 2

device_mesh: Any = MockDeviceMesh()
local_tensor = torch.tensor([[[1], [2], [3]], [[4], [5], [6]]])
global_size, global_stride = compute_global_tensor_info(
local_tensor, device_mesh, [Shard(0), Shard(1), Shard(2)]
)
self.assertEqual(
global_size, [(i + 2) * x for (i, x) in enumerate(local_tensor.size())]
)
self.assertEqual(global_stride[0], local_tensor.stride()[0] * 3 * 4)
self.assertEqual(global_stride[1], local_tensor.stride()[1])
self.assertEqual(global_stride[2], local_tensor.stride()[2] * 3)


class TestStridedSharding(DTensorTestBase):
@property
def world_size(self):
Expand Down
Loading