From 036296717c58f7bf4dbd007ff8afb836f9902250 Mon Sep 17 00:00:00 2001 From: Faycel Kouteib Date: Fri, 26 Apr 2024 19:11:13 -0700 Subject: [PATCH] [Test][Distributed] Make more tests multi-threaded. This conversion covers all tests under 'test/distributed/_tensor' directory. Fixes #108744 --- .../_tensor/experimental/test_local_map.py | 9 +-- .../_tensor/experimental/test_tp_transform.py | 8 +- test/distributed/_tensor/test_api.py | 27 +++---- test/distributed/_tensor/test_attention.py | 31 ++------ test/distributed/_tensor/test_common_rules.py | 26 ++----- .../_tensor/test_convolution_ops.py | 11 +-- test/distributed/_tensor/test_dtensor.py | 76 ++++++------------- .../_tensor/test_dtensor_compile.py | 35 ++++----- .../distributed/_tensor/test_embedding_ops.py | 8 +- .../_tensor/test_experimental_ops.py | 16 ++-- test/distributed/_tensor/test_init.py | 18 ++--- test/distributed/_tensor/test_math_ops.py | 13 +--- test/distributed/_tensor/test_matrix_ops.py | 34 +++------ test/distributed/_tensor/test_optimizers.py | 36 +++------ test/distributed/_tensor/test_random_ops.py | 24 ++---- test/distributed/_tensor/test_redistribute.py | 33 +++----- test/distributed/_tensor/test_tensor_ops.py | 71 ++++++----------- test/distributed/_tensor/test_utils.py | 16 ++-- test/distributed/_tensor/test_view_ops.py | 7 +- 19 files changed, 158 insertions(+), 341 deletions(-) diff --git a/test/distributed/_tensor/experimental/test_local_map.py b/test/distributed/_tensor/experimental/test_local_map.py index 1035df2f5f7d8..32a0b8312467c 100644 --- a/test/distributed/_tensor/experimental/test_local_map.py +++ b/test/distributed/_tensor/experimental/test_local_map.py @@ -13,8 +13,7 @@ from torch.distributed._tensor.experimental import local_map from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, - with_comms, + DTensorOpTestBase, ) @@ -38,13 +37,12 @@ def mul_forward(device_mesh, X, scalar): return torch.mul(X, scalar) -class TestLocalMap(DTensorTestBase): +class TestLocalMap(DTensorOpTestBase): @property def world_size(self): return 2 # simple correctness check - @with_comms def test_local_map_correctness(self): device_mesh = init_device_mesh( device_type=self.device_type, mesh_shape=(self.world_size,) @@ -86,7 +84,6 @@ def test_local_map_correctness(self): self.assertEqual(Y_dt.to_local(), Y) # check for `out_placements` - @with_comms def test_local_map_out_placements(self): device_mesh = init_device_mesh( device_type=self.device_type, mesh_shape=(self.world_size,) @@ -108,7 +105,6 @@ def test_local_map_out_placements(self): self.assertTrue(not (X.equal(Y))) # check for `in_placements` handling - @with_comms def test_local_map_in_placements(self): device_mesh = init_device_mesh( device_type=self.device_type, mesh_shape=(self.world_size,) @@ -174,7 +170,6 @@ def test_local_map_in_placements(self): self.assertEqual(Y_dt.full_tensor(), Y) # check for `redistribute_inputs` handling - @with_comms def test_local_map_redistribute(self): device_mesh = init_device_mesh( device_type=self.device_type, mesh_shape=(self.world_size,) diff --git a/test/distributed/_tensor/experimental/test_tp_transform.py b/test/distributed/_tensor/experimental/test_tp_transform.py index 636870264f84d..3d4f8483bc7ce 100644 --- a/test/distributed/_tensor/experimental/test_tp_transform.py +++ b/test/distributed/_tensor/experimental/test_tp_transform.py @@ -13,8 +13,7 @@ ) from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, - with_comms, + DTensorOpTestBase, ) @@ -52,7 +51,7 @@ def forward(self, x): return self.bn(self.fc(x)) -class TensorParallelTest(DTensorTestBase): +class TensorParallelTest(DTensorOpTestBase): def setUp(self) -> None: super().setUp() @@ -66,7 +65,6 @@ def assert_has_c10d_ops( actual_ops_count[str(node.target)] += 1 self.assertDictEqual(expected_ops_count, actual_ops_count) - @with_comms def test_tp_transform_with_uncovered_op(self): model = DummyModel().to(device=self.device_type) inputs = (torch.randn(7, 3, requires_grad=False).to(device=self.device_type),) @@ -96,7 +94,6 @@ def test_tp_transform_with_uncovered_op(self): }, ) - @with_comms def test_tp_transform_e2e(self): torch.manual_seed(0) model = MLPListModule(2).to(device=self.device_type) @@ -134,7 +131,6 @@ def test_tp_transform_e2e(self): }, ) - @with_comms def test_tp_transform_no_bias(self): torch.manual_seed(0) model = MLPListModule(1, bias=False).to(device=self.device_type) diff --git a/test/distributed/_tensor/test_api.py b/test/distributed/_tensor/test_api.py index 196bd6407b266..e30cb27c3aa8a 100644 --- a/test/distributed/_tensor/test_api.py +++ b/test/distributed/_tensor/test_api.py @@ -13,8 +13,7 @@ ) from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, - with_comms, + DTensorOpTestBase, ) @@ -33,16 +32,15 @@ def reset_parameters(self): m.reset_parameters() -class DTensorAPITest(DTensorTestBase): +class DTensorAPITest(DTensorOpTestBase): @property def world_size(self) -> int: # hard code world size to 4 as we need to test # at least with 2d mesh return 4 - @with_comms def test_distribute_tensor(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] for requires_grad in [True, False]: @@ -63,7 +61,6 @@ def test_distribute_tensor(self): dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_minus_spec) self.assertEqual(dist_tensor.placements[0].dim, 1) - @with_comms def test_distribute_tensor_errors(self): device_mesh = DeviceMesh( self.device_type, torch.arange(self.world_size).reshape(2, 2) @@ -92,9 +89,8 @@ def test_distribute_tensor_errors(self): new_spec = [Shard(0), Replicate()] distribute_tensor(dtensor, device_mesh, new_spec) - @with_comms def test_distribute_tensor_uneven_sharding(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() input_sizes_and_shard_dims = [ ((self.world_size * 3 + 1, 3, 3), 0), ((self.world_size * 3 + 2, 3, 3), 0), @@ -114,9 +110,8 @@ def test_distribute_tensor_uneven_sharding(self): local_tensor = dist_tensor.to_local() self.assertEqual(local_tensor, splitted_tensor_list[self.rank]) - @with_comms def test_distribute_module(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() # fully shard all linear modules on dim 0 module_to_shard = MyModel(5 * self.world_size, 20, device=self.device_type) shard_spec = [Shard(0)] @@ -177,9 +172,8 @@ def shard_fn(name, module, device_mesh): else: self.assertEqual(param.placements, replica_spec) - @with_comms def test_distribute_module_input_fn_output_fn(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() # fully replicate all linear modules module_to_replicate = MyModel(20, 1, device=self.device_type) @@ -222,9 +216,8 @@ def replicate_input_fn(mod, inputs, device_mesh): self.assertTrue(isinstance(param_grad, DTensor)) self.assertTrue(isinstance(param_grad.placements[0], Replicate)) - @with_comms def test_distribute_module_input_fn_output_fn_warning(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() # fully replicate all linear modules module_to_replicate = MyModel(20, 1, device=self.device_type) @@ -250,9 +243,8 @@ def output_fn(outputs, device_mesh): self.assertIsInstance(local_out, torch.Tensor) self.assertNotIsInstance(local_out, DTensor) - @with_comms def test_distribute_module_casting(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() # check DTensor casting dt = DTensor.from_local(torch.rand(10), device_mesh, [Replicate()]) @@ -288,11 +280,10 @@ def test_distribute_module_casting(self): output = replica_model(dt) self.assertEqual(output.dtype, torch.bfloat16) - @with_comms def test_distribute_module_meta(self): # If the model is too big, the user may first the create entire model on the meta device and then initialize # it on the device in the partition function. - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() # fully shard all parameters on dim 0 module_to_shard = MyModel(5 * self.world_size, 20, device="meta") diff --git a/test/distributed/_tensor/test_attention.py b/test/distributed/_tensor/test_attention.py index db5a26d438502..231429997da98 100644 --- a/test/distributed/_tensor/test_attention.py +++ b/test/distributed/_tensor/test_attention.py @@ -4,7 +4,7 @@ import torch from torch import nn -from torch.distributed._tensor import DeviceMesh, distribute_tensor, Shard +from torch.distributed._tensor import distribute_tensor, Shard from torch.distributed._tensor.debug import CommDebugMode from torch.distributed._tensor.experimental.attention import ( _CausalBehavior, @@ -25,17 +25,16 @@ run_tests, ) from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, + DTensorOpTestBase, ModelArgs, Transformer, - with_comms, ) c10d_functional = torch.ops.c10d_functional -class RingAttentionTest(DTensorTestBase): +class RingAttentionTest(DTensorOpTestBase): @property def world_size(self) -> int: return 2 @@ -44,13 +43,9 @@ def world_size(self) -> int: @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention" ) - @with_comms @parametrize("is_causal", [True, False]) def test_ring_attention_sdpa(self, is_causal: bool) -> None: - device_mesh = DeviceMesh( - self.device_type, - torch.arange(0, self.world_size), - ) + device_mesh = self.build_device_mesh() dtype = torch.bfloat16 bs = 8 query_tokens = 8 @@ -168,14 +163,10 @@ def test_ring_attention_sdpa(self, is_causal: bool) -> None: @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention" ) - @with_comms @sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]) @parametrize("is_causal", [True, False]) def test_ring_attention_native_transformer(self, is_causal: bool) -> None: - device_mesh = DeviceMesh( - self.device_type, - torch.arange(0, self.world_size), - ) + device_mesh = self.build_device_mesh() dtype = torch.bfloat16 bs = 8 ntokens = 8 @@ -250,13 +241,9 @@ def test_is_causal_behavior(self) -> None: @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention" ) - @with_comms @sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]) def test_ring_attention_custom_transformer(self) -> None: - device_mesh = DeviceMesh( - self.device_type, - torch.arange(0, self.world_size), - ) + device_mesh = self.build_device_mesh() dtype = torch.bfloat16 bs = 2 args = ModelArgs() @@ -301,7 +288,6 @@ def test_ring_attention_custom_transformer(self) -> None: @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention" ) - @with_comms @parametrize( "attention_fn", [ @@ -311,10 +297,7 @@ def test_ring_attention_custom_transformer(self) -> None: ], ) def test_ring_attention_compile(self, attention_fn: object) -> None: - device_mesh = DeviceMesh( - self.device_type, - torch.arange(0, self.world_size), - ) + device_mesh = self.build_device_mesh() dtype = torch.bfloat16 bs = 8 query_tokens = 8 diff --git a/test/distributed/_tensor/test_common_rules.py b/test/distributed/_tensor/test_common_rules.py index a69cfc2b4ddd5..714b689b5154d 100644 --- a/test/distributed/_tensor/test_common_rules.py +++ b/test/distributed/_tensor/test_common_rules.py @@ -9,14 +9,13 @@ from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, - with_comms, + DTensorOpTestBase, ) aten = torch.ops.aten -class CommonRulesTest(DTensorTestBase): +class CommonRulesTest(DTensorOpTestBase): @property def world_size(self) -> int: # hard code world size to 4 as we need to test @@ -31,10 +30,9 @@ def _gen_tensor_meta(self, shape): empty_tensor.dtype, ) - @with_comms def test_einop_basic_propagation(self): # plain einsum, mm - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() mm_call = aten.mm.default # propagate col-wise sharding @@ -85,9 +83,8 @@ def test_einop_basic_propagation(self): self.assertIsNotNone(output_spec) self.assertTrue(output_spec.placements[0].is_partial()) - @with_comms def test_einop_pointwise_propagation(self): - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() add_call = aten.add.Tensor # addition @@ -137,7 +134,6 @@ def test_einop_pointwise_propagation(self): self.assertIsNotNone(output_spec) self.assertEqual(output_spec.dim_map, [0, -1, -1]) - @with_comms def test_einop_merge_sharding(self): # 2d mesh einop merge sharding mesh_shape = torch.arange(self.world_size).reshape( @@ -163,7 +159,6 @@ def test_einop_merge_sharding(self): self.assertIsNotNone(output_spec) self.assertEqual(output_spec.dim_map, [0, 1]) - @with_comms def test_einop_linearity(self): mesh_shape = torch.arange(self.world_size).reshape( self.world_size // 2, self.world_size // 2 @@ -231,11 +226,9 @@ def test_einop_linearity(self): # mat2 mesh dim 1 should become partial now! self.assertTrue(mat2_spec.placements[1].is_partial()) - @with_comms def test_einop_multi_sharding_on_mesh_dim(self): # einop prop with multi sharding on same mesh dim - mesh_shape = torch.arange(self.world_size) - mesh = DeviceMesh(self.device_type, mesh_shape) + mesh = self.build_device_mesh() mm_call = aten.mm.default mat1, mat2 = [0, -1], [0, -1] @@ -260,7 +253,6 @@ def test_einop_multi_sharding_on_mesh_dim(self): self.assertEqual(schema_suggestion.args_schema[0].dim_map, [0, -1]) self.assertEqual(schema_suggestion.args_schema[1].dim_map, [-1, -1]) - @with_comms def test_einop_errors(self): mesh_shape = torch.arange(self.world_size).reshape( self.world_size // 2, self.world_size // 2 @@ -281,9 +273,8 @@ def test_einop_errors(self): with self.assertRaisesRegex(RuntimeError, "sharded two different ways:"): einop_rule("ij,ij->ij", OpSchema(add_call, (mat1_spec, mat2_spec), {})) - @with_comms def test_pointwise_rules_broadcasting(self): - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() where_call = aten.where.self inp1, inp2, inp3 = [0], [], [-1, -1] @@ -307,9 +298,8 @@ def test_pointwise_rules_broadcasting(self): self.assertIsNotNone(output_spec) self.assertEqual(output_spec.dim_map, [-1, 0]) - @with_comms def test_pointwise_rules_suggestion(self): - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() lerp_call = aten.lerp.Scalar # propagate point-wise sharding @@ -335,7 +325,6 @@ def test_pointwise_rules_suggestion(self): self.assertEqual(len(schema_suggestion.args_schema), 3) self.assertEqual(schema_suggestion.args_schema[2], -1) - @with_comms def test_pointwise_multi_sharding_on_mesh_dim(self): # 2d mesh pointwise sharding mesh_shape = torch.arange(self.world_size).reshape( @@ -381,7 +370,6 @@ def test_pointwise_multi_sharding_on_mesh_dim(self): self.assertEqual(schema_suggestion.args_schema[0].dim_map, [-1, -1, -1, 1]) self.assertEqual(schema_suggestion.args_schema[1].dim_map, mat2) - @with_comms def test_pointwise_enforce_sharding_multi_sharding_on_mesh_dim(self): # 2d mesh pointwise sharding mesh_shape = torch.arange(self.world_size).reshape( diff --git a/test/distributed/_tensor/test_convolution_ops.py b/test/distributed/_tensor/test_convolution_ops.py index 19687d9e1e329..941170cfd4bcb 100644 --- a/test/distributed/_tensor/test_convolution_ops.py +++ b/test/distributed/_tensor/test_convolution_ops.py @@ -14,8 +14,7 @@ ) from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, - with_comms, + DTensorOpTestBase, ) ITER_TIME = 10 @@ -36,15 +35,14 @@ def _conv_fn( module.register_parameter(name, dist_param) -class DistConvolutionOpsTest(DTensorTestBase): +class DistConvolutionOpsTest(DTensorOpTestBase): @property def world_size(self) -> int: # hard code world size to 2 return 2 - @with_comms def test_downsampling_convolution(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(3)] input_list = torch.rand(ITER_TIME, 7, 3, 512, 1024) @@ -109,9 +107,8 @@ def test_downsampling_convolution(self): f"Too large relative mse for bias tensor, expected less equal 1e-6, got {bias_mse_rel}", ) - @with_comms def test_depthwise_convolution(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(3)] input_list = torch.rand(ITER_TIME, 7, 256, 128, 256) diff --git a/test/distributed/_tensor/test_dtensor.py b/test/distributed/_tensor/test_dtensor.py index 653dfcbb5876a..f03794d49c7e4 100644 --- a/test/distributed/_tensor/test_dtensor.py +++ b/test/distributed/_tensor/test_dtensor.py @@ -22,8 +22,7 @@ from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, - with_comms, + DTensorOpTestBase, ) @@ -48,10 +47,9 @@ def reset_parameters(self, *args, **kwargs): self.net2.bias.fill_(1.2) -class DTensorTest(DTensorTestBase): - @with_comms +class DTensorTest(DTensorOpTestBase): def test_dtensor_constructor(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = [Shard(0)] local_tensor = torch.randn(3, 3, requires_grad=True) dist_tensor_shape = torch.Size([self.world_size * 3, 3]) @@ -77,7 +75,6 @@ def test_dtensor_constructor(self): stride=local_tensor.stride(), ) - @with_comms def test_meta_dtensor(self): device_mesh = self.build_device_mesh() dist_specs = [[Shard(0)], [Replicate()]] @@ -100,7 +97,6 @@ def test_meta_dtensor(self): value_tensor = torch.empty_like(meta_dtensor.to_local()).fill_(1.5) self.assertEqual(meta_dtensor.to_local(), value_tensor) - @with_comms def test_modules_w_meta_dtensor(self): model = DummyMLP("meta") device_mesh = self.build_device_mesh() @@ -135,9 +131,8 @@ def test_modules_w_meta_dtensor(self): inp = torch.randn(20, 5, device=self.device_type) self.assertEqual(model_tp(inp), model_regular_tp(inp)) - @with_comms def test_dtensor_stride(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard0_spec = [Shard(0)] local_tensor = torch.randn(4, 8) global_shape = torch.Size([self.world_size * 4, 8]) @@ -161,9 +156,8 @@ def test_dtensor_stride(self): global_stride = (8 * self.world_size, 1, 32 * self.world_size) self.assertEqual(dist_tensor.stride(), global_stride) - @with_comms def test_from_local(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = [Shard(0)] local_tensor = torch.randn(3, 3) sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements) @@ -198,7 +192,6 @@ def test_from_local(self): expected_grad = torch.ones(3, 3) * 9 self.assertEqual(local_tensor_with_grad.grad, expected_grad) - @with_comms def test_from_local_uneven_sharding(self): mesh_shape = (self.world_size,) device_mesh = init_device_mesh(self.device_type, mesh_shape) @@ -224,7 +217,6 @@ def test_from_local_uneven_sharding(self): self.assertEqual(dtensor.size(), global_tensor.size()) self.assertEqual(dtensor.stride(), global_tensor.stride()) - @with_comms def test_from_local_uneven_sharding_raise_error(self): mesh_shape = (self.world_size,) device_mesh = init_device_mesh(self.device_type, mesh_shape) @@ -259,17 +251,15 @@ def test_from_local_uneven_sharding_raise_error(self): stride=global_tensor.stride(), ) - @with_comms def test_from_local_negative_dim(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = [Shard(-1)] local_tensor = torch.randn(3, 3) sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements) self.assertEqual(sharded_tensor.placements[0].dim, 1) - @with_comms def test_to_local(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = (Shard(0),) dist_tensor_shape = torch.Size([self.world_size * 3, 3]) local_tensor_with_grad = torch.randn( @@ -318,9 +308,8 @@ def test_to_local(self): except RuntimeError: self.assertEqual(sharded_tensor.grad.stride(), [1, 3 * self.world_size]) - @with_comms def test_to_local_grad_hint(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = (Shard(0),) global_tensor = torch.ones(8, 3, requires_grad=True) @@ -343,9 +332,8 @@ def test_to_local_grad_hint(self): replica_grad = sharded_dtensor.grad.full_tensor() self.assertEqual(replica_grad, global_tensor * self.world_size) - @with_comms def test_full_tensor_sync(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = (Shard(0),) global_tensor = torch.ones(8, 3, requires_grad=True) @@ -354,9 +342,8 @@ def test_full_tensor_sync(self): self.assertFalse(isinstance(full_out, AsyncCollectiveTensor)) self.assertEqual(full_out, global_tensor) - @with_comms def test_full_tensor_grad_hint(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = (Shard(0),) global_tensor = torch.ones(8, 3, requires_grad=True) @@ -367,9 +354,8 @@ def test_full_tensor_grad_hint(self): replica_grad = sharded_dtensor.grad.full_tensor() self.assertEqual(replica_grad, global_tensor * self.world_size) - @with_comms def test_dtensor_new_empty_strided(self): - device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + device_mesh = self.build_device_mesh() local_tensor = torch.randn(8, 8, requires_grad=True, device=self.device_type) my_dtensor = distribute_tensor(local_tensor, device_mesh, [Shard(0)]) new_strided_dtensor = my_dtensor.new_empty_strided( @@ -390,12 +376,11 @@ def test_dtensor_new_empty_strided(self): local_tensor.grad, ) - @with_comms def test_dtensor_async_output(self): # Tests that if the output of some dtensor operations isn't used in any compute, # the output should be an AsyncCollectiveTensor (representing the fact that # we haven't synced the collective yet). - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() def fn(dt): dt_out_redistribute = dt.redistribute(mesh, [Replicate()], async_op=True) @@ -432,10 +417,9 @@ def fn(dt): self.assertFalse(isinstance(sync_out, AsyncCollectiveTensor)) self.assertEqual(sync_out.to_local(), x) - @with_comms def test_from_local_then_to_local(self): # this test ensure end to end from torch.Tensor -> dist tensor -> torch.Tensor works - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = [Shard(0)] # step 1. construct from construct local tensor @@ -465,9 +449,8 @@ def test_from_local_then_to_local(self): expected_grad = torch.ones(3, 3) * 6 self.assertEqual(local_tensor_with_grad.grad, expected_grad) - @with_comms def test_dtensor_spec_read_only_after_set(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = [Shard(0)] local_tensor = torch.randn(3, 3) sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements) @@ -477,9 +460,8 @@ def test_dtensor_spec_read_only_after_set(self): self.assertTrue(sharded_tensor.placements is not placements) self.assertNotEqual(sharded_tensor.placements, placements) - @with_comms def test_dtensor_spec_hash(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = [Shard(0)] local_tensor = torch.randn(3, 3) local_tensor2 = torch.randn(3, 3) @@ -497,15 +479,13 @@ def test_dtensor_spec_hash(self): ) self.assertNotEqual(hash(sharded_tensor._spec), hash(replica_tensor._spec)) - @with_comms def test_dtensor_properties(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = [Shard(0)] local_tensor = torch.randn(3, 3) sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements) self.assertEqual(sharded_tensor.device.type, self.device_type) - @with_comms def test_dtensor_save_load(self): import io @@ -520,7 +500,7 @@ def test_dtensor_save_load(self): self.assertEqual(sharded_tensor, reloaded_st) -class DTensorMeshTest(DTensorTestBase): +class DTensorMeshTest(DTensorOpTestBase): @property def world_size(self): return 8 @@ -531,10 +511,9 @@ def sub_mesh_assert_equal(self, mesh, exp_in_mesh, exp_out_of_mesh, tensor): else: self.assertEqual(tensor, exp_out_of_mesh) - @with_comms def test_dtensor_device_mesh_device_conversion(self): # construct a cuda device mesh - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() # construct from a cpu local tensor with cuda device mesh # should automatically convert the dist tensor to cuda @@ -544,16 +523,15 @@ def test_dtensor_device_mesh_device_conversion(self): self.assertEqual(dist_tensor.device.type, self.device_type) self.assertEqual(dist_tensor.to_local().device.type, self.device_type) - @with_comms def test_dtensor_api_device_mesh_context_manager(self): - with DeviceMesh(self.device_type, list(range(self.world_size))) as mesh: + with self.build_device_mesh() as mesh: placements = [Shard(0)] local_tensor = torch.randn(3, 3) sharded_tensor = DTensor.from_local( local_tensor, device_mesh=mesh, placements=placements ) - with DeviceMesh(self.device_type, list(range(self.world_size))): + with self.build_device_mesh(): placements = [Shard(0)] local_tensor = torch.randn(3, 3) sharded_tensor = DTensor.from_local(local_tensor, placements=placements) @@ -563,7 +541,7 @@ def test_dtensor_api_device_mesh_context_manager(self): replica_tensor.size(), torch.Size([3 * self.world_size, 3]) ) - with DeviceMesh(self.device_type, torch.arange(self.world_size)): + with self.build_device_mesh(): placements = [Shard(0)] global_shape = torch.Size([3 * self.world_size, 3]) global_tensor = torch.randn(global_shape) @@ -583,7 +561,6 @@ def test_dtensor_api_device_mesh_context_manager(self): sharded_after_2d = distribute_tensor(global_tensor, placements=placements) self.assertEqual(sharded_after_2d.to_local().shape, torch.Size([3, 3])) - @with_comms def test_dtensor_2d_mesh(self): mesh_tensor = torch.arange(self.world_size).reshape(2, 4) # construct a cuda device mesh @@ -606,7 +583,6 @@ def test_dtensor_2d_mesh(self): dist_tensor = DTensor.from_local(local_tensor, mesh, shard_same_dim_spec) self.assertEqual(dist_tensor.size(), torch.Size([3 * self.world_size, 3])) - @with_comms def test_device_mesh_nd(self): # construct a cuda device mesh mesh_tensor = torch.arange(self.world_size).reshape(2, 2, 2) @@ -627,7 +603,6 @@ def test_device_mesh_nd(self): self.assertEqual(dist_tensor.device.type, self.device_type) self.assertEqual(dist_tensor.to_local().device.type, self.device_type) - @with_comms def test_dtensor_spec_local_shard_offset(self): device_mesh = DeviceMesh( self.device_type, torch.arange(self.world_size).reshape(2, 4) @@ -666,7 +641,6 @@ def test_dtensor_spec_local_shard_offset(self): ) self.assertEqual(expected_shard_offsets, offset) - @with_comms def test_from_local_sub_mesh(self): mesh = DeviceMesh(self.device_type, [0, 2]) local_tensor = torch.ones(3, 4) @@ -693,7 +667,6 @@ def test_from_local_sub_mesh(self): dtensor.to_local(), ) - @with_comms def test_default_value_sub_mesh(self): mesh = DeviceMesh(self.device_type, [0, 2]) @@ -731,7 +704,6 @@ def test_default_value_sub_mesh(self): [dt.to_local() for dt in dtensor_list], ) - @with_comms def test_redistribute_sub_mesh(self): mesh = DeviceMesh(self.device_type, [0, 2]) @@ -747,7 +719,6 @@ def test_redistribute_sub_mesh(self): mesh.mesh, torch.ones(4, 3), torch.tensor([]), sharded_again.to_local() ) - @with_comms def test_implicit_replication(self): mesh = init_device_mesh(self.device_type, (self.world_size,)) local_tensor1 = torch.ones(4, 3) @@ -764,7 +735,7 @@ def test_implicit_replication(self): self.assertEqual(local_shard, torch.ones(4, 3) + torch.ones(3)) -class TestDTensorPlacementTypes(DTensorTestBase): +class TestDTensorPlacementTypes(DTensorOpTestBase): @property def world_size(self): return 8 @@ -778,9 +749,8 @@ def _create_tensor(self, size): else: return tensor - @with_comms def test_split_tensor_1D(self) -> None: - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() shard_placement = Shard(0) for size in range(8): diff --git a/test/distributed/_tensor/test_dtensor_compile.py b/test/distributed/_tensor/test_dtensor_compile.py index f9ad0278d7e28..e25367b023454 100644 --- a/test/distributed/_tensor/test_dtensor_compile.py +++ b/test/distributed/_tensor/test_dtensor_compile.py @@ -40,9 +40,8 @@ run_tests, ) from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, + DTensorOpTestBase, MLPModule, - with_comms, ) from torch.testing._internal.distributed.fake_pg import FakeStore from torch.utils._triton import has_triton @@ -141,13 +140,13 @@ def fn(x): compiled_fn = torch.compile(backend="aot_eager", fullgraph=True)(fn) - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() opt_fn = fn(mesh) compiled_out = compiled_fn(mesh) self.assertEqual(opt_fn, compiled_out) def test_fakify_dtensor(self): - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() # pass in DTensor as inputs/outputs to the function def fn(x): @@ -161,7 +160,7 @@ def fn(x): self.assertEqual(res, ref) def test_dynamo_dtensor(self): - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() # test passing in DTensor as inputs/outputs and run some tensor computation def fn(x): @@ -175,7 +174,7 @@ def fn(x): self.assertEqual(res, ref) def test_dtensor_attribute_access_on_intermediate(self): - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() def fn(x): tmp = x * 2 @@ -192,7 +191,7 @@ def fn(x): self.assertEqual(res, ref) def test_dtensor_noncontiguous_output(self): - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() # test passing in DTensor as inputs/outputs and run some tensor computation def fn(x, y, z): @@ -210,7 +209,7 @@ def fn(x, y, z): out.contiguous().sum().backward() def test_dynamo_dtensor_from_local(self): - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() # create DTensor inside fn and run some compute def fn(x): @@ -253,7 +252,7 @@ def from_local_kwargs_fn(x): self.assertEqual(cnt.frame_count, 2) def test_dynamo_to_local_kwargs(self): - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() def fn(x): return dt.to_local(grad_placements=[Shard(0)]) + 2 @@ -267,7 +266,7 @@ def fn(x): self.assertEqual(out_ref, out_test) def test_dynamo_to_local_kwargs_forward_hook(self): - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() def fw_hook(module, inp, out): tmp = out.to_local(grad_placements=out.placements) + 2 @@ -295,7 +294,7 @@ def fw_hook(module, inp, out): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_dtensor_different_gradient_placement(self): - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() def fn(x, y, z): permute = x.permute(0, 2, 1) @@ -319,7 +318,7 @@ def fn(x, y, z): out_dt.sum().backward() def test_dynamo_dtensor_from_local_redistribute(self): - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() # pass in tensor as inputs/outputs, create DTensor and run redistribute # (allgather collective) inside the fn @@ -350,7 +349,7 @@ def redistribute_kwargs_fn(x): self.assertEqual(res, ref) def test_dtensor_dynamo_device_mesh_attrs(self): - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() # pass in tensor as inputs/outputs, create DTensor and run redistribute # (allgather collective) inside the fn @@ -370,7 +369,7 @@ def fn(x_dt): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_dtensor_partial_placement_graph_output(self): - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() def fn(x): return x + x @@ -463,15 +462,14 @@ def forward(self, input): @instantiate_parametrized_tests -class TestDTensorCompileE2E(DTensorTestBase): +class TestDTensorCompileE2E(DTensorOpTestBase): @property def world_size(self): return 4 - @with_comms @parametrize("is_seq_parallel", [True, False]) def test_tp_compile_fullgraph(self, is_seq_parallel): - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() model = SimpleModel(self.device_type) @@ -528,7 +526,6 @@ def test_tp_compile_fullgraph(self, is_seq_parallel): self.assertEqual(compiled_out, out) self.assertEqual(cnt.frame_count, 1) - @with_comms @skip_if_lt_x_gpu(4) def test_2d_fsdp_tp_compile(self): data_parallel_size = 2 @@ -579,7 +576,6 @@ def test_2d_fsdp_tp_compile(self): self.assertEqual(out, compiled_output) self.assertEqual(cnt.frame_count, 1) - @with_comms @skip_if_lt_x_gpu(4) def test_2d_fsdp_tp_ac_compile(self): dp_degree = 2 @@ -630,7 +626,6 @@ def test_2d_fsdp_tp_ac_compile(self): for n, p in zip(fsdp_2d.parameters(), compiled_2d.parameters()): self.assertEqual(n.grad, p.grad) - @with_comms @skip_if_lt_x_gpu(4) def test_compile_dtensor_redistribute_backward(self): mesh = DeviceMesh(device_type="cuda", mesh=torch.arange(self.world_size)) diff --git a/test/distributed/_tensor/test_embedding_ops.py b/test/distributed/_tensor/test_embedding_ops.py index 3ac61e0b45fe5..47dd39e8ba3c3 100644 --- a/test/distributed/_tensor/test_embedding_ops.py +++ b/test/distributed/_tensor/test_embedding_ops.py @@ -13,8 +13,7 @@ from torch.distributed._tensor.debug import CommDebugMode from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, - with_comms, + DTensorOpTestBase, ) if TEST_WITH_DEV_DBG_ASAN: @@ -28,7 +27,7 @@ funcol = torch.ops.c10d_functional -class TestEmbeddingOp(DTensorTestBase): +class TestEmbeddingOp(DTensorOpTestBase): def _apply_sharding(self, embedding_mod, shard_dim, device_mesh): def shard_embedding_fn(name, module, device_mesh): for name, param in module.named_parameters(): @@ -136,7 +135,6 @@ def _run_embedding_op_test( ) self.assertEqual(local_output, sharded_output.full_tensor()) - @with_comms def test_sharded_embedding_colwise(self): mesh = self.build_device_mesh() self._run_embedding_op_test(mesh, 1, [5, 4], 17, 12) @@ -147,7 +145,6 @@ def test_sharded_embedding_colwise(self): self._run_embedding_op_test(mesh, 1, [34], 15, 14, padding_idx=10) self._run_embedding_op_test(mesh, 1, [8, 6, 5, 4], 23, 13, padding_idx=12) - @with_comms def test_sharded_embedding_colwise_max_norm_errors(self): mesh = self.build_device_mesh() with self.assertRaisesRegex( @@ -158,7 +155,6 @@ def test_sharded_embedding_colwise_max_norm_errors(self): mesh, 1, [8, 6, 5, 4], 23, 13, padding_idx=12, max_norm=2.0 ) - @with_comms def test_sharded_embedding_rowwise(self): mesh = self.build_device_mesh() # test correctness diff --git a/test/distributed/_tensor/test_experimental_ops.py b/test/distributed/_tensor/test_experimental_ops.py index 803b687a369bd..7066b8daae65a 100644 --- a/test/distributed/_tensor/test_experimental_ops.py +++ b/test/distributed/_tensor/test_experimental_ops.py @@ -6,11 +6,10 @@ import torch.distributed as dist -from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate +from torch.distributed._tensor import distribute_tensor, Replicate from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, - with_comms, + DTensorOpTestBase, ) @@ -18,15 +17,14 @@ LR = 0.001 -class DistOtherOpsTest(DTensorTestBase): +class DistOtherOpsTest(DTensorOpTestBase): @property def world_size(self) -> int: # hard code world size to 2 return 2 - @with_comms def test_slice(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Replicate()] input_list = torch.rand(ITER_TIME, 1024, 10) @@ -75,10 +73,9 @@ def test_slice(self): f"Too large relative mse for gradient, expected less equal 1e-6, got {grad_mse_rel}", ) - @with_comms def test_bernoulli(self): rank = dist.get_rank() - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Replicate()] input_list = torch.rand(ITER_TIME, 1024, 10) @@ -138,9 +135,8 @@ def test_bernoulli(self): f"Too large relative mse for gradient, expected less equal 1e-6, got {grad_mse_rel}", ) - @with_comms def test_nll(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Replicate()] pred_list = torch.rand(ITER_TIME, 1024, 10) diff --git a/test/distributed/_tensor/test_init.py b/test/distributed/_tensor/test_init.py index 11a9596abc75d..391e05d188c1e 100644 --- a/test/distributed/_tensor/test_init.py +++ b/test/distributed/_tensor/test_init.py @@ -5,12 +5,11 @@ from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard, zeros from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, - with_comms, + DTensorOpTestBase, ) -class DTensorInitOpsTest(DTensorTestBase): +class DTensorInitOpsTest(DTensorOpTestBase): def _run_init_op(self, init_op, *args, **kwargs): device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] @@ -24,20 +23,19 @@ def _run_init_op(self, init_op, *args, **kwargs): dtensor = init_op(dtensor, *args, **kwargs) self.assertEqual(local_tensor_clone, dtensor.to_local()) - @with_comms def test_init_ops(self): # NOTE: random init tests are moved to test_random_ops.py self._run_init_op(torch.nn.init.constant_, 2.4) -class DTensorConstructorTest(DTensorTestBase): +class DTensorConstructorTest(DTensorOpTestBase): @property def world_size(self): return 4 def _run_init_op(self, init_op, dist_init_op, eq_op, *args, **kwargs): # 1d mesh test - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements_list = [[Shard(0)], [Shard(1)], [Shard(2)], [Replicate()]] # even sharding @@ -90,7 +88,6 @@ def _run_init_op(self, init_op, dist_init_op, eq_op, *args, **kwargs): expected_tensor = init_op([], *args, **kwargs) eq_op(expected_tensor, local_tensor) - @with_comms def test_ones(self): self._run_init_op( torch.ones, @@ -99,7 +96,6 @@ def test_ones(self): requires_grad=True, ) - @with_comms def test_empty(self): self._run_init_op( torch.empty, @@ -110,7 +106,6 @@ def test_empty(self): requires_grad=True, ) - @with_comms def test_full(self): self._run_init_op( torch.full, @@ -120,7 +115,6 @@ def test_full(self): requires_grad=True, ) - @with_comms def test_zeros(self): self._run_init_op( torch.zeros, @@ -129,10 +123,9 @@ def test_zeros(self): requires_grad=True, ) - @with_comms def test_zeros_full_mesh(self): # construct a cuda device 1d mesh - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() placements = [Shard(0)] size = [32, 3] dist_tensor = zeros(size, device_mesh=mesh, placements=placements) @@ -194,7 +187,6 @@ def test_zeros_full_mesh(self): elif self.rank == 3: self.assertEqual(local_tensor, torch.zeros([15, 1])) - @with_comms def test_zeros_submesh(self): # default world_size is 4 # construct a cuda device 1d mesh, with no sub pg initialized diff --git a/test/distributed/_tensor/test_math_ops.py b/test/distributed/_tensor/test_math_ops.py index f810278c3ffe8..7c1e6ea9e157e 100644 --- a/test/distributed/_tensor/test_math_ops.py +++ b/test/distributed/_tensor/test_math_ops.py @@ -12,16 +12,15 @@ from torch.distributed._tensor.placement_types import Replicate, Shard from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, + DTensorOpTestBase, skip_unless_torch_gpu, - with_comms, ) funcol = torch.ops.c10d_functional -class DistMathOpsTest(DTensorTestBase): +class DistMathOpsTest(DTensorOpTestBase): def linear_op_reductions(self, op_str): device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] @@ -50,18 +49,15 @@ def linear_op_reductions(self, op_str): dt_full_reduced = op_dt().full_tensor() self.assertEqual(dt_full_reduced, full_reduced_tensor) - @with_comms def test_linear_op_reductions(self): for op_str in ("all", "sum", "prod", "max", "min"): self.linear_op_reductions(op_str) - @with_comms @skip_unless_torch_gpu def test_mean(self): self.linear_op_reductions("mean") # TODO: forward test can be removed once test_softmax_with_bwd passes on CPU - @with_comms def test_softmax_fwd(self): device_mesh = self.build_device_mesh() @@ -90,7 +86,6 @@ def test_softmax_fwd(self): # TODO: get test_softmax_with_bwd pass on CPU # DTensor's _softmax_backward_data produces wrong result on CPU on certain dimension. # fail_on_cpu_list = [(0, -1), (1, -1)] - @with_comms @skip_unless_torch_gpu def test_softmax_with_bwd(self): device_mesh = self.build_device_mesh() @@ -133,7 +128,6 @@ def test_softmax_with_bwd(self): self.assertTrue(dist_x.grad.placements[0].is_shard(dim=shard_dim)) self.assertEqual(dist_x.grad.full_tensor(), x.grad) - @with_comms @skip_unless_torch_gpu def test_nll_loss_and_cross_entropy(self): device_mesh = self.build_device_mesh() @@ -203,7 +197,6 @@ def test_nll_loss_and_cross_entropy(self): self.assertEqual(dist_x.grad.full_tensor(), x.grad) x.grad.zero_() - @with_comms def test_shard_math_ops(self): mesh_shape = (2, self.world_size // 2) mesh = DeviceMesh( @@ -227,7 +220,6 @@ def test_shard_math_ops(self): fully_shard_full_tensor = op(fully_shard_tensor, 2).full_tensor() self.assertEqual(fully_shard_full_tensor, expect_rs) - @with_comms def test_layer_norm_fwd(self): device_mesh = self.build_device_mesh() @@ -286,7 +278,6 @@ def _replicate_fn(name, module, device_mesh): self.assertEqual(y_local.shape, dtensor_meta.shape) self.assertEqual(y_local, y_dist.full_tensor()) - @with_comms def test_layer_norm_bwd(self): device_mesh = self.build_device_mesh() diff --git a/test/distributed/_tensor/test_matrix_ops.py b/test/distributed/_tensor/test_matrix_ops.py index b303157acf44a..415410d994b94 100644 --- a/test/distributed/_tensor/test_matrix_ops.py +++ b/test/distributed/_tensor/test_matrix_ops.py @@ -6,7 +6,7 @@ import torch import torch.nn.functional as F -from torch.distributed._tensor import DeviceMesh, distribute_tensor +from torch.distributed._tensor import distribute_tensor from torch.distributed._tensor.api import DTensor from torch.distributed._tensor.placement_types import ( _Partial, @@ -16,16 +16,14 @@ ) from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, + DTensorOpTestBase, skip_unless_torch_gpu, - with_comms, ) -class DistMatrixOpsTest(DTensorTestBase): - @with_comms +class DistMatrixOpsTest(DTensorOpTestBase): def test_addmm(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] replica_spec = [Replicate()] @@ -40,9 +38,8 @@ def test_addmm(self): local_res = torch.addmm(input_tensor, tensor_to_shard, tensor_to_replicate) self.assertEqual(dist_res.full_tensor(), local_res) - @with_comms def test_addmm_empty_operand(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] replica_spec = [Replicate()] @@ -57,9 +54,8 @@ def test_addmm_empty_operand(self): local_res = torch.addmm(input_tensor, tensor_to_shard, tensor_to_replicate) self.assertEqual(dist_res.full_tensor(), local_res) - @with_comms def test_addmm_auto_redistribute(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard0_spec = [Shard(0)] shard1_spec = [Shard(1)] replica_spec = [Replicate()] @@ -88,9 +84,8 @@ def test_addmm_auto_redistribute(self): self.assertIsNotNone(mat2.grad) self.assertEqual(mat2.grad.full_tensor(), tensor_to_shard0.grad) - @with_comms def test_mm(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard0_spec = Shard(0) shard1_spec = Shard(1) replica_spec = Replicate() @@ -118,9 +113,8 @@ def test_placement_comb( for spec in shard_specs_comb: test_placement_comb([spec[0]], [spec[1]]) - @with_comms def test_t(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] tensor_to_transpose = torch.randn(12, 8, requires_grad=True) @@ -132,9 +126,8 @@ def test_t(self): self.assertEqual(tranposed_mat2.size(), torch.Size([12, 8])) self.assertEqual(tranposed_mat2.placements, shard_spec) - @with_comms def test_t_partial(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() a = torch.randn(12, 8) b = torch.randn(8, 4) @@ -156,10 +149,9 @@ def test_t_partial(self): ) # baddbmm introduces nan occasionally on CPU: https://github.com/pytorch/pytorch/issues/80588 - @with_comms @skip_unless_torch_gpu def test_baddbmm(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() tensor = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True) batch_1 = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True) batch_2 = torch.rand(4, 8, 8, device=self.device_type, requires_grad=True) @@ -221,9 +213,8 @@ def test_placement_comb( [spec[0]], [spec[1]], [spec[2]], beta, alpha, batch_1.grad ) - @with_comms def test_bmm(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() mat1 = torch.rand(4, 8, 4, device=self.device_type, requires_grad=True) mat2 = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True) local_result = torch.bmm(mat1, mat2) @@ -265,10 +256,9 @@ def test_placement_comb( for spec in shard_specs_comb: test_placement_comb([spec[0]], [spec[1]]) - @with_comms @skip_unless_torch_gpu def test_scaled_dot_product_attention(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() # bsz, n_heads, slen, head_dim query = torch.rand( (4, 8, 8, 8), diff --git a/test/distributed/_tensor/test_optimizers.py b/test/distributed/_tensor/test_optimizers.py index e9421d8a45217..7626d781f111a 100644 --- a/test/distributed/_tensor/test_optimizers.py +++ b/test/distributed/_tensor/test_optimizers.py @@ -7,7 +7,6 @@ import torch.nn as nn from torch.distributed._tensor import ( - DeviceMesh, distribute_module, distribute_tensor, DTensor, @@ -17,9 +16,8 @@ from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, + DTensorOpTestBase, MLPModule, - with_comms, ) @@ -50,7 +48,7 @@ def output_fn(mod, outputs, device_mesh): return outputs.redistribute(placements=[Replicate()] * device_mesh.ndim).to_local() -class TestDTensorOptimizer(DTensorTestBase): +class TestDTensorOptimizer(DTensorOpTestBase): def _assert_optimizer( self, mesh, @@ -84,9 +82,8 @@ def _assert_optimizer( # Default 'rtol' and 'atol' for attr:`~torch.float32` are ``1.3e-6`` and ``1e-5`` self.assertEqual(p1, p2, atol=atol, rtol=rtol) - @with_comms def test_adam_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() # TODO: add fused_adam support adam_configs = [ @@ -118,9 +115,8 @@ def test_adam_1d_sharding(self): inp = torch.ones(8, 10, device=self.device_type) self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) - @with_comms def test_adamw_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() # TODO: add fused_adamw support adamw_configs = [ @@ -160,9 +156,8 @@ def test_adamw_1d_sharding(self): inp = torch.ones(8, 10, device=self.device_type) self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) - @with_comms def test_sgd_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() sgd_configs = [ {"lr": 0.1}, @@ -200,9 +195,8 @@ def test_sgd_1d_sharding(self): inp = torch.ones(8, 10, device=self.device_type) self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) - @with_comms def test_adagrad_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() adagrad_configs = [ {"lr": 0.1}, @@ -254,9 +248,8 @@ def test_adagrad_1d_sharding(self): inp = torch.ones(8, 10, device=self.device_type) self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) - @with_comms def test_RMSprop_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() RMSprop_configs = [ {"lr": 0.1}, @@ -313,9 +306,8 @@ def test_RMSprop_1d_sharding(self): inp = torch.ones(8, 10, device=self.device_type) self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) - @with_comms def test_adadelta_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() adadelta_configs = [ {"lr": 0.1}, @@ -353,9 +345,8 @@ def test_adadelta_1d_sharding(self): inp = torch.ones(8, 10, device=self.device_type) self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) - @with_comms def test_nadam_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() nadam_configs = [ {"lr": 0.1}, @@ -392,9 +383,8 @@ def test_nadam_1d_sharding(self): inp = torch.ones(8, 10, device=self.device_type) self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) - @with_comms def test_radam_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() radam_configs = [ {"lr": 0.1}, @@ -431,9 +421,8 @@ def test_radam_1d_sharding(self): inp = torch.ones(8, 10, device=self.device_type) self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) - @with_comms def test_adamax_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() adamax_configs = [ {"lr": 0.1}, @@ -471,9 +460,8 @@ def test_adamax_1d_sharding(self): inp = torch.ones(8, 10, device=self.device_type) self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp) - @with_comms def test_asgd_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() asgd_configs = [ {"lr": 0.1}, diff --git a/test/distributed/_tensor/test_random_ops.py b/test/distributed/_tensor/test_random_ops.py index cfd8365fc7116..153ae555af7ba 100644 --- a/test/distributed/_tensor/test_random_ops.py +++ b/test/distributed/_tensor/test_random_ops.py @@ -17,14 +17,13 @@ from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, + DTensorOpTestBase, skip_if_lt_x_gpu, skip_unless_torch_gpu, - with_comms, ) -class DistTensorRandomInitTest(DTensorTestBase): +class DistTensorRandomInitTest(DTensorOpTestBase): def _run_init_op(self, init_op, *args, **kwargs): device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] @@ -62,7 +61,6 @@ def _run_init_op(self, init_op, *args, **kwargs): # other rank should have a different local tensor self.assertNotEqual(dtensor.full_tensor()[slice_idx], local_tensor) - @with_comms def test_init_ops(self): self._run_init_op( torch.nn.init.kaiming_uniform_, @@ -79,8 +77,7 @@ def test_init_ops(self): self._run_init_op(torch.randint_like, low=0, high=100, dtype=dtype) -class DistTensorRandomOpTest(DTensorTestBase): - @with_comms +class DistTensorRandomOpTest(DTensorOpTestBase): @skip_unless_torch_gpu def test_rng_tracker_init(self): torch.cuda.manual_seed(self.rank) @@ -88,23 +85,21 @@ def test_rng_tracker_init(self): broadcast_object_list(object_list) seed_from_rank_0 = int(object_list[0]) - device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + device_mesh = self.build_device_mesh() # seed synchronization happens after the first `distribute_tensor` call dtensor = distribute_tensor( torch.empty([self.world_size], device="cuda"), device_mesh, [Shard(0)] ) self.assertEqual(seed_from_rank_0, random._rng_tracker.get_seed("parallel-rng")) - @with_comms @skip_unless_torch_gpu def test_manual_seed(self): - device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + device_mesh = self.build_device_mesh() manual_seed(1234, device_mesh) self.assertEqual(1234, random._rng_tracker.get_seed("parallel-rng")) with self.assertRaisesRegex(RuntimeError, "different seed values"): manual_seed(self.rank, device_mesh) - @with_comms @skip_unless_torch_gpu def test_deterministic_dropout_1d(self): # test suite sets each rank's seed to the same value but in actual @@ -113,7 +108,7 @@ def test_deterministic_dropout_1d(self): # torch random generator keeps different seeds on ranks. torch.cuda.manual_seed(self.rank) # TODO: add test before/after enabling distribute region - device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + device_mesh = self.build_device_mesh() size = [4, 4] dtensor = distribute_tensor( @@ -145,10 +140,9 @@ def test_deterministic_dropout_1d(self): local_tensor[other_slice, :], ) - @with_comms @skip_unless_torch_gpu def test_deterministic_rand_1d(self): - device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + device_mesh = self.build_device_mesh() size = [4, 4 * self.world_size] for fn in [ @@ -188,7 +182,6 @@ def test_deterministic_rand_1d(self): local_tensor[other_slice, :], ) - @with_comms @skip_if_lt_x_gpu(4) def test_deterministic_uniform_2d(self): mesh = torch.arange(self.world_size).reshape(2, 2) @@ -290,7 +283,6 @@ def test_deterministic_uniform_2d(self): else: self.assertNotEqual(full_tensor[slice_idx], local_tensor) - @with_comms @skip_if_lt_x_gpu(4) def test_meta_tensor_init(self): # test suite sets each rank's seed to the same value but in actual @@ -300,7 +292,7 @@ def test_meta_tensor_init(self): # that Replicate DTensor will have the same initialized results # across ranks. torch.cuda.manual_seed(self.rank) - device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + device_mesh = self.build_device_mesh() size = [1024, 2048] meta_dtensor = distribute_tensor( torch.empty(*size, device="meta"), device_mesh, [Replicate()] diff --git a/test/distributed/_tensor/test_redistribute.py b/test/distributed/_tensor/test_redistribute.py index b07378375965d..bd96a839db949 100644 --- a/test/distributed/_tensor/test_redistribute.py +++ b/test/distributed/_tensor/test_redistribute.py @@ -11,22 +11,20 @@ from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, - with_comms, + DTensorOpTestBase, ) funcol = torch.ops.c10d_functional -class RedistributeTest(DTensorTestBase): +class RedistributeTest(DTensorOpTestBase): @property def world_size(self): return 4 - @with_comms def test_shard_to_replicate_forward_backward(self): # 1) test shard -> replicate forward - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() replica_spec = [Replicate()] input_sizes_and_shard_dim = [ @@ -65,9 +63,8 @@ def test_shard_to_replicate_forward_backward(self): ) self.assertEqual(comm_mode.get_total_counts(), 0) - @with_comms def test_replicate_to_replicate_forward_backward(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() replica_spec = [Replicate()] local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True) @@ -93,9 +90,8 @@ def test_replicate_to_replicate_forward_backward(self): self.assertEqual(grad_input.to_local(), torch.ones(12, 3)) self.assertEqual(comm_mode.get_total_counts(), 0) - @with_comms def test_replicate_to_local_partial_grad(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() replica_spec = [Replicate()] local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True) @@ -112,9 +108,8 @@ def test_replicate_to_local_partial_grad(self): self.assertEqual(comm_mode.get_total_counts(), 1) self.assertEqual(comm_mode.get_comm_counts()[funcol.all_reduce], 1) - @with_comms def test_replicate_to_shard_forward_backward(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() replica_spec = [Replicate()] input_sizes_and_shard_dim = [ @@ -160,13 +155,12 @@ def test_replicate_to_shard_forward_backward(self): comm_mode.get_comm_counts()[funcol.all_gather_into_tensor], 1 ) - @with_comms def test_partial_to_replicate_forward_backward(self): # Although we don't allow user to reshard to produce a partial # placement (i.e. user can't reshard to partial), we do allow # replicate to partial internally, and also partial to replicate # backward should work as expected - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() partial_local = torch.ones(12, 3, device=self.device_type, requires_grad=True) partial_spec = [_Partial()] replica_spec = [Replicate()] @@ -195,9 +189,8 @@ def test_partial_to_replicate_forward_backward(self): self.assertEqual(partial_local.grad, torch.ones_like(partial_local)) self.assertEqual(comm_mode.get_total_counts(), 0) - @with_comms def test_replicate_to_partial(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True) partial_spec = _Partial() replica_spec = Replicate() @@ -243,9 +236,8 @@ def test_replicate_to_partial(self): ) self.assertEqual(comm_mode.get_total_counts(), 0) - @with_comms def test_partial_to_shard(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() partial_spec = [_Partial()] my_rank = device_mesh.get_rank() @@ -298,9 +290,8 @@ def test_partial_to_shard(self): comm_mode.get_comm_counts()[funcol.reduce_scatter_tensor], 1 ) - @with_comms def test_redistribute_negative_shard_dim(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True) shard_spec = [Shard(1)] shard_minus_spec = [Shard(-1)] @@ -310,7 +301,6 @@ def test_redistribute_negative_shard_dim(self): reshard_tensor = shard_tensor.redistribute(device_mesh, shard_minus_spec) self.assertEqual(shard_tensor.placements[0].dim, 1) - @with_comms def test_redistribute_uneven_sharding(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2)) data_to_test = [ @@ -335,12 +325,11 @@ def test_redistribute_uneven_sharding(self): self.assertEqual(dt_full_tensor, input_tensor) -class MultiDimRedistributeTest(DTensorTestBase): +class MultiDimRedistributeTest(DTensorOpTestBase): @property def world_size(self) -> int: return 8 - @with_comms def test_multi_dim_mesh(self): devices = torch.arange(self.world_size) for mesh_shape in [devices, devices.view(4, 2), devices.view(2, 2, 2)]: diff --git a/test/distributed/_tensor/test_tensor_ops.py b/test/distributed/_tensor/test_tensor_ops.py index e4d1e3ecfd95f..1d4b5f3a3a77f 100644 --- a/test/distributed/_tensor/test_tensor_ops.py +++ b/test/distributed/_tensor/test_tensor_ops.py @@ -9,25 +9,22 @@ from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorConverter, - DTensorTestBase, - with_comms, + DTensorOpTestBase, ) -class DistTensorOpsTest(DTensorTestBase): - @with_comms +class DistTensorOpsTest(DTensorOpTestBase): def test_aten_contiguous(self): # this op not covered by dtensor_ops - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() self._test_op( mesh, lambda x: torch.ops.aten.contiguous(x), torch.randn(16, 32), ) - @with_comms def test_detach(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] tensor_to_detach = torch.randn(12, 8, requires_grad=True) @@ -35,9 +32,8 @@ def test_detach(self): detached_mat = mat.detach() self.assertFalse(detached_mat is mat) - @with_comms def test_clone(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() specs = [[Replicate()], [Shard(0)]] tensor_to_clone = torch.randn(12, 8, requires_grad=True) for spec in specs: @@ -46,9 +42,8 @@ def test_clone(self): self.assertFalse(cloned_mat is mat) self.assertEqual(cloned_mat.to_local(), mat.to_local()) - @with_comms def test_contiguous(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() tensor = torch.rand(3, 5, 6, requires_grad=True) sharding = [Shard(0)] dist_tensor = DTensor.from_local(tensor, device_mesh, sharding) @@ -72,9 +67,8 @@ def test_contiguous(self): new_dt.to_local().sum().backward() self.assertEqual(tensor.grad, torch.ones(3, 5, 6)) - @with_comms def test_inplace_op(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() input_tensor = torch.randn((12, 3), device=self.device_type) dt_to_add = distribute_tensor(input_tensor, mesh, [Shard(0)]) dt_to_mul = dt_to_add.clone() @@ -99,9 +93,8 @@ def test_inplace_op(self): self.assertTrue(res is dt_to_inplace_add) self.assertTrue(res.placements == tuple(shard_spec)) - @with_comms def test_op_out_variant(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() input_tensor = torch.randn((12, 3), device=self.device_type) sharded_dt_input = distribute_tensor(input_tensor, mesh, [Shard(0)]) expected_dt = sharded_dt_input.clone() + 3 @@ -120,9 +113,8 @@ def test_op_out_variant(self): self.assertTrue(res.placements == tuple(replica_spec)) self.assertEqual(replicate_out.to_local(), expected_dt.to_local()) - @with_comms def test_empty_like(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -131,9 +123,8 @@ def test_empty_like(self): # empty is not deterministic, so we only check that the shard propagation worked self.assertEqual((4, 8), empty_like_dt.to_local().shape) - @with_comms def test_fill_inplace(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -143,9 +134,8 @@ def test_fill_inplace(self): self.assertEqual(full_expected, full_like_dt.to_local()) self.assertEqual(full_expected, dist_tensor.to_local()) - @with_comms def test_full_like(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -154,9 +144,8 @@ def test_full_like(self): full_expected = torch.full((4, 8), 42.0) self.assertEqual(full_expected, full_like_dt.to_local()) - @with_comms def test_ones_like(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -165,9 +154,8 @@ def test_ones_like(self): ones_expected = torch.ones(4, 8) self.assertEqual(ones_expected, ones_like_dt.to_local()) - @with_comms def test_ones_like_partial_sum(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [_Partial()] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -178,9 +166,8 @@ def test_ones_like_partial_sum(self): ones_expected = torch.ones(dist_tensor.shape) self.assertEqual(ones_expected, ones_like_dt.full_tensor()) - @with_comms def test_fill_inplace_partial_sum(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [_Partial()] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -194,9 +181,8 @@ def test_fill_inplace_partial_sum(self): ) self.assertEqual(fill_expected, dist_tensor.full_tensor()) - @with_comms def test_zeros_like_partial_sum(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [_Partial()] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -207,9 +193,8 @@ def test_zeros_like_partial_sum(self): zeros_expected = torch.zeros(dist_tensor.shape) self.assertEqual(zeros_expected, zeros_like_dt.full_tensor()) - @with_comms def test_zero_inplace(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -219,9 +204,8 @@ def test_zero_inplace(self): self.assertEqual(zeros_expected, zeros_like_dt.to_local()) self.assertEqual(zeros_expected, dist_tensor.to_local()) - @with_comms def test_zeros_like(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -230,7 +214,6 @@ def test_zeros_like(self): zeros_expected = torch.zeros(4, 8) self.assertEqual(zeros_expected, zeros_like_dt.to_local()) - @with_comms @skip_if_lt_x_gpu(4) def test_stack(self): mesh_2d = DeviceMesh( @@ -248,7 +231,7 @@ def test_stack(self): self.assertEqual(stack_dt.placements, tuple(partial_placement)) self.assertEqual(stack_dt.shape, (2, 4, 8)) - mesh_1d = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh_1d = self.build_device_mesh() # stack before/after shard dim global_input = torch.randn(8, 8) shard1_input = distribute_tensor(global_input, mesh_1d, [Shard(1)]) @@ -268,9 +251,8 @@ def test_stack(self): torch.stack([global_input, global_input], dim=1), ) - @with_comms def test_equal(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] input_tensor_1 = torch.ones(4, 4) @@ -318,9 +300,8 @@ def _test_op(self, mesh, op_call, *args, **kwargs): d_out = op_call(*d_args, **d_kwargs) self.assertEqual(d_out.full_tensor(), out) - @with_comms def test_new_full(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = [[Shard(0)], [Replicate()]] for placement in placements: global_tensor = torch.randn(12, 8) @@ -337,9 +318,8 @@ def test_new_full(self): self.assertTrue(new_full_dt.placements[0].is_replicate()) self.assertEqual(new_full_expected, new_full_dt.to_local()) - @with_comms def test_gather(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() comm_mode = CommDebugMode() # case 1 all replicate: input replicated, index replicated, output replicated @@ -386,10 +366,9 @@ def test_gather(self): self.assertEqual(output_dt.placements, [Shard(gather_dim)]) self.assertEqual(output_dt.full_tensor(), global_output) - @with_comms def test_index(self): meshes = [ - DeviceMesh(self.device_type, list(range(self.world_size))), # 1D mesh + self.build_device_mesh(), # 1D mesh # TODO(@azzolini): un-comment when DTensorConverter supports N-D mesh # DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, -1)), # 2D mesh ] @@ -495,9 +474,8 @@ def test_index(self): torch.randint(5, (12, 8, 12)), ) - @with_comms def test_where_type_promotion(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) # 1D mesh + mesh = self.build_device_mesh() # 1D mesh specs = [[Shard(0)], [Replicate()]] for spec in specs: @@ -507,9 +485,8 @@ def test_where_type_promotion(self): ref = torch.where(global_tensor > 0, 1, 0) self.assertEqual(res.full_tensor(), ref) - @with_comms def test_dtensor_dtype_conversion(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] # by default we start from bf16 dtype local_tenor = torch.randn(2, 8, dtype=torch.bfloat16) diff --git a/test/distributed/_tensor/test_utils.py b/test/distributed/_tensor/test_utils.py index 7ba49ae5204db..6a1a3308121cf 100644 --- a/test/distributed/_tensor/test_utils.py +++ b/test/distributed/_tensor/test_utils.py @@ -15,19 +15,17 @@ from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, - with_comms, + DTensorOpTestBase, ) c10d_functional = torch.ops.c10d_functional -class UtilTest(DTensorTestBase): +class UtilTest(DTensorOpTestBase): @property def world_size(self): return 8 - @with_comms def test_compute_local_shape_2d_uneven(self): # mesh: 4 * 2 mesh_tensor = torch.arange(self.world_size).reshape(4, 2) @@ -57,7 +55,7 @@ def test_compute_local_shape_2d_uneven(self): else: self.assertEqual(local_size3[1], 3) - @with_comms + def test_compute_local_shape_and_global_offset_1D(self): one_d_placements = [[Shard(0)], [Replicate()]] @@ -67,8 +65,7 @@ def test_compute_local_shape_and_global_offset_1D(self): # 2) sharding resulting in shards of different size across different ranks # 3) sharding resulting in non-empty shards of same size across all ranks for size in range(self.world_size * 2 + 1): - mesh_tensor = torch.arange(self.world_size) - device_mesh = DeviceMesh(self.device_type, mesh_tensor) + device_mesh = self.build_device_mesh() global_tensor = torch.arange(size) global_shape = global_tensor.size() @@ -88,7 +85,6 @@ def test_compute_local_shape_and_global_offset_1D(self): global_tensor[dim0_start:dim0_end], ) - @with_comms def test_compute_local_shape_and_global_offset_2D(self): two_d_placements_options = [Shard(0), Shard(1), Replicate()] # Generating 6 two-d placements combinations @@ -123,12 +119,11 @@ def test_compute_local_shape_and_global_offset_2D(self): ) -class Test2DStridedLocalShard(DTensorTestBase): +class Test2DStridedLocalShard(DTensorOpTestBase): @property def world_size(self): return 4 - @with_comms def test_fsdp1_tp_2d_dtensor_local_shards_and_offsets(self): # We are mimicking the behavior of FSDP1 + TP. # Currently, the 2D DTensor's local shard is correct, since from_local + redistribute incurs a all_gather behind the scene. @@ -162,7 +157,6 @@ def test_fsdp1_tp_2d_dtensor_local_shards_and_offsets(self): self.assertEqual(local_size, torch.Size([1, 2])) self.assertEqual(global_offset, torch.Size([self.rank, 0])) - @with_comms def test_fsdp2_tp_2d_dtensor_local_shards_and_offsets(self): # We are mimicking the behavior of FSDP2 + TP. # Currently, the 2D DTensor's local shard is incorrect for resharding, since we want to avoid extra communication. diff --git a/test/distributed/_tensor/test_view_ops.py b/test/distributed/_tensor/test_view_ops.py index 429e62588651a..55d290327f544 100644 --- a/test/distributed/_tensor/test_view_ops.py +++ b/test/distributed/_tensor/test_view_ops.py @@ -22,13 +22,12 @@ from torch.distributed._tensor.placement_types import Placement from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, - with_comms, + DTensorOpTestBase, ) from torch.utils import _pytree as pytree -class TestViewOps(DTensorTestBase): +class TestViewOps(DTensorOpTestBase): def test_view_groups(self): self.assertEqual( view_groups([2, 3], [3, 2]), @@ -184,7 +183,6 @@ def dimmap_test(self, op, args, expected_rule_output): self.assertEqual(rules, expected_rule_output) self.call_dt_test(op, args, {}, self.device_mesh) - @with_comms def test_view_ops(self): self.device_mesh = DeviceMesh( self.device_type, torch.arange(dist.get_world_size()).view(-1, 2) @@ -481,7 +479,6 @@ def test_view_ops(self): # Split(InputDim(1), (13, 2), 1), # ), # ) - @with_comms def test_complex_view_ops(self): self.device_mesh = DeviceMesh( self.device_type, torch.arange(dist.get_world_size()).view(-1, 2)