From aab7c2b3291d0446f14f73db96015aaa16ee0a7a Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Tue, 8 Jul 2025 22:30:37 +0000 Subject: [PATCH] Add dtensor placement test --- test/run_tests.sh | 18 ++++++ test/spmd/test_dtensor_integration3.py | 86 ++++++++++++++++++++++++++ 2 files changed, 104 insertions(+) create mode 100644 test/spmd/test_dtensor_integration3.py diff --git a/test/run_tests.sh b/test/run_tests.sh index 08b38c041868..7b9e3110c3cc 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -62,6 +62,23 @@ function run_test_without_functionalization { XLA_DISABLE_FUNCTIONALIZATION=1 run_test "$@" } +function run_test_multi_devices { + if ! test_is_selected "$1"; then + return + fi + echo "Running in PjRt runtime: $@" + # TODO(darisoy): run these tests with multiple CPU devices, this fails due to TF issue. + PJRT_DEVICE=CPU CPU_NUM_DEVICES=4 run_coverage "$@" +} + +function run_test_multi_devices_without_func { + if ! test_is_selected "$1"; then + return + fi + echo "Running with XLA_DISABLE_FUNCTIONALIZATION: $@" + XLA_DISABLE_FUNCTIONALIZATION=1 run_test_multi_devices "$@" +} + function run_use_bf16 { if ! test_is_selected "$1"; then return @@ -235,6 +252,7 @@ function run_xla_op_tests3 { run_test "$_TEST_DIR/spmd/test_xla_spmd_python_api_interaction.py" run_test "$_TEST_DIR/spmd/test_dtensor_integration.py" run_test "$_TEST_DIR/spmd/test_dtensor_integration2.py" + run_test_multi_devices_without_func "$_TEST_DIR/spmd/test_dtensor_integration3.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py" run_test "$_TEST_DIR/spmd/test_mp_input_sharding.py" diff --git a/test/spmd/test_dtensor_integration3.py b/test/spmd/test_dtensor_integration3.py new file mode 100644 index 000000000000..6bff5ecf8269 --- /dev/null +++ b/test/spmd/test_dtensor_integration3.py @@ -0,0 +1,86 @@ +import os +import sys + +import torch +from torch import nn +import torch.optim as optim +from torch.distributed.tensor import (DeviceMesh, Replicate, Shard, + distribute_tensor, distribute_module, + init_device_mesh) +import torch_xla +import torch_xla.debug.metrics as met +import torch_xla.runtime as xr +import torch_xla.core.xla_model as xm +from torch_xla.distributed.spmd import auto_policy + +import unittest + +import test_xla_sharding_base + + +# This integration test passes when run independently. +class DTensorIntegrationTest3(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + super().setUpClass() + + # This test fails with functionalization, so disabled functionalization. + def test_xla_placement(self): + + class Model(torch.nn.Module): + + def __init__(self): + super().__init__() + self.in_proj = torch.nn.Linear(32, 16, bias=False) + self.out_proj = torch.nn.Linear(16, 8, bias=False) + + def forward(self, hidden): + hidden = self.in_proj(hidden) + hidden = torch.relu(hidden) + hidden = self.out_proj(hidden) + return hidden + + def forward_pure(hidden, in_proj_weight, out_proj_weight): + hidden = torch.matmul(hidden, in_proj_weight.T) + hidden = torch.relu(hidden) + hidden = torch.matmul(hidden, out_proj_weight.T) + return hidden + + #xr.use_spmd() + model = Model() + model.to('xla') + device_count = xr.global_runtime_device_count() + device_mesh = init_device_mesh( + device_type='xla', mesh_shape=(device_count,)) + + # Tensor parallel shardings + inputs_sharding = [Replicate()] + in_proj_weight_sharding = [Shard(0)] + out_proj_weight_sharding = [Shard(1)] + + torch.manual_seed(15213) + inputs = torch.rand(2, 32) + inputs = inputs.to('xla') + outputs_unsharded = model(inputs) + xm.mark_step() + outputs_unsharded = outputs_unsharded.cpu() + inputs = distribute_tensor(inputs, device_mesh, placements=inputs_sharding) + in_proj_weight = distribute_tensor( + model.in_proj.weight, device_mesh, placements=in_proj_weight_sharding) + out_proj_weight = distribute_tensor( + model.out_proj.weight, device_mesh, placements=out_proj_weight_sharding) + outputs_sharded = forward_pure(inputs, in_proj_weight, out_proj_weight) + xm.mark_step() + outputs_sharded = outputs_sharded.cpu() + #from torch_xla.distributed.spmd.debugging import visualize_sharding + #generated_table = visualize_sharding(outputs.sharding_spec(), use_color=False) + print(outputs_unsharded) + print(outputs_sharded) + torch.testing.assert_close(outputs_sharded.global_tensor.numpy(), + outputs_unsharded.detach().numpy()) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1)