-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Closed
Labels
oncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queueoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
Smaller repro generated from an internal issue:
+ def test_dtensor_noncontiguous_output(self):
+ mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
+
+ # test passing in DTensor as inputs/outputs and run some tensor computation
+ def fn(x, y, z):
+ x_transposed = x.permute(0, 2, 1).contiguous()
+ tmp = torch._C._nn.linear(x_transposed, y, z)
+ return tmp.permute(0, 2, 1)
+
+ x_inner = torch.randn(4, 16, 4, requires_grad=True)
+ y_inner = torch.randn(4, 16, requires_grad=True)
+ z_inner = torch.randn(4, requires_grad=True)
+ x = DTensor.from_local(x_inner, mesh, [Shard(1)], run_check=False)
+ y = DTensor.from_local(y_inner, mesh, [Shard(1)], run_check=False)
+ z = DTensor.from_local(z_inner, mesh, [Replicate()], run_check=False)
+ out = torch.compile(fn, backend="aot_eager", fullgraph=True)(x, y, z)
+ out.contiguous().sum().backward()
+
Fails with:
File "/data/users/hirsheybar/b/pytorch/torch/_subclasses/fake_tensor.py", line 2190, in _dispatch_impl
r = func(*args, **kwargs)
File "/data/users/hirsheybar/b/pytorch/torch/_ops.py", line 571, in __call__
return self_._op(*args, **kwargs)
File "/data/users/hirsheybar/b/pytorch/torch/_refs/__init__.py", line 4475, in view
return _reshape_view_helper(a, *shape, allow_copy=False)
File "/data/users/hirsheybar/b/pytorch/torch/_refs/__init__.py", line 3643, in _reshape_view_helper
raise ValueError(msg)
torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised:
ValueError: Cannot view a tensor with shape torch.Size([4, 4, 4]) and strides (16, 1, 4) as a tensor with shape (16, 4)!
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @ezyang @msaroufim @anijain2305 @zou3519
Metadata
Metadata
Assignees
Labels
oncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queueoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module