Skip to content

Commit

Permalink
Higher order op for preserving leaf functions through trace, particul…
Browse files Browse the repository at this point in the history
…arly for getting user defined hooks to compiled autograd

ghstack-source-id: 37029e876bf9889441c2f469b06add4fe754e4d5
Pull Request resolved: #109690

Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd

Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd

Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd

Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd

Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd

Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd

Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd

lint

lint

lint

lint

more test

more test

more test

more test
  • Loading branch information
voznesenskym committed Sep 27, 2023
1 parent 34ded74 commit eae245e
Show file tree
Hide file tree
Showing 5 changed files with 401 additions and 0 deletions.
251 changes: 251 additions & 0 deletions test/dynamo/test_backward_higher_order_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
# Owner(s): ["module: dynamo"]
# flake8: noqa

import functools

import torch

import torch._dynamo.test_case
import torch._dynamo.testing
import torch._dynamo.utils
from torch import _inductor as inductor
from torch._dynamo import compiled_autograd
from torch._dynamo._trace_wrapped_higher_order_op import trace_wrapped
from torch._dynamo.testing import normalize_gm
from torch._dynamo.utils import counters
from torch.fx.experimental.proxy_tensor import make_fx


def _multiply(x):
return x * x


def _multiply_invoke(grad):
return trace_wrapped(grad, fn=_multiply)


class BackwardHigherOrderOpTests(torch._dynamo.test_case.TestCase):
def test_invoke_in_eager(self):
x = torch.tensor([0.5, 0.5], requires_grad=True)
y = torch.tensor([0.5, 0.5], requires_grad=True)

def fn(x, y):
x.register_hook(_multiply_invoke)
return x * y

out = fn(x, y)
grad_out = torch.tensor([2.0, 2.0])
out.backward(grad_out)
self.assertEqual(x.grad, y * grad_out)

def test_invoke_in_pt2(self):
for backend in ["eager", "aot_eager", "inductor"]:
torch._dynamo.reset()
x = torch.tensor([0.5, 0.5], requires_grad=True)
y = torch.tensor([0.5, 0.5], requires_grad=True)

def fn(x, y):
x.register_hook(_multiply_invoke)
return x * y

fn = torch._dynamo.optimize(backend)(fn)
out = fn(x, y)
grad_out = torch.tensor([2.0, 2.0])
out.backward(grad_out)
self.assertEqual(x.grad, grad_out * y)

def test_invoke_make_fx_forward_contrived(self):
x = torch.tensor([0.5, 0.5], requires_grad=True)
out = make_fx(_multiply_invoke)(x)
self.assertEqual(out(x), torch.tensor([0.25, 0.25]))
actual = normalize_gm(out.print_readable(False))

expected = """\
class _multiply_invoke(torch.nn.Module):
def forward(self, grad_1: f32[2]):
trace_wrapped: f32[2] = torch__dynamo__trace_wrapped_higher_order_op_self_invoke(grad_1); grad_1 = None
assert_1: f32[2] = torch__dynamo__trace_wrapped_higher_order_op__assert_meta(trace_wrapped, (2,), (1,), torch.float32); trace_wrapped = None
detach: f32[2] = torch.ops.aten.detach.default(assert_1); assert_1 = None
detach_1: f32[2] = torch.ops.aten.detach.default(detach); detach = None
detach_2: f32[2] = torch.ops.aten.detach.default(detach_1); detach_1 = None
detach_3: f32[2] = torch.ops.aten.detach.default(detach_2); detach_2 = None
return detach_3
"""
self.assertExpectedInline(actual, expected)

def test_invoke_make_bw(self):
x = torch.tensor([0.5, 0.5], requires_grad=True)

def fwd(x):
z = x * x
return z + z

res = fwd(x)
res.backward(torch.tensor([1.0, 1.0]))
out = make_fx(_multiply_invoke)(x.grad)
self.assertEqual(out(x.grad), torch.tensor([4.0, 4.0]))
actual = normalize_gm(out.print_readable(False))

expected = """\
class _multiply_invoke(torch.nn.Module):
def forward(self, grad_1: f32[2]):
trace_wrapped: f32[2] = torch__dynamo__trace_wrapped_higher_order_op_self_invoke(grad_1); grad_1 = None
assert_1: f32[2] = torch__dynamo__trace_wrapped_higher_order_op__assert_meta(trace_wrapped, (2,), (1,), torch.float32); trace_wrapped = None
return assert_1
"""
self.assertExpectedInline(actual, expected)

def test_invoke_in_pt2_compiled_autograd(self):
graph = None

def compiler_fn(gm):
def inner_compiler(gm_, example_inputs_):
nonlocal graph
self.assertEqual(graph, None)
graph = gm_
return inductor.compile(gm_, example_inputs_)

return torch.compile(
gm, backend=inner_compiler, fullgraph=True, dynamic=True
)

for backend in ["eager", "aot_eager", "inductor"]:
torch._dynamo.reset()
x = torch.tensor([0.5, 0.5], requires_grad=True)
y = torch.tensor([0.5, 0.5], requires_grad=True)

def fn(x, y):
x.register_hook(_multiply_invoke)
return x + y

fn = torch._dynamo.optimize(backend)(fn)
out = fn(x, y)
grad_out = torch.tensor([2.0, 2.0])
with compiled_autograd.enable(compiler_fn):
out.backward(grad_out)
actual = normalize_gm(graph.print_readable(False))
self.assertEqual(x.grad, grad_out * grad_out)
expected = """\
class GraphModule(torch.nn.Module):
def forward(self, L_inputs_0_ : torch.Tensor):
getitem = L_inputs_0_
new_empty_strided = torch.ops.aten.new_empty_strided.default(getitem, [2], [1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
copy_ = torch.ops.aten.copy_.default(new_empty_strided, getitem); new_empty_strided = None
call_hook = getitem * getitem; getitem = None
new_empty_strided_1 = torch.ops.aten.new_empty_strided.default(call_hook, [2], [1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
copy__1 = torch.ops.aten.copy_.default(new_empty_strided_1, call_hook); new_empty_strided_1 = call_hook = None
return (copy_, copy__1)
"""
self.assertExpectedInline(actual, expected)

graph = None

def test_invoke_in_pt2_compiled_autograd_side_effect(self):
def _side_effect_stateful_fn2(x, obj):
obj.counter = obj.counter + 1
return _multiply(x)

def _side_effectful_invoke2(grad, fn):
return trace_wrapped(grad, fn=fn)

graph = None

def compiler_fn(gm):
def inner_compiler(gm_, example_inputs_):
nonlocal graph
self.assertEqual(graph, None)
graph = gm_
return inductor.compile(gm_, example_inputs_)

return torch.compile(
gm, backend=inner_compiler, fullgraph=True, dynamic=True
)

for backend in ["eager", "aot_eager", "inductor"]:
torch._dynamo.reset()
x = torch.tensor([0.5, 0.5], requires_grad=True)
y = torch.tensor([0.5, 0.5], requires_grad=True)

class MyObj:
def __init__(self):
self.counter = 0

obj = MyObj()
inner_fn = functools.partial(_side_effect_stateful_fn2, obj=obj)
hook_fn = functools.partial(_side_effectful_invoke2, fn=inner_fn)
x.register_hook(hook_fn)

def fn(x, y):
return x + y

fn = torch._dynamo.optimize(backend, nopython=True)(fn)
out = fn(x, y)
grad_out = torch.tensor([2.0, 2.0])
with compiled_autograd.enable(compiler_fn):
out.backward(grad_out)
actual = normalize_gm(graph.print_readable(False))
self.assertEqual(obj.counter, 1)
self.assertEqual(x.grad, grad_out + grad_out)
expected = """\
class GraphModule(torch.nn.Module):
def forward(self, L_inputs_0_ : torch.Tensor):
getitem = L_inputs_0_
new_empty_strided = torch.ops.aten.new_empty_strided.default(getitem, [2], [1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
copy_ = torch.ops.aten.copy_.default(new_empty_strided, getitem); new_empty_strided = None
call_hook = getitem * getitem; getitem = None
new_empty_strided_1 = torch.ops.aten.new_empty_strided.default(call_hook, [2], [1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
copy__1 = torch.ops.aten.copy_.default(new_empty_strided_1, call_hook); new_empty_strided_1 = call_hook = None
return (copy_, copy__1)
"""
self.assertExpectedInline(actual, expected)

out = fn(x, y)
out.backward(grad_out)
self.assertEqual(obj.counter, 2)

out = fn(x, y)
out.backward(grad_out)
self.assertEqual(obj.counter, 3)
graph = None

def test_invoke_in_pt2_compiled_autograd_graph_breaks(self):
def _graph_breaking_fn(x):
print("Boo!")
return _multiply(x)

def _graph_break_invoke(grad):
return trace_wrapped(grad, fn=_graph_breaking_fn)

def compiler_fn(gm):
return torch.compile(gm, backend="inductor", fullgraph=True, dynamic=True)

for backend in ["eager", "aot_eager", "inductor"]:
torch._dynamo.reset()
x = torch.tensor([0.5, 0.5], requires_grad=True)
y = torch.tensor([0.5, 0.5], requires_grad=True)

def fn(x, y):
x.register_hook(_graph_break_invoke)
return x + y

fn = torch._dynamo.optimize(backend, nopython=True)(fn)
out = fn(x, y)
grad_out = torch.tensor([2.0, 2.0])
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported,
"print",
):
with compiled_autograd.enable(compiler_fn):
out.backward(grad_out)

graph = None
128 changes: 128 additions & 0 deletions torch/_dynamo/_trace_wrapped_higher_order_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from torch._C import DispatchKey
from torch._higher_order_ops.utils import autograd_not_implemented

from torch._ops import HigherOrderOperator
from torch._subclasses import FakeTensorMode

from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
from torch.utils._python_dispatch import _get_current_dispatch_mode


__all__ = ["trace_wrapped"]


# trace_wrapped(*args, fn) is equivalent to fn(*args), but with a twist:
# if you make_fx trace through this call, we will not actually trace into fn; instead,
# we will directly insert it as a call_function to fn in the graph.
# (Unlike make_fx, Dynamo WILL inline into fn.)
# You can think of this as a one off allow_in_graph equivalent for proxy tensor tracing.
#
# Because proxy tensor tracing does not actually run the function, there are
# requirements on the behavior of fn. We are still figuring it out, but here is the current state:
#
# 1) fn SHOULD only take a single argument, which must be a tensor
# 2) fn MUST return a new tensor with the same metadata as the original tensor
# (e.g., zeros_like(input) is a permissible implementation of fn).
# This is verified via an extra assert that is inserted into the traced graph.
# 3) fn MAY have side effects, but it MAY NOT perform metadata mutation on other tensors
# participating in proxy tensor tracing (it MAY mutate other tensors, it MAY mutate Python state)
# These requirements stem from the requirement that we need to continue performing proxy tensor tracing,
# which assumes accurate fake tensor metadata, without actually running fn.
# In the future, we may allow for a "meta" function associated with fn to allow for more interesting input-output patterns.
#
# Note that tensors / Python state are allowed to be mutated.
# This is relaxed constraint is not always sound, but it is sound for backward tracing with fake
# tensors as it takes place in AOTAutograd, as the backward pass is guaranteed not to depend on concrete
# tensor values (via fake tensor) or Python state (because the autograd engine doesn't depend on Python).
#
# The intended use case for this function is to allow AOTAutograd to defer complex
# backward hooks to compiled autograd. AOTAutograd performs a make_fx trace which preserves
# the function call as is in the graph, and only when we Dynamo through the backward graph in
# compiled autograd do we inline into the function.


def trace_wrapped(*args, fn):
return _trace_wrapped_op(*args, fn=fn)


_trace_wrapped_op = HigherOrderOperator("trace_wrapped")


def _assert_meta(grad, size, stride, dtype):
assert grad.size() == size, "size mismatch"
assert grad.stride() == stride, "stride mismatch"
assert grad.dtype == dtype, "dtype mismatch"
return grad


@_trace_wrapped_op.py_impl(ProxyTorchDispatchMode)
def inner_trace(mode, *args, fn):
import torch

assert len(args) == 1
grad = args[0]
assert isinstance(grad, torch.Tensor)

def self_invoke(*args):
return _trace_wrapped_op(*args, fn=fn)

proxy_args = (mode.tracer.unwrap_proxy(grad),)
out_proxy = mode.tracer.create_proxy(
"call_function", self_invoke, proxy_args, {}, name="trace_wrapped"
)
grad = torch.zeros_like(grad)
grad = track_tensor_tree(grad, out_proxy, constant=None, tracer=mode.tracer)

# We have a little shortcut here, wherein we DO NOT yet run a meta func, and so
# we take on an assumption that input and output meta matches. As such, we must introduce
# a runtime assert
proxy_args = (
mode.tracer.unwrap_proxy(grad),
grad.size(),
grad.stride(),
grad.dtype,
)
out_proxy = mode.tracer.create_proxy(
"call_function",
_assert_meta,
proxy_args,
{},
name="assert",
)
grad = torch.empty_like(grad)
grad = track_tensor_tree(grad, out_proxy, constant=None, tracer=mode.tracer)
return grad


@_trace_wrapped_op.py_impl(FakeTensorMode)
def inner_fake(*args, fn):
raise RuntimeError("This op should never be invoked here")


@_trace_wrapped_op.py_impl(DispatchKey.CompositeExplicitAutograd)
def _trace_wrapped_op_dense(*args, fn):
mode = _get_current_dispatch_mode()
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
return fn(*args)


_trace_wrapped_op.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(_trace_wrapped_op, deferred_error=True)
)


@_trace_wrapped_op.py_functionalize_impl
def _trace_wrapped_functionalized(ctx, *args, fn):
unwrapped_args = ctx.unwrap_tensors(args)
wrapped_fn = ctx.functionalize(fn)
with ctx.redispatch_to_next():
return ctx.wrap_tensors(_trace_wrapped_op(*unwrapped_args, fn=wrapped_fn))


# TODO(voz): Make this automatic for keys, this is very ugly atm
_trace_wrapped_op.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined]
_trace_wrapped_op.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined]
_trace_wrapped_op.fallthrough(DispatchKey.ADInplaceOrView)
_trace_wrapped_op.fallthrough(DispatchKey.BackendSelect)
_trace_wrapped_op.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined]
_trace_wrapped_op.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined]
1 change: 1 addition & 0 deletions torch/_dynamo/side_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ def is_empty(self):
return not (
any(map(self.is_modified, self.id_to_variable.values()))
or self.save_for_backward
or self.tensor_hooks
)

def clear(self):
Expand Down
4 changes: 4 additions & 0 deletions torch/_dynamo/skipfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ def _module_dir(m: types.ModuleType):
_module_dir(torch) + "distributed/_tensor/device_mesh.py",
}

FILENAME_ALLOWLIST |= {
_module_dir(torch) + "_dynamo/_trace_wrapped_higher_order_op.py",
}

SKIP_DIRS_RE = None

is_fbcode = importlib.import_module("torch._inductor.config").is_fbcode()
Expand Down

0 comments on commit eae245e

Please sign in to comment.