-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Closed
Labels
module: compiled autogradcompiled_autogradcompiled_autogradmodule: dynamic shapesoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
example code:
import torch
@torch.compile(backend="aot_eager_decomp_partition")
def f(x):
return x.sin().sin()
with torch._dynamo.utils.maybe_enable_compiled_autograd(True):
x = torch.randn(2, 3, requires_grad=True)
torch._dynamo.mark_dynamic(x, 0)
out = f(x)
out.sum().backward()
x = torch.randn(4, 3, requires_grad=True)
torch._dynamo.mark_dynamic(x, 0)
breakpoint()
out = f(x)
out.sum().backward()
This fails with:
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['inputs'][1].size()[0])! For more information, run with TORCH_LOGS="+dynamic".
- Not all values of RelaxedUnspecConstraint(L['inputs'][1].size()[0]) are valid because L['inputs'][1].size()[0] was inferred to be a constant (2).
The proximal cause is that:
(1) compiled autograd is passing in a FakeTensor tangent in the backward that has only static shapes, when its shape is supposed to be dynamic: https://github.com/pytorch/pytorch/blob/main/torch/_functorch/_aot_autograd/runtime_wrappers.py#L1709
(2) when we trace the backward graph from AOTDispatcher, we do some compute like aten.mul(activation, tangent)
. The activation has a (s0, 3
) size, while the tangent has static shape (2, 3)
, so we infer that s0 == 2
and incorrectly specialize the shape.
I'm not entirely sure how compiled autograd figures out that it should be fakeifying tensors with dynamic or static shape, but maybe we need to properly plumb this information?
Metadata
Metadata
Assignees
Labels
module: compiled autogradcompiled_autogradcompiled_autogradmodule: dynamic shapesoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module