Skip to content

DTensor + compile with non-contiguous output failure #118596

@bdhirsh

Description

@bdhirsh

Smaller repro generated from an internal issue:

+    def test_dtensor_noncontiguous_output(self):
+        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
+
+        # test passing in DTensor as inputs/outputs and run some tensor computation
+        def fn(x, y, z):
+            x_transposed = x.permute(0, 2, 1).contiguous()
+            tmp = torch._C._nn.linear(x_transposed, y, z)
+            return tmp.permute(0, 2, 1)
+
+        x_inner = torch.randn(4, 16, 4, requires_grad=True)
+        y_inner = torch.randn(4, 16, requires_grad=True)
+        z_inner = torch.randn(4, requires_grad=True)
+        x = DTensor.from_local(x_inner, mesh, [Shard(1)], run_check=False)
+        y = DTensor.from_local(y_inner, mesh, [Shard(1)], run_check=False)
+        z = DTensor.from_local(z_inner, mesh, [Replicate()], run_check=False)
+        out = torch.compile(fn, backend="aot_eager", fullgraph=True)(x, y, z)
+        out.contiguous().sum().backward()
+

Fails with:

   File "/data/users/hirsheybar/b/pytorch/torch/_subclasses/fake_tensor.py", line 2190, in _dispatch_impl
    r = func(*args, **kwargs)
  File "/data/users/hirsheybar/b/pytorch/torch/_ops.py", line 571, in __call__
    return self_._op(*args, **kwargs)
  File "/data/users/hirsheybar/b/pytorch/torch/_refs/__init__.py", line 4475, in view
    return _reshape_view_helper(a, *shape, allow_copy=False)
  File "/data/users/hirsheybar/b/pytorch/torch/_refs/__init__.py", line 3643, in _reshape_view_helper
    raise ValueError(msg)
torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised:
ValueError: Cannot view a tensor with shape torch.Size([4, 4, 4]) and strides (16, 1, 4) as a tensor with shape (16, 4)!

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @ezyang @msaroufim @anijain2305 @zou3519

Metadata

Metadata

Assignees

Labels

oncall: distributedAdd this issue/PR to distributed oncall triage queueoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions