Skip to content

Commit

Permalink
[dynamo] handle setting .data on a tensor (#113080)
Browse files Browse the repository at this point in the history
**Dynamo**

We don't want setattr in the graph. Setting data has interesting implications on both aliasing and on the autograd engine.

The safe recipe is:

1) Disable grad
2) Call set_()
3) Manually lower the version counter on the object to hide it from the autograd engine

This is effectively the same exact thing as setting .data, and it composes properly with aot_autograd and inductor.

**aot_autograd**

For aot_autograd, there's another snag.

Specifically, when we invoke aot_autograd, we call `fake_mode.from_tensor()`, relying on memo to get the right tensor out. For .data mutations, this doesn't work, because the memoized fake_tensor is in the state it will be in at the end of the trace, not at the beginning. This means that the .data call is already applied, and the tensor shape (as in the case of these tests) mismatches. aot_autograd produces an invalid graph, with illegal calls like `torch.ops.aten.view.default(primals_2, [0])` where primals is actually sized `([6])` on input.

The new plan here is to:
1) Record tensor fakification policy in dynamo
2) provide a fresh fake mode to all backends
3) Invoke from_tensor with the stored policy to get fresh new fake tensors in aot_autograd

Pull Request resolved: #113080
Approved by: https://github.com/bdhirsh
  • Loading branch information
voznesenskym authored and pytorchmergebot committed Dec 2, 2023
1 parent 77c4565 commit 4cfe997
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 18 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ Tensor& set__symint(Tensor& result, const Tensor& storage, c10::SymInt storage_o

Tensor& set_tensor_(Tensor& result, const Tensor& source) {
if (result.unsafeGetTensorImpl() != source.unsafeGetTensorImpl()) {
return result.set_(source.storage(), source.storage_offset(), source.sizes(), source.strides());
return result.set__symint(source.storage(), source.sym_storage_offset(), source.sym_sizes(), source.sym_strides());
}
return result;
}
Expand Down
12 changes: 12 additions & 0 deletions test/dynamo/test_aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,18 @@ def f(x):
),
)

def test_aot_autograd_raises_invalid_leaf_set(self):
@torch.compile
def f(x):
x.set_(torch.ones(2))

# We still want to make sure that this raises
x = torch.ones(2, requires_grad=True)
with self.assertRaisesRegex(
RuntimeError, "is being used in an in-place operation"
):
f(x)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
48 changes: 37 additions & 11 deletions test/dynamo/test_input_attr_tracking.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
# Owner(s): ["module: dynamo"]
# flake8: noqa
import torch
import torch._dynamo
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.testing import CompileCounter
from torch._dynamo.testing import (
CompileCounter,
CompileCounterWithBackend,
EagerAndRecordGraphs,
normalize_gm,
)


class TestInputAttrTracking(torch._dynamo.test_case.TestCase):
Expand Down Expand Up @@ -290,22 +296,42 @@ def fn(x, y):

eager_result = fn(x, y)

counter = CompileCounter()
eager_and_record = EagerAndRecordGraphs()

counter = CompileCounterWithBackend(eager_and_record)

fn = torch._dynamo.optimize(counter, nopython=True)(fn)

compile_result = fn(x, y)

graph = eager_and_record.graphs[0]
actual = normalize_gm(graph.print_readable(False))

