Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/neuron/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ function run_xla_op_tests3 {
run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_spec_conv.py"
run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_redistribute.py"
run_test_multi_device "$_TEST_DIR/spmd/test_xla_sharded_tensor.py"
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
#run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
run_test "$_TEST_DIR/spmd/test_train_spmd_linear_model.py"
Expand Down
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ function run_xla_op_tests3 {
run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py"
run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_redistribute.py"
run_test_multi_devices "$_TEST_DIR/spmd/test_xla_sharded_tensor.py"
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
run_test "$_TEST_DIR/spmd/test_mp_input_sharding.py"
Expand Down
38 changes: 38 additions & 0 deletions test/spmd/test_xla_sharded_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import sys
import unittest
import test_xla_sharding_base
from torch.distributed.tensor import DTensor
from torch_xla.distributed.spmd import XLAShardedTensor

import torch


class XlaShardedTensorTest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
super().setUpClass()

def test_xlashardedtensor_is_dtensor(self):
"""Test that XLAShardedTensor is a subclass of DTensor."""
xt = torch.randn(128, 128).to('xla')
xla_tensor = XLAShardedTensor(xt)
self.assertIsInstance(xla_tensor, DTensor)

def test_xlashardedtensor_gradient(self):
"""Test accessing gradients of an XLAShardedTensor (triggers __torch_function__)."""
xt = torch.randn(128, 128).to('xla')
xla_tensor = XLAShardedTensor(xt, requires_grad=True)
result = xla_tensor.sum()
result.backward()

# this should trigger __torch_function__
grad = xla_tensor.grad

self.assertIsNotNone(grad)
self.assertEqual(grad.shape, xla_tensor.shape)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ run_test "$_TEST_DIR/spmd/test_fsdp_v2.py"
run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
run_test "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py"
run_test "$_TEST_DIR/spmd/test_dtensor_redistribute.py"
run_test "$_TEST_DIR/spmd/test_xla_sharded_tensor.py"
run_test "$_TEST_DIR/test_gradient_accumulation.py"
XLA_EXPERIMENTAL=nonzero:masked_select:nms run_test "$_TEST_DIR/ds/test_dynamic_shape_models.py" -v
run_test "$_TEST_DIR/test_autocast.py"
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/distributed/spmd/xla_sharded_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import Placement, Shard, Replicate, Partial
from torch.utils._pytree import tree_map_only
from torch.distributed.tensor import DTensor


@dataclass
Expand Down Expand Up @@ -63,7 +64,7 @@ def no_dispatch() -> Iterator[None]:
del guard


class XLAShardedTensor(torch.Tensor):
class XLAShardedTensor(DTensor):
"""
A wrapper around `torch.Tensor` with sharding annotation
for XLA SPMD auto-sharding. The wrapped tensors are unwrapped
Expand Down Expand Up @@ -300,4 +301,4 @@ def redistribute(self, device_mesh, placements, *, async_op: bool = False):

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
return super().__torch_function__(func, types, args, kwargs)
return super(DTensor, cls).__torch_function__(func, types, args, kwargs)