diff --git a/test/neuron/run_tests.sh b/test/neuron/run_tests.sh index d8ee9a39b03..a68e0671a3b 100755 --- a/test/neuron/run_tests.sh +++ b/test/neuron/run_tests.sh @@ -257,6 +257,7 @@ function run_xla_op_tests3 { run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_spec_conv.py" run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_redistribute.py" + run_test_multi_device "$_TEST_DIR/spmd/test_xla_sharded_tensor.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" #run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py" run_test "$_TEST_DIR/spmd/test_train_spmd_linear_model.py" diff --git a/test/run_tests.sh b/test/run_tests.sh index 54c893c7b40..0a7002c410d 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -257,6 +257,7 @@ function run_xla_op_tests3 { run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py" run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_redistribute.py" + run_test_multi_devices "$_TEST_DIR/spmd/test_xla_sharded_tensor.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py" run_test "$_TEST_DIR/spmd/test_mp_input_sharding.py" diff --git a/test/spmd/test_xla_sharded_tensor.py b/test/spmd/test_xla_sharded_tensor.py new file mode 100644 index 00000000000..a101fb9bcd7 --- /dev/null +++ b/test/spmd/test_xla_sharded_tensor.py @@ -0,0 +1,38 @@ +import sys +import unittest +import test_xla_sharding_base +from torch.distributed.tensor import DTensor +from torch_xla.distributed.spmd import XLAShardedTensor + +import torch + + +class XlaShardedTensorTest(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + super().setUpClass() + + def test_xlashardedtensor_is_dtensor(self): + """Test that XLAShardedTensor is a subclass of DTensor.""" + xt = torch.randn(128, 128).to('xla') + xla_tensor = XLAShardedTensor(xt) + self.assertIsInstance(xla_tensor, DTensor) + + def test_xlashardedtensor_gradient(self): + """Test accessing gradients of an XLAShardedTensor (triggers __torch_function__).""" + xt = torch.randn(128, 128).to('xla') + xla_tensor = XLAShardedTensor(xt, requires_grad=True) + result = xla_tensor.sum() + result.backward() + + # this should trigger __torch_function__ + grad = xla_tensor.grad + + self.assertIsNotNone(grad) + self.assertEqual(grad.shape, xla_tensor.shape) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 440db8bd28a..e1ad7c0023a 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -63,6 +63,7 @@ run_test "$_TEST_DIR/spmd/test_fsdp_v2.py" run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" run_test "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py" run_test "$_TEST_DIR/spmd/test_dtensor_redistribute.py" +run_test "$_TEST_DIR/spmd/test_xla_sharded_tensor.py" run_test "$_TEST_DIR/test_gradient_accumulation.py" XLA_EXPERIMENTAL=nonzero:masked_select:nms run_test "$_TEST_DIR/ds/test_dynamic_shape_models.py" -v run_test "$_TEST_DIR/test_autocast.py" diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index 5a049b5864e..652a2011cbd 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -11,6 +11,7 @@ from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor.placement_types import Placement, Shard, Replicate, Partial from torch.utils._pytree import tree_map_only +from torch.distributed.tensor import DTensor @dataclass @@ -63,7 +64,7 @@ def no_dispatch() -> Iterator[None]: del guard -class XLAShardedTensor(torch.Tensor): +class XLAShardedTensor(DTensor): """ A wrapper around `torch.Tensor` with sharding annotation for XLA SPMD auto-sharding. The wrapped tensors are unwrapped @@ -300,4 +301,4 @@ def redistribute(self, device_mesh, placements, *, async_op: bool = False): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): - return super().__torch_function__(func, types, args, kwargs) + return super(DTensor, cls).__torch_function__(func, types, args, kwargs) \ No newline at end of file