self.assertEqual(compile_result, eager_result)
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 2)
# Graph for reference
# __compiled_fn_0 <eval_with_key>.0 opcode name target args kwargs
# ------------- ------ ----------------------- ------------ --------
# placeholder l_x_ L_x_ () {}
# placeholder l_y_ L_y_ () {}
# call_method detach detach (l_y_,) {}
# call_function mul <built-in function mul> (l_x_, l_y_) {}
# output output output ((mul,),) {}
self.assertEqual(counter.op_count, 6)
self.assertExpectedInline(
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_y_ : torch.Tensor, L_x_ : torch.Tensor):
l_y_ = L_y_
l_x_ = L_x_
detach = l_y_.detach()
_set_grad_enabled = torch._C._set_grad_enabled(False)
set_ = torch_Tensor_set_(l_x_, detach); detach = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True)
_lower_version_count_by_1 = torch__dynamo_variables_builtin__lower_version_count_by_1(set_); set_ = None
mul = l_x_ * l_y_; l_x_ = l_y_ = None
return (mul,)
""",
)

# Note - this does not actually get captured in the graph yet.
# The plan of record is to introduce a set_data op, entirely subsume the operation into a call_function
Expand Down
56 changes: 56 additions & 0 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -3677,6 +3677,62 @@ def fn(x=None):
self.assertEqual(opt_fn("10"), fn("10"))
self.assertEqual(cnt.frame_count, 4)

def test_tensor_set_data(self):
# https://github.com/pytorch/pytorch/issues/113030
def func1(x, y):
x.data = y
x.add_(1)
return x

def func2(x, y):
x.data = y
y.data = torch.zeros([0])
return x

def func3(x, y):
z = x
x.data = y
y.data = torch.zeros([0])
return x is z

for backend in ["eager", "aot_eager", "inductor"]:
for func in [func1, func2, func3]:
if backend != "eager" and func is func1:
# add_ not working w/ aot_autograd?
continue
torch._dynamo.reset()
cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)

compiled_fn = torch.compile(func, backend=cnt, fullgraph=True)
requires_grad = func is not func1
for i in range(0, 5):
# Inputs
eager_a = torch.ones([6], requires_grad=requires_grad)
compiled_a = torch.ones([6], requires_grad=requires_grad)

eager_b = torch.ones([6], requires_grad=requires_grad)
compiled_b = torch.ones([6], requires_grad=requires_grad)

# Eager
out_eager = func(eager_a, eager_b)
# Compiled
out_compiled = compiled_fn(compiled_a, compiled_b)
self.assertEqual(eager_a, compiled_a)
self.assertEqual(eager_b, compiled_b)
self.assertEqual(out_eager, out_compiled)

# func1 hits a leaf Variable that requires grad is being used in an in-place operation
if requires_grad:
bwd_inp_eager = torch.randn([6])
bwd_inp_compiled = torch.clone(bwd_inp_eager)
eager_a.backward(bwd_inp_eager)
compiled_a.backward(bwd_inp_compiled)
self.assertEqual(eager_a.grad, compiled_a.grad)

# Prove guarding works - we run the compiled_fn 5 times
# frame_count should stay at 1.
self.assertEqual(cnt.frame_count, 1)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
75 changes: 70 additions & 5 deletions torch/_dynamo/variables/builtin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import functools
import inspect
import itertools
Expand Down Expand Up @@ -1216,11 +1217,62 @@ def call_setattr(
and name_var.is_python_constant()
):
name = name_var.as_python_constant()
if name == "requires_grad" and isinstance(obj, variables.TensorVariable):
unimplemented(
"mutating requires_grad can introduce a new leaf from non-leaf or vice versa in "
"the middle of the graph, which aot_autograd does not currently know how to handle. "
)
if isinstance(obj, variables.TensorVariable):
from .builder import wrap_fx_proxy

if name == "requires_grad":
# TODO(voz): Make it work properly
unimplemented(
"mutating requires_grad can introduce a new leaf from non-leaf or vice versa in "
"the middle of the graph, which aot_autograd does not currently know how to handle. "
)
if name == "data":
# Remove the old reference in tracked fakes - if we don't do this
# new .data value size and shape differences will cause
# tracked fakes to produce incorrect guards. This is sound because the TensorVariable
# coming out of set_() below will be a new one, and get
# installed in tracked fakes.
to_remove = []
for tf in tx.output.tracked_fakes:
if tf.source == obj.source:
to_remove.append(tf)
for tf in to_remove:
tx.output.tracked_fakes.remove(tf)

# Step 1 - disable grads
with dynamo_disable_grad(tx), torch.no_grad():
# Step 2 - call `set_`
out = wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_function",
torch.Tensor.set_,
*proxy_args_kwargs([obj, val], {}),
),
)

# Step 3 - drop the version counter - this is a step required to get
# .data setting to play correctly with the autograd engine.
# Esentially, dynamo is trying to faithful preserve the (absurd)
# behavior of .data= from eager mode
def _lower_version_count_by_1(x):
version = x._version
if version > 0:
version = version - 1
torch._C._autograd._unsafe_set_version_counter(x, version)
return x

tx.output.create_proxy(
"call_function",
_lower_version_count_by_1,
(out.as_proxy(),),
{},
)
_lower_version_count_by_1(obj.as_proxy().node.meta["example_value"])
# This handles options prop, guards and ends with a clone
# Step 4 - replace all reference to the current object with the new one
return out

tx.output.side_effects.store_attr(obj, name, val)
return val
elif isinstance(obj, variables.UserDefinedObjectVariable):
Expand Down Expand Up @@ -1554,3 +1606,16 @@ def call_all(self, tx, *args, **kwargs):
return tx.inline_user_function_return(
SourcelessBuilder()(tx, polyfill.all), args, kwargs
)


@contextlib.contextmanager
def dynamo_disable_grad(tx):
from . import GradModeVariable

org_value = torch.is_grad_enabled()
gmv = GradModeVariable.create(tx, False)
try:
gmv.enter(tx)
yield
finally:
gmv.exit(tx)
3 changes: 2 additions & 1 deletion torch/_functorch/_aot_autograd/runtime_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def runtime_wrapper(*args):
if trace_joint:
assert isinstance(updated_inpt, TensorAlias)
updated_inpt = updated_inpt.alias
original_inpt.set_(updated_inpt)
with torch.no_grad():
original_inpt.set_(updated_inpt)
continue
if meta.mutates_metadata and not meta.mutates_data:
if trace_joint:
Expand Down

0 comments on commit 4cfe997

Please sign in to comment.