-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Description
#127551 was an original attempt to fix this, but @ezyang pointed out that we can do a bit better:
During normal execution of the backward (unless create_graph=True), grad mode will always be disabled while the backward runs. So if there is user code that mutates a graph input during the backward, it should be safe to include in the graph, because grad_mode is no active.
The problem: right now we error if there are any mutations during the bw on inputs that require grad at compile time, while we are tracing out the joint graph here. Based on the above, it doesn't really make sense to error at compile time. Instead, we should:
(1) unconditionally trace out the backward data mutations into copy_() nodes in the backward graph.
(2) at runtime, we can add an assert in the backward that raises an error if we know there are backward mutations in the graph, and grad mode is currently enabled (meaning the user is running the backward with create_graph=True)
Example repro:
import torch
@torch.library.custom_op("mylib::clone", mutates_args={})
def f(x: torch.Tensor) -> torch.Tensor:
return x.clone()
def f_fake(x):
return torch.empty_like(x)
def backward(ctx, grad):
ctx.x.zero_()
return grad
def setup_context(ctx, inputs, output):
x, = inputs
ctx.x = x
f.register_fake(f_fake)
f.register_autograd(backward, setup_context=setup_context)
x = torch.randn(3, requires_grad=True)
y = f(x)
y.sum().backward()
def fn(x: torch.Tensor) -> torch.Tensor:
return torch.ops.mylib.clone(x)
print(x.grad) # ones
print(x) # zeros
torch.compile(fn, backend="aot_eager", fullgraph=True)(x)