Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd #109690

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
2208e96
Higher order op for preserving leaf functions through trace, particul…
voznesenskym Sep 20, 2023
fcc5d22
Update on "Higher order op for preserving leaf functions through trac…
voznesenskym Sep 20, 2023
d5ed407
Update on "Higher order op for preserving leaf functions through trac…
voznesenskym Sep 20, 2023
f15a8b8
Update on "Higher order op for preserving leaf functions through trac…
voznesenskym Sep 20, 2023
90fb411
Update on "Higher order op for preserving leaf functions through trac…
voznesenskym Sep 20, 2023
5909aa5
Update on "Higher order op for preserving leaf functions through trac…
voznesenskym Sep 21, 2023
7564ee8
Update on "Higher order op for preserving leaf functions through trac…
voznesenskym Sep 21, 2023
d3e0685
Update on "Higher order op for preserving leaf functions through trac…
voznesenskym Sep 21, 2023
e793656
Update on "Higher order op for preserving leaf functions through trac…
voznesenskym Sep 21, 2023
722d3de
Update on "Higher order op for preserving leaf functions through trac…
voznesenskym Sep 21, 2023
18c134b
Update on "Higher order op for preserving leaf functions through trac…
voznesenskym Sep 22, 2023
9b5f464
Update on "Higher order op for preserving leaf functions through trac…
voznesenskym Sep 22, 2023
78a4b1f
Update on "Higher order op for preserving leaf functions through trac…
voznesenskym Sep 22, 2023
7fd2e88
Update on "Higher order op for preserving leaf functions through trac…
voznesenskym Sep 26, 2023
735a68d
Update on "Higher order op for preserving leaf functions through trac…
voznesenskym Sep 26, 2023
807641b
Update on "Higher order op for preserving leaf functions through trac…
voznesenskym Sep 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
251 changes: 251 additions & 0 deletions test/dynamo/test_backward_higher_order_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
# Owner(s): ["module: dynamo"]
# flake8: noqa

import functools

import torch

import torch._dynamo.test_case
import torch._dynamo.testing
import torch._dynamo.utils
from torch import _inductor as inductor
from torch._dynamo import compiled_autograd
from torch._dynamo._trace_wrapped_higher_order_op import trace_wrapped
from torch._dynamo.testing import normalize_gm
from torch._dynamo.utils import counters
from torch.fx.experimental.proxy_tensor import make_fx


def _multiply(x):
return x * x


def _multiply_invoke(grad):
return trace_wrapped(grad, fn=_multiply)


class BackwardHigherOrderOpTests(torch._dynamo.test_case.TestCase):
def test_invoke_in_eager(self):
x = torch.tensor([0.5, 0.5], requires_grad=True)
y = torch.tensor([0.5, 0.5], requires_grad=True)

def fn(x, y):
x.register_hook(_multiply_invoke)
return x * y

out = fn(x, y)
grad_out = torch.tensor([2.0, 2.0])
out.backward(grad_out)
self.assertEqual(x.grad, y * grad_out)

def test_invoke_in_pt2(self):
for backend in ["eager", "aot_eager", "inductor"]:
torch._dynamo.reset()
x = torch.tensor([0.5, 0.5], requires_grad=True)
y = torch.tensor([0.5, 0.5], requires_grad=True)

def fn(x, y):
x.register_hook(_multiply_invoke)
return x * y

fn = torch._dynamo.optimize(backend)(fn)
out = fn(x, y)
grad_out = torch.tensor([2.0, 2.0])
out.backward(grad_out)
self.assertEqual(x.grad, grad_out * y)

def test_invoke_make_fx_forward_contrived(self):
x = torch.tensor([0.5, 0.5], requires_grad=True)
out = make_fx(_multiply_invoke)(x)
self.assertEqual(out(x), torch.tensor([0.25, 0.25]))
actual = normalize_gm(out.print_readable(False))

expected = """\
class _multiply_invoke(torch.nn.Module):
def forward(self, grad_1: f32[2]):
invocation: f32[2] = torch__dynamo__trace_wrapped_higher_order_op_self_invoke(grad_1); grad_1 = None
assert_1: f32[2] = torch__dynamo__trace_wrapped_higher_order_op__assert_meta(invocation, (2,), (1,), torch.float32); invocation = None
detach: f32[2] = torch.ops.aten.detach.default(assert_1); assert_1 = None
detach_1: f32[2] = torch.ops.aten.detach.default(detach); detach = None
detach_2: f32[2] = torch.ops.aten.detach.default(detach_1); detach_1 = None
detach_3: f32[2] = torch.ops.aten.detach.default(detach_2); detach_2 = None
return detach_3
"""
self.assertExpectedInline(actual, expected)

def test_invoke_make_bw(self):
x = torch.tensor([0.5, 0.5], requires_grad=True)

def fwd(x):
z = x * x
return z + z

res = fwd(x)
res.backward(torch.tensor([1.0, 1.0]))
out = make_fx(_multiply_invoke)(x.grad)
voznesenskym marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(out(x.grad), torch.tensor([4.0, 4.0]))
actual = normalize_gm(out.print_readable(False))

expected = """\
class _multiply_invoke(torch.nn.Module):
def forward(self, grad_1: f32[2]):
invocation: f32[2] = torch__dynamo__trace_wrapped_higher_order_op_self_invoke(grad_1); grad_1 = None
assert_1: f32[2] = torch__dynamo__trace_wrapped_higher_order_op__assert_meta(invocation, (2,), (1,), torch.float32); invocation = None
return assert_1
"""
self.assertExpectedInline(actual, expected)

def test_invoke_in_pt2_compiled_autograd(self):
graph = None

def compiler_fn(gm):
def inner_compiler(gm_, example_inputs_):
nonlocal graph
self.assertEqual(graph, None)
graph = gm_
return inductor.compile(gm_, example_inputs_)
voznesenskym marked this conversation as resolved.
Show resolved Hide resolved

return torch.compile(
gm, backend=inner_compiler, fullgraph=True, dynamic=True
)

for backend in ["eager", "aot_eager", "inductor"]:
torch._dynamo.reset()
x = torch.tensor([0.5, 0.5], requires_grad=True)
y = torch.tensor([0.5, 0.5], requires_grad=True)

def fn(x, y):
x.register_hook(_multiply_invoke)
return x + y

fn = torch._dynamo.optimize(backend)(fn)
out = fn(x, y)
grad_out = torch.tensor([2.0, 2.0])
with compiled_autograd.enable(compiler_fn):
out.backward(grad_out)
actual = normalize_gm(graph.print_readable(False))
self.assertEqual(x.grad, grad_out * grad_out)
expected = """\
class GraphModule(torch.nn.Module):
def forward(self, L_inputs_0_ : torch.Tensor):
getitem = L_inputs_0_

new_empty_strided = torch.ops.aten.new_empty_strided.default(getitem, [2], [1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))

copy_ = torch.ops.aten.copy_.default(new_empty_strided, getitem); new_empty_strided = None

call_hook = getitem * getitem; getitem = None

new_empty_strided_1 = torch.ops.aten.new_empty_strided.default(call_hook, [2], [1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))

copy__1 = torch.ops.aten.copy_.default(new_empty_strided_1, call_hook); new_empty_strided_1 = call_hook = None
return (copy_, copy__1)
"""
self.assertExpectedInline(actual, expected)

graph = None

def test_invoke_in_pt2_compiled_autograd_side_effect(self):
def _side_effect_stateful_fn2(x, obj):
obj.counter = obj.counter + 1
return _multiply(x)

def _side_effectful_invoke2(grad, fn):
return trace_wrapped(grad, fn=fn)

graph = None

def compiler_fn(gm):
def inner_compiler(gm_, example_inputs_):
nonlocal graph
self.assertEqual(graph, None)
graph = gm_
return inductor.compile(gm_, example_inputs_)

return torch.compile(
gm, backend=inner_compiler, fullgraph=True, dynamic=True
)

for backend in ["eager", "aot_eager", "inductor"]:
torch._dynamo.reset()
x = torch.tensor([0.5, 0.5], requires_grad=True)
y = torch.tensor([0.5, 0.5], requires_grad=True)

class MyObj:
def __init__(self):
self.counter = 0

obj = MyObj()
inner_fn = functools.partial(_side_effect_stateful_fn2, obj=obj)
hook_fn = functools.partial(_side_effectful_invoke2, fn=inner_fn)
x.register_hook(hook_fn)

def fn(x, y):
return x + y

fn = torch._dynamo.optimize(backend, nopython=True)(fn)
out = fn(x, y)
grad_out = torch.tensor([2.0, 2.0])
with compiled_autograd.enable(compiler_fn):
out.backward(grad_out)
actual = normalize_gm(graph.print_readable(False))
self.assertEqual(obj.counter, 1)
self.assertEqual(x.grad, grad_out + grad_out)
expected = """\
class GraphModule(torch.nn.Module):
def forward(self, L_inputs_0_ : torch.Tensor):
getitem = L_inputs_0_

new_empty_strided = torch.ops.aten.new_empty_strided.default(getitem, [2], [1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))

copy_ = torch.ops.aten.copy_.default(new_empty_strided, getitem); new_empty_strided = None

call_hook = getitem * getitem; getitem = None

new_empty_strided_1 = torch.ops.aten.new_empty_strided.default(call_hook, [2], [1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))

copy__1 = torch.ops.aten.copy_.default(new_empty_strided_1, call_hook); new_empty_strided_1 = call_hook = None
return (copy_, copy__1)
"""
self.assertExpectedInline(actual, expected)

out = fn(x, y)
out.backward(grad_out)
self.assertEqual(obj.counter, 2)

out = fn(x, y)
out.backward(grad_out)
self.assertEqual(obj.counter, 3)
graph = None

def test_invoke_in_pt2_compiled_autograd_graph_breaks(self):
voznesenskym marked this conversation as resolved.
Show resolved Hide resolved
def _graph_breaking_fn(x):
print("Boo!")
return _multiply(x)

def _graph_break_invoke(grad):
return trace_wrapped(grad, fn=_graph_breaking_fn)

def compiler_fn(gm):
return torch.compile(gm, backend="inductor", fullgraph=True, dynamic=True)

for backend in ["eager", "aot_eager", "inductor"]:
torch._dynamo.reset()
x = torch.tensor([0.5, 0.5], requires_grad=True)
y = torch.tensor([0.5, 0.5], requires_grad=True)

def fn(x, y):
x.register_hook(_graph_break_invoke)
return x + y

fn = torch._dynamo.optimize(backend, nopython=True)(fn)
out = fn(x, y)
grad_out = torch.tensor([2.0, 2.0])
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported,
"print",
):
with compiled_autograd.enable(compiler_fn):
out.backward(grad_out)

graph = None
111 changes: 111 additions & 0 deletions torch/_dynamo/_trace_wrapped_higher_order_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._higher_order_ops.utils import autograd_not_implemented

from torch._ops import HigherOrderOperator
from torch._subclasses import FakeTensorMode

from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
from torch.utils._python_dispatch import _get_current_dispatch_mode
voznesenskym marked this conversation as resolved.
Show resolved Hide resolved


__all__ = ["trace_wrapped"]


# Trace wrapped is a higher order op meant for both invoking a bound function,
# and for registering it as a call_function.
# This allows us to re-enter dynamo during compiled autograd to trace (or graph break)
# the functions as needed. While there is nothing backward specific about this op, the way it is written means
# we can support functions in backward with complex python. It can be thought of as an allow_in_graph
# for our aten graph. If we were to not do this, the functions would get inlined into their composing aten ops,
# and we would lose the python state mutation.
voznesenskym marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

@ezyang ezyang Sep 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a proposed rewrite of the top level comment:

trace_wrapped(*args, fn) is equivalent to fn(*args), but with a twist: if you make_fx trace through this call, we will not actually trace into fn; instead, we will directly insert it as a call_function to fn in the graph. (Unlike make_fx, Dynamo WILL inline into fn.) You can think of this as a one off allow_in_graph equivalent for proxy tensor tracing.

Because proxy tensor tracing does not actually run the function, there are requirements on the behavior of fn. We are still figuring it out, but here is the current state:

  • fn can only take a single argument, which must be a tensor
  • fn must return a new tensor with the same metadata as the original tensor (e.g., empty_like(input) is a permissible implementation of fn). This is verified via an extra assert that is inserted into the traced graph.
  • fn MAY have side effects, but it MAY NOT perform metadata mutation on other tensors participating in proxy tensor tracing (it MAY mutate other tensors, it MAY mutate Python state)

These requirements stem from the requirement that we need to continue performing proxy tensor tracing, which assumes accurate fake tensor metadata, without actually running fn. In the future, we may allow for a "meta" function associated with fn to allow for more interesting input-output patterns.

Note that tensors / Python state are allowed to be mutated. This is relaxed constraint is not always sound, but it is sound for backward tracing with fake tensors as it takes place in AOTAutograd, as the backward pass is guaranteed not to depend on concrete tensor values (via fake tensor) or Python state (because the autograd engine doesn't depend on Python).

The intended use case for this function is to allow AOTAutograd to defer complex backward hooks to compiled autograd. AOTAutograd performs a make_fx trace which preserves the function call as is in the graph, and only when we Dynamo through the backward graph in compiled autograd do we inline into the function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, this is better. Thanks for rewriting it. eg: zeros_like(input) I suppose

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you are using may and must as https://www.rfc-editor.org/rfc/rfc2119 - let's use SHOULD and MUST ;)

Thank you again.

def trace_wrapped(*args, fn):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason to take this variadically, you only support one argument 🤔

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went back and forth. This feels better in case we want to use it in autograd.Function, where we take multiple args.

return _trace_wrapped_op(*args, fn=fn)


_trace_wrapped_op = HigherOrderOperator("trace_wrapped")


def _assert_meta(grad, size, stride, dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob dynamo q: how does dynamo know to execute these asserts at compile time (while dynamo is tracing), instead of automatically trying to add these asserts and metadata calls as proxies into the backward graph?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So long as this function is not allowed in graph, dynamo must inline into it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh right- thanks!

assert grad.size() == size, "size mismatch"
assert grad.stride() == stride, "stride mismatch"
assert grad.dtype == dtype, "dtype mismatch"
return grad


@_trace_wrapped_op.py_impl(ProxyTorchDispatchMode)
def inner_trace(mode, *args, fn):
import torch

assert len(args) == 1
grad = args[0]
assert isinstance(grad, torch.Tensor)

# We've implemented a higher-order operator that remains consistent in proxy tensor tracing.
# However, Dynamo is aware and traces this into its genuine functionality.
# The operator's purpose is to facilitate invoking non-traceable functions
# and embedding them directly into the graph. Essentially, this transforms function
# calls into "leaf modules" as per traditional FX terminology.
# Note: Instead of naming it "allow_in_graph", we opted for a different name since "allow_in_graph"
# might imply that it's traceable, whereas this function is intrinsically non-traceable.
# Note2: I hate this name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be subsumed by the comment above I think

def self_invoke(*args):
return _trace_wrapped_op(*args, fn=fn)

proxy_args = (mode.tracer.unwrap_proxy(grad),)
out_proxy = mode.tracer.create_proxy(
"call_function", self_invoke, proxy_args, {}, name="invocation"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: call this trace_wrapped instead?

)
grad = torch.zeros_like(grad)
grad = track_tensor_tree(grad, out_proxy, constant=None, tracer=mode.tracer)
voznesenskym marked this conversation as resolved.
Show resolved Hide resolved

# We have a little shortcut here, wherein we DO NOT yet run a meta func, and so
voznesenskym marked this conversation as resolved.
Show resolved Hide resolved
# we take on an assumption that input and output meta matches. As such, we must introduce
# a runtime assert
proxy_args = pytree.tree_map(
mode.tracer.unwrap_proxy, (grad, grad.size(), grad.stride(), grad.dtype)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the tree_map here is also unnecessary, just s/grad/out_proxy/

out_proxy = mode.tracer.create_proxy(
"call_function",
_assert_meta,
proxy_args,
{},
name="assert",
)
grad = torch.empty_like(grad)
grad = track_tensor_tree(grad, out_proxy, constant=None, tracer=mode.tracer)
voznesenskym marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused, why do you need to do this twice

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I understand why you did this (you need to prevent the assert from getting DCEd) but I don't think this is the right way to do it. Let me think...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need to prevent this from DCE'd? Like, the assert can just have no data deps and you don't have to track at all. What happens when you do that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, lemme try, I thought it was cause of DCE but now I do not remember.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you get DCE if you don't track w/ create_proxy. However, if we change it to create_node, it breaks in other ways because none of the rest of this is nodes. It's all proxies. Is there a way to pass proxies to node creation? It seems like crossing streams...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every proxy has a node so you can extract the node from

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ofc, but is that kosher here? is that better than just repeating proxy binding code? Does it actually make a difference? I defer to you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is some DCE thing, it will happen whether or not you create_proxy or create_node. I guess this is fine. Actually, why don't you just shove this into self_invoke, that will also prevent DCE

return grad


@_trace_wrapped_op.py_impl(FakeTensorMode)
def inner_fake(*args, fn):
raise RuntimeError("This op should never be invoked here")


@_trace_wrapped_op.py_impl(DispatchKey.CompositeExplicitAutograd)
def _trace_wrapped_op_dense(*args, fn):
mode = _get_current_dispatch_mode()
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
return fn(*args)


_trace_wrapped_op.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(_trace_wrapped_op, deferred_error=True)
)


@_trace_wrapped_op.py_functionalize_impl
def _trace_wrapped_functionalized(ctx, *args, fn):
unwrapped_args = ctx.unwrap_tensors(args)
wrapped_fn = ctx.functionalize(fn)
voznesenskym marked this conversation as resolved.
Show resolved Hide resolved
with ctx.redispatch_to_next():
return ctx.wrap_tensors(_trace_wrapped_op(*unwrapped_args, fn=wrapped_fn))


# TODO(voz): Make this automatic for keys, this is very ugly atm
voznesenskym marked this conversation as resolved.
Show resolved Hide resolved
_trace_wrapped_op.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined]
_trace_wrapped_op.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined]
_trace_wrapped_op.fallthrough(DispatchKey.ADInplaceOrView)
_trace_wrapped_op.fallthrough(DispatchKey.BackendSelect)
_trace_wrapped_op.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined]
_trace_wrapped_op.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined]
1 change: 1 addition & 0 deletions torch/_dynamo/side_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ def is_empty(self):
return not (
any(map(self.is_modified, self.id_to_variable.values()))
or self.save_for_backward
or self.tensor_hooks
)

def clear(self):
Expand Down
4 changes: 4 additions & 0 deletions torch/_dynamo/skipfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ def _module_dir(m: types.ModuleType):
_module_dir(torch) + "distributed/_tensor/device_mesh.py",
}

FILENAME_ALLOWLIST |= {
_module_dir(torch) + "_dynamo/_trace_wrapped_higher_order_op.py",
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


SKIP_DIRS_RE = None

is_fbcode = importlib.import_module("torch._inductor.config").is_fbcode()
Expand Down