-
Notifications
You must be signed in to change notification settings - Fork 21.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Higher order op for preserving leaf functions through trace, particul…
…arly for getting user defined hooks to compiled autograd ghstack-source-id: 37029e876bf9889441c2f469b06add4fe754e4d5 Pull Request resolved: #109690 Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd lint lint lint lint more test more test more test more test
- Loading branch information
1 parent
34ded74
commit eae245e
Showing
5 changed files
with
401 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,251 @@ | ||
# Owner(s): ["module: dynamo"] | ||
# flake8: noqa | ||
|
||
import functools | ||
|
||
import torch | ||
|
||
import torch._dynamo.test_case | ||
import torch._dynamo.testing | ||
import torch._dynamo.utils | ||
from torch import _inductor as inductor | ||
from torch._dynamo import compiled_autograd | ||
from torch._dynamo._trace_wrapped_higher_order_op import trace_wrapped | ||
from torch._dynamo.testing import normalize_gm | ||
from torch._dynamo.utils import counters | ||
from torch.fx.experimental.proxy_tensor import make_fx | ||
|
||
|
||
def _multiply(x): | ||
return x * x | ||
|
||
|
||
def _multiply_invoke(grad): | ||
return trace_wrapped(grad, fn=_multiply) | ||
|
||
|
||
class BackwardHigherOrderOpTests(torch._dynamo.test_case.TestCase): | ||
def test_invoke_in_eager(self): | ||
x = torch.tensor([0.5, 0.5], requires_grad=True) | ||
y = torch.tensor([0.5, 0.5], requires_grad=True) | ||
|
||
def fn(x, y): | ||
x.register_hook(_multiply_invoke) | ||
return x * y | ||
|
||
out = fn(x, y) | ||
grad_out = torch.tensor([2.0, 2.0]) | ||
out.backward(grad_out) | ||
self.assertEqual(x.grad, y * grad_out) | ||
|
||
def test_invoke_in_pt2(self): | ||
for backend in ["eager", "aot_eager", "inductor"]: | ||
torch._dynamo.reset() | ||
x = torch.tensor([0.5, 0.5], requires_grad=True) | ||
y = torch.tensor([0.5, 0.5], requires_grad=True) | ||
|
||
def fn(x, y): | ||
x.register_hook(_multiply_invoke) | ||
return x * y | ||
|
||
fn = torch._dynamo.optimize(backend)(fn) | ||
out = fn(x, y) | ||
grad_out = torch.tensor([2.0, 2.0]) | ||
out.backward(grad_out) | ||
self.assertEqual(x.grad, grad_out * y) | ||
|
||
def test_invoke_make_fx_forward_contrived(self): | ||
x = torch.tensor([0.5, 0.5], requires_grad=True) | ||
out = make_fx(_multiply_invoke)(x) | ||
self.assertEqual(out(x), torch.tensor([0.25, 0.25])) | ||
actual = normalize_gm(out.print_readable(False)) | ||
|
||
expected = """\ | ||
class _multiply_invoke(torch.nn.Module): | ||
def forward(self, grad_1: f32[2]): | ||
trace_wrapped: f32[2] = torch__dynamo__trace_wrapped_higher_order_op_self_invoke(grad_1); grad_1 = None | ||
assert_1: f32[2] = torch__dynamo__trace_wrapped_higher_order_op__assert_meta(trace_wrapped, (2,), (1,), torch.float32); trace_wrapped = None | ||
detach: f32[2] = torch.ops.aten.detach.default(assert_1); assert_1 = None | ||
detach_1: f32[2] = torch.ops.aten.detach.default(detach); detach = None | ||
detach_2: f32[2] = torch.ops.aten.detach.default(detach_1); detach_1 = None | ||
detach_3: f32[2] = torch.ops.aten.detach.default(detach_2); detach_2 = None | ||
return detach_3 | ||
""" | ||
self.assertExpectedInline(actual, expected) | ||
|
||
def test_invoke_make_bw(self): | ||
x = torch.tensor([0.5, 0.5], requires_grad=True) | ||
|
||
def fwd(x): | ||
z = x * x | ||
return z + z | ||
|
||
res = fwd(x) | ||
res.backward(torch.tensor([1.0, 1.0])) | ||
out = make_fx(_multiply_invoke)(x.grad) | ||
self.assertEqual(out(x.grad), torch.tensor([4.0, 4.0])) | ||
actual = normalize_gm(out.print_readable(False)) | ||
|
||
expected = """\ | ||
class _multiply_invoke(torch.nn.Module): | ||
def forward(self, grad_1: f32[2]): | ||
trace_wrapped: f32[2] = torch__dynamo__trace_wrapped_higher_order_op_self_invoke(grad_1); grad_1 = None | ||
assert_1: f32[2] = torch__dynamo__trace_wrapped_higher_order_op__assert_meta(trace_wrapped, (2,), (1,), torch.float32); trace_wrapped = None | ||
return assert_1 | ||
""" | ||
self.assertExpectedInline(actual, expected) | ||
|
||
def test_invoke_in_pt2_compiled_autograd(self): | ||
graph = None | ||
|
||
def compiler_fn(gm): | ||
def inner_compiler(gm_, example_inputs_): | ||
nonlocal graph | ||
self.assertEqual(graph, None) | ||
graph = gm_ | ||
return inductor.compile(gm_, example_inputs_) | ||
|
||
return torch.compile( | ||
gm, backend=inner_compiler, fullgraph=True, dynamic=True | ||
) | ||
|
||
for backend in ["eager", "aot_eager", "inductor"]: | ||
torch._dynamo.reset() | ||
x = torch.tensor([0.5, 0.5], requires_grad=True) | ||
y = torch.tensor([0.5, 0.5], requires_grad=True) | ||
|
||
def fn(x, y): | ||
x.register_hook(_multiply_invoke) | ||
return x + y | ||
|
||
fn = torch._dynamo.optimize(backend)(fn) | ||
out = fn(x, y) | ||
grad_out = torch.tensor([2.0, 2.0]) | ||
with compiled_autograd.enable(compiler_fn): | ||
out.backward(grad_out) | ||
actual = normalize_gm(graph.print_readable(False)) | ||
self.assertEqual(x.grad, grad_out * grad_out) | ||
expected = """\ | ||
class GraphModule(torch.nn.Module): | ||
def forward(self, L_inputs_0_ : torch.Tensor): | ||
getitem = L_inputs_0_ | ||
new_empty_strided = torch.ops.aten.new_empty_strided.default(getitem, [2], [1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu')) | ||
copy_ = torch.ops.aten.copy_.default(new_empty_strided, getitem); new_empty_strided = None | ||
call_hook = getitem * getitem; getitem = None | ||
new_empty_strided_1 = torch.ops.aten.new_empty_strided.default(call_hook, [2], [1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu')) | ||
copy__1 = torch.ops.aten.copy_.default(new_empty_strided_1, call_hook); new_empty_strided_1 = call_hook = None | ||
return (copy_, copy__1) | ||
""" | ||
self.assertExpectedInline(actual, expected) | ||
|
||
graph = None | ||
|
||
def test_invoke_in_pt2_compiled_autograd_side_effect(self): | ||
def _side_effect_stateful_fn2(x, obj): | ||
obj.counter = obj.counter + 1 | ||
return _multiply(x) | ||
|
||
def _side_effectful_invoke2(grad, fn): | ||
return trace_wrapped(grad, fn=fn) | ||
|
||
graph = None | ||
|
||
def compiler_fn(gm): | ||
def inner_compiler(gm_, example_inputs_): | ||
nonlocal graph | ||
self.assertEqual(graph, None) | ||
graph = gm_ | ||
return inductor.compile(gm_, example_inputs_) | ||
|
||
return torch.compile( | ||
gm, backend=inner_compiler, fullgraph=True, dynamic=True | ||
) | ||
|
||
for backend in ["eager", "aot_eager", "inductor"]: | ||
torch._dynamo.reset() | ||
x = torch.tensor([0.5, 0.5], requires_grad=True) | ||
y = torch.tensor([0.5, 0.5], requires_grad=True) | ||
|
||
class MyObj: | ||
def __init__(self): | ||
self.counter = 0 | ||
|
||
obj = MyObj() | ||
inner_fn = functools.partial(_side_effect_stateful_fn2, obj=obj) | ||
hook_fn = functools.partial(_side_effectful_invoke2, fn=inner_fn) | ||
x.register_hook(hook_fn) | ||
|
||
def fn(x, y): | ||
return x + y | ||
|
||
fn = torch._dynamo.optimize(backend, nopython=True)(fn) | ||
out = fn(x, y) | ||
grad_out = torch.tensor([2.0, 2.0]) | ||
with compiled_autograd.enable(compiler_fn): | ||
out.backward(grad_out) | ||
actual = normalize_gm(graph.print_readable(False)) | ||
self.assertEqual(obj.counter, 1) | ||
self.assertEqual(x.grad, grad_out + grad_out) | ||
expected = """\ | ||
class GraphModule(torch.nn.Module): | ||
def forward(self, L_inputs_0_ : torch.Tensor): | ||
getitem = L_inputs_0_ | ||
new_empty_strided = torch.ops.aten.new_empty_strided.default(getitem, [2], [1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu')) | ||
copy_ = torch.ops.aten.copy_.default(new_empty_strided, getitem); new_empty_strided = None | ||
call_hook = getitem * getitem; getitem = None | ||
new_empty_strided_1 = torch.ops.aten.new_empty_strided.default(call_hook, [2], [1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu')) | ||
copy__1 = torch.ops.aten.copy_.default(new_empty_strided_1, call_hook); new_empty_strided_1 = call_hook = None | ||
return (copy_, copy__1) | ||
""" | ||
self.assertExpectedInline(actual, expected) | ||
|
||
out = fn(x, y) | ||
out.backward(grad_out) | ||
self.assertEqual(obj.counter, 2) | ||
|
||
out = fn(x, y) | ||
out.backward(grad_out) | ||
self.assertEqual(obj.counter, 3) | ||
graph = None | ||
|
||
def test_invoke_in_pt2_compiled_autograd_graph_breaks(self): | ||
def _graph_breaking_fn(x): | ||
print("Boo!") | ||
return _multiply(x) | ||
|
||
def _graph_break_invoke(grad): | ||
return trace_wrapped(grad, fn=_graph_breaking_fn) | ||
|
||
def compiler_fn(gm): | ||
return torch.compile(gm, backend="inductor", fullgraph=True, dynamic=True) | ||
|
||
for backend in ["eager", "aot_eager", "inductor"]: | ||
torch._dynamo.reset() | ||
x = torch.tensor([0.5, 0.5], requires_grad=True) | ||
y = torch.tensor([0.5, 0.5], requires_grad=True) | ||
|
||
def fn(x, y): | ||
x.register_hook(_graph_break_invoke) | ||
return x + y | ||
|
||
fn = torch._dynamo.optimize(backend, nopython=True)(fn) | ||
out = fn(x, y) | ||
grad_out = torch.tensor([2.0, 2.0]) | ||
with self.assertRaisesRegex( | ||
torch._dynamo.exc.Unsupported, | ||
"print", | ||
): | ||
with compiled_autograd.enable(compiler_fn): | ||
out.backward(grad_out) | ||
|
||
graph = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
from torch._C import DispatchKey | ||
from torch._higher_order_ops.utils import autograd_not_implemented | ||
|
||
from torch._ops import HigherOrderOperator | ||
from torch._subclasses import FakeTensorMode | ||
|
||
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree | ||
from torch.utils._python_dispatch import _get_current_dispatch_mode | ||
|
||
|
||
__all__ = ["trace_wrapped"] | ||
|
||
|
||
# trace_wrapped(*args, fn) is equivalent to fn(*args), but with a twist: | ||
# if you make_fx trace through this call, we will not actually trace into fn; instead, | ||
# we will directly insert it as a call_function to fn in the graph. | ||
# (Unlike make_fx, Dynamo WILL inline into fn.) | ||
# You can think of this as a one off allow_in_graph equivalent for proxy tensor tracing. | ||
# | ||
# Because proxy tensor tracing does not actually run the function, there are | ||
# requirements on the behavior of fn. We are still figuring it out, but here is the current state: | ||
# | ||
# 1) fn SHOULD only take a single argument, which must be a tensor | ||
# 2) fn MUST return a new tensor with the same metadata as the original tensor | ||
# (e.g., zeros_like(input) is a permissible implementation of fn). | ||
# This is verified via an extra assert that is inserted into the traced graph. | ||
# 3) fn MAY have side effects, but it MAY NOT perform metadata mutation on other tensors | ||
# participating in proxy tensor tracing (it MAY mutate other tensors, it MAY mutate Python state) | ||
# These requirements stem from the requirement that we need to continue performing proxy tensor tracing, | ||
# which assumes accurate fake tensor metadata, without actually running fn. | ||
# In the future, we may allow for a "meta" function associated with fn to allow for more interesting input-output patterns. | ||
# | ||
# Note that tensors / Python state are allowed to be mutated. | ||
# This is relaxed constraint is not always sound, but it is sound for backward tracing with fake | ||
# tensors as it takes place in AOTAutograd, as the backward pass is guaranteed not to depend on concrete | ||
# tensor values (via fake tensor) or Python state (because the autograd engine doesn't depend on Python). | ||
# | ||
# The intended use case for this function is to allow AOTAutograd to defer complex | ||
# backward hooks to compiled autograd. AOTAutograd performs a make_fx trace which preserves | ||
# the function call as is in the graph, and only when we Dynamo through the backward graph in | ||
# compiled autograd do we inline into the function. | ||
|
||
|
||
def trace_wrapped(*args, fn): | ||
return _trace_wrapped_op(*args, fn=fn) | ||
|
||
|
||
_trace_wrapped_op = HigherOrderOperator("trace_wrapped") | ||
|
||
|
||
def _assert_meta(grad, size, stride, dtype): | ||
assert grad.size() == size, "size mismatch" | ||
assert grad.stride() == stride, "stride mismatch" | ||
assert grad.dtype == dtype, "dtype mismatch" | ||
return grad | ||
|
||
|
||
@_trace_wrapped_op.py_impl(ProxyTorchDispatchMode) | ||
def inner_trace(mode, *args, fn): | ||
import torch | ||
|
||
assert len(args) == 1 | ||
grad = args[0] | ||
assert isinstance(grad, torch.Tensor) | ||
|
||
def self_invoke(*args): | ||
return _trace_wrapped_op(*args, fn=fn) | ||
|
||
proxy_args = (mode.tracer.unwrap_proxy(grad),) | ||
out_proxy = mode.tracer.create_proxy( | ||
"call_function", self_invoke, proxy_args, {}, name="trace_wrapped" | ||
) | ||
grad = torch.zeros_like(grad) | ||
grad = track_tensor_tree(grad, out_proxy, constant=None, tracer=mode.tracer) | ||
|
||
# We have a little shortcut here, wherein we DO NOT yet run a meta func, and so | ||
# we take on an assumption that input and output meta matches. As such, we must introduce | ||
# a runtime assert | ||
proxy_args = ( | ||
mode.tracer.unwrap_proxy(grad), | ||
grad.size(), | ||
grad.stride(), | ||
grad.dtype, | ||
) | ||
out_proxy = mode.tracer.create_proxy( | ||
"call_function", | ||
_assert_meta, | ||
proxy_args, | ||
{}, | ||
name="assert", | ||
) | ||
grad = torch.empty_like(grad) | ||
grad = track_tensor_tree(grad, out_proxy, constant=None, tracer=mode.tracer) | ||
return grad | ||
|
||
|
||
@_trace_wrapped_op.py_impl(FakeTensorMode) | ||
def inner_fake(*args, fn): | ||
raise RuntimeError("This op should never be invoked here") | ||
|
||
|
||
@_trace_wrapped_op.py_impl(DispatchKey.CompositeExplicitAutograd) | ||
def _trace_wrapped_op_dense(*args, fn): | ||
mode = _get_current_dispatch_mode() | ||
assert mode is None, "Mode should never be enabled for CPU/CUDA key" | ||
return fn(*args) | ||
|
||
|
||
_trace_wrapped_op.py_impl(DispatchKey.Autograd)( | ||
autograd_not_implemented(_trace_wrapped_op, deferred_error=True) | ||
) | ||
|
||
|
||
@_trace_wrapped_op.py_functionalize_impl | ||
def _trace_wrapped_functionalized(ctx, *args, fn): | ||
unwrapped_args = ctx.unwrap_tensors(args) | ||
wrapped_fn = ctx.functionalize(fn) | ||
with ctx.redispatch_to_next(): | ||
return ctx.wrap_tensors(_trace_wrapped_op(*unwrapped_args, fn=wrapped_fn)) | ||
|
||
|
||
# TODO(voz): Make this automatic for keys, this is very ugly atm | ||
_trace_wrapped_op.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined] | ||
_trace_wrapped_op.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined] | ||
_trace_wrapped_op.fallthrough(DispatchKey.ADInplaceOrView) | ||
_trace_wrapped_op.fallthrough(DispatchKey.BackendSelect) | ||
_trace_wrapped_op.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined] | ||
_trace_wrapped_op.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.