Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions test/distributed/_tensor/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,29 @@ def test_to_local(self):
except RuntimeError:
self.assertEqual(sharded_tensor.grad.stride(), [1, 3 * self.world_size])

@with_comms
def test_dtensor_new_empty_strided(self):
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
local_tensor = torch.randn(8, 8, requires_grad=True, device=self.device_type)
my_dtensor = distribute_tensor(local_tensor, device_mesh, [Shard(0)])
new_strided_dtensor = my_dtensor.new_empty_strided(
(8, 8), (8, 1), requires_grad=True
)
# test the op produces new dtensor and autograd works
self.assertEqual(new_strided_dtensor.shape, my_dtensor.shape)
new_strided_dtensor.sum().backward()
self.assertIsNotNone(new_strided_dtensor.grad)
self.assertIsInstance(new_strided_dtensor.grad, DTensor)

# test backward new_empty_strided with sharding works correctly
my_dtensor.to_local().sum().backward()
local_tensor.sum().backward()
self.assertEqual(my_dtensor.grad, new_strided_dtensor.grad)
self.assertEqual(
my_dtensor.grad.redistribute(placements=[Replicate()]).to_local(),
local_tensor.grad,
)

@with_comms
def test_dtensor_async_output(self):
# Tests that if the output of some dtensor operations isn't used in any compute,
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/_tensor/ops/tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
aten.detach.default,
aten.equal.default,
aten.is_same_size.default,
aten.new_empty_strided.default, # TODO: re-think new_empty_strided
]
)
def default_strategy(
Expand Down Expand Up @@ -109,6 +108,7 @@ def create_like_strategy(
aten.new_full.default,
aten.new_ones.default,
aten.new_zeros.default,
aten.new_empty_strided.default, # TODO: re-think new_empty_strided
]
)
def new_factory_strategy(
Expand Down