Skip to content

AOTAutograd: allow input mutations in the bw that occur under no_grad #127572

@bdhirsh

Description

@bdhirsh

#127551 was an original attempt to fix this, but @ezyang pointed out that we can do a bit better:

During normal execution of the backward (unless create_graph=True), grad mode will always be disabled while the backward runs. So if there is user code that mutates a graph input during the backward, it should be safe to include in the graph, because grad_mode is no active.

The problem: right now we error if there are any mutations during the bw on inputs that require grad at compile time, while we are tracing out the joint graph here. Based on the above, it doesn't really make sense to error at compile time. Instead, we should:

(1) unconditionally trace out the backward data mutations into copy_() nodes in the backward graph.

(2) at runtime, we can add an assert in the backward that raises an error if we know there are backward mutations in the graph, and grad mode is currently enabled (meaning the user is running the backward with create_graph=True)

Example repro:

import torch
@torch.library.custom_op("mylib::clone", mutates_args={})
def f(x: torch.Tensor) -> torch.Tensor:
    return x.clone()
def f_fake(x):
    return torch.empty_like(x)
def backward(ctx, grad):
    ctx.x.zero_()
    return grad
def setup_context(ctx, inputs, output):
    x, = inputs
    ctx.x = x
f.register_fake(f_fake)
f.register_autograd(backward, setup_context=setup_context)
x = torch.randn(3, requires_grad=True)
y = f(x)
y.sum().backward()
def fn(x: torch.Tensor) -> torch.Tensor:
    return torch.ops.mylib.clone(x)
print(x.grad)  # ones
print(x)  # zeros
torch.compile(fn, backend="aot_eager", fullgraph=True)(x)

cc @ezyang @msaroufim @anijain2305 @chauhang @zou3519

Metadata

Metadata

Assignees

Labels

internal ramp-up taskTasks that are suitable for new folks w/ high-touch guidance from senior PyTorch folksmodule: aotdispatchumbrella label for AOTAutograd issuesmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions