Skip to content

compiled autograd + dynamic shapes fails with constraint violation #133575

@bdhirsh

Description

@bdhirsh

example code:

import torch


@torch.compile(backend="aot_eager_decomp_partition")
def f(x):
    return x.sin().sin()

with torch._dynamo.utils.maybe_enable_compiled_autograd(True):

    x = torch.randn(2, 3, requires_grad=True)
    torch._dynamo.mark_dynamic(x, 0)
    out = f(x)
    out.sum().backward()

    x = torch.randn(4, 3, requires_grad=True)
    torch._dynamo.mark_dynamic(x, 0)
    breakpoint()
    out = f(x)
    out.sum().backward()

This fails with:

torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['inputs'][1].size()[0])! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of RelaxedUnspecConstraint(L['inputs'][1].size()[0]) are valid because L['inputs'][1].size()[0] was inferred to be a constant (2).

The proximal cause is that:

(1) compiled autograd is passing in a FakeTensor tangent in the backward that has only static shapes, when its shape is supposed to be dynamic: https://github.com/pytorch/pytorch/blob/main/torch/_functorch/_aot_autograd/runtime_wrappers.py#L1709

(2) when we trace the backward graph from AOTDispatcher, we do some compute like aten.mul(activation, tangent). The activation has a (s0, 3) size, while the tangent has static shape (2, 3), so we infer that s0 == 2 and incorrectly specialize the shape.

I'm not entirely sure how compiled autograd figures out that it should be fakeifying tensors with dynamic or static shape, but maybe we need to properly plumb this information?

cc @ezyang @chauhang @penguinwu

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions