-
Notifications
You must be signed in to change notification settings - Fork 21.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Higher order op for preserving leaf functions through trace, particul…
…arly for getting user defined hooks to compiled autograd ghstack-source-id: 0518747bd22e427555bb72c60356c04827313b3b Pull Request resolved: #109690 Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd
- Loading branch information
1 parent
6b760ff
commit 1ce318a
Showing
8 changed files
with
410 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
# Owner(s): ["module: dynamo"] | ||
# flake8: noqa | ||
|
||
import torch | ||
|
||
import torch._dynamo.test_case | ||
import torch._dynamo.testing | ||
import torch._dynamo.utils | ||
from torch import _inductor as inductor | ||
from torch._dynamo import compiled_autograd | ||
from torch._dynamo._trace_wrapped_higher_order_op import _trace_wrapped | ||
from torch._dynamo.testing import normalize_gm | ||
from torch._dynamo.utils import counters | ||
from torch.fx.experimental.proxy_tensor import make_fx | ||
import functools | ||
|
||
def _multiply(x): | ||
return x * x | ||
|
||
def _graph_breaking_fn(x): | ||
print("Boo!") | ||
return _multiply(x) | ||
|
||
def _side_effect_stateful_fn2(x, obj): | ||
obj.counter = obj.counter + 1 | ||
return _multiply(x) | ||
|
||
def _multiply_invoke(grad): | ||
return _trace_wrapped(grad, fn=_multiply) | ||
|
||
def _graph_break_invoke(grad): | ||
return _trace_wrapped(grad, fn=_graph_breaking_fn) | ||
|
||
def _side_effectful_invoke2(grad, fn): | ||
return _trace_wrapped(grad, fn=fn) | ||
|
||
|
||
class BackwardHigherOrderOpTests(torch._dynamo.test_case.TestCase): | ||
def test_invoke_in_eager(self): | ||
x = torch.tensor([0.5, 0.5], requires_grad=True) | ||
y = torch.tensor([0.5, 0.5], requires_grad=True) | ||
|
||
def fn(x, y): | ||
x.register_hook(_multiply_invoke) | ||
return x * y | ||
|
||
out = fn(x, y) | ||
out.backward(torch.tensor([2.0, 2.0])) | ||
self.assertEqual(x.grad, 2 * x) | ||
|
||
def test_invoke_in_pt2(self): | ||
for backend in ["eager", "aot_eager", "inductor"]: | ||
torch._dynamo.reset() | ||
x = torch.tensor([0.5, 0.5], requires_grad=True) | ||
y = torch.tensor([0.5, 0.5], requires_grad=True) | ||
|
||
def fn(x, y): | ||
x.register_hook(_multiply_invoke) | ||
return x * y | ||
|
||
fn = torch._dynamo.optimize(backend)(fn) | ||
out = fn(x, y) | ||
out.backward(torch.tensor([2.0, 2.0])) | ||
self.assertEqual(x.grad, 2 * x) | ||
|
||
def test_invoke_make_fx_forward_contrived(self): | ||
x = torch.tensor([0.5, 0.5], requires_grad=True) | ||
out = make_fx(_multiply_invoke)(x) | ||
self.assertEqual(out(x), torch.tensor([0.25, 0.25])) | ||
actual = normalize_gm(out.print_readable(False)) | ||
|
||
expected = """\ | ||
class _multiply_invoke(torch.nn.Module): | ||
def forward(self, grad_1: f32[2]): | ||
invocation: f32[2] = functools_self_invoke(grad_1); grad_1 = None | ||
assert_1: f32[2] = torch._functional_assert_tensor_metadata(invocation, (2,), (1,), torch.float32); invocation = None | ||
detach: f32[2] = torch.ops.aten.detach.default(assert_1); assert_1 = None | ||
detach_1: f32[2] = torch.ops.aten.detach.default(detach); detach = None | ||
detach_2: f32[2] = torch.ops.aten.detach.default(detach_1); detach_1 = None | ||
detach_3: f32[2] = torch.ops.aten.detach.default(detach_2); detach_2 = None | ||
return detach_3 | ||
""" | ||
self.assertExpectedInline(actual, expected) | ||
|
||
def test_invoke_make_bw(self): | ||
x = torch.tensor([0.5, 0.5], requires_grad=True) | ||
|
||
def fwd(x): | ||
z = x * x | ||
return z + z | ||
|
||
res = fwd(x) | ||
res.backward(torch.tensor([1.0, 1.0])) | ||
out = make_fx(_multiply_invoke)(x.grad) | ||
self.assertEqual(out(x.grad), torch.tensor([4.0, 4.0])) | ||
actual = normalize_gm(out.print_readable(False)) | ||
|
||
expected = """\ | ||
class _multiply_invoke(torch.nn.Module): | ||
def forward(self, grad_1: f32[2]): | ||
invocation: f32[2] = functools_self_invoke(grad_1); grad_1 = None | ||
assert_1: f32[2] = torch._functional_assert_tensor_metadata(invocation, (2,), (1,), torch.float32); invocation = None | ||
return assert_1 | ||
""" | ||
self.assertExpectedInline(actual, expected) | ||
|
||
def test_invoke_in_pt2_compiled_autograd(self): | ||
graph = None | ||
|
||
def compiler_fn(gm): | ||
def inner_compiler(gm_, example_inputs_): | ||
nonlocal graph | ||
self.assertEqual(graph, None) | ||
graph = gm_ | ||
return inductor.compile(gm_, example_inputs_) | ||
|
||
return torch.compile( | ||
gm, backend=inner_compiler, fullgraph=True, dynamic=True | ||
) | ||
|
||
for backend in ["eager", "aot_eager", "inductor"]: | ||
torch._dynamo.reset() | ||
x = torch.tensor([0.5, 0.5], requires_grad=True) | ||
y = torch.tensor([0.5, 0.5], requires_grad=True) | ||
|
||
def fn(x, y): | ||
x.register_hook(_multiply_invoke) | ||
return x + y | ||
|
||
fn = torch._dynamo.optimize(backend)(fn) | ||
out = fn(x, y) | ||
with compiled_autograd.enable(compiler_fn): | ||
out.backward(torch.tensor([2.0, 2.0])) | ||
actual = normalize_gm(graph.print_readable(False)) | ||
self.assertEqual(x.grad, torch.tensor([4.0, 4.0])) | ||
expected = """\ | ||
class GraphModule(torch.nn.Module): | ||
def forward(self, L_inputs_0_ : torch.Tensor): | ||
getitem = L_inputs_0_ | ||
new_empty_strided = torch.ops.aten.new_empty_strided.default(getitem, [2], [1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu')) | ||
copy_ = torch.ops.aten.copy_.default(new_empty_strided, getitem); new_empty_strided = None | ||
call_hook = getitem * getitem; getitem = None | ||
new_empty_strided_1 = torch.ops.aten.new_empty_strided.default(call_hook, [2], [1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu')) | ||
copy__1 = torch.ops.aten.copy_.default(new_empty_strided_1, call_hook); new_empty_strided_1 = call_hook = None | ||
return (copy_, copy__1) | ||
""" | ||
self.assertExpectedInline(actual, expected) | ||
|
||
graph = None | ||
|
||
def test_invoke_in_pt2_compiled_autograd_side_effect(self): | ||
graph = None | ||
|
||
def compiler_fn(gm): | ||
def inner_compiler(gm_, example_inputs_): | ||
nonlocal graph | ||
self.assertEqual(graph, None) | ||
graph = gm_ | ||
return inductor.compile(gm_, example_inputs_) | ||
|
||
return torch.compile( | ||
gm, backend=inner_compiler, fullgraph=True, dynamic=True | ||
) | ||
|
||
for backend in ["eager", "aot_eager", "inductor"]: | ||
torch._dynamo.reset() | ||
x = torch.tensor([0.5, 0.5], requires_grad=True) | ||
y = torch.tensor([0.5, 0.5], requires_grad=True) | ||
|
||
class MyObj: | ||
def __init__(self): | ||
self.counter = 0 | ||
|
||
obj = MyObj() | ||
inner_fn = functools.partial(_side_effect_stateful_fn2, obj=obj) | ||
hook_fn = functools.partial(_side_effectful_invoke2, fn=inner_fn) | ||
x.register_hook(hook_fn) | ||
def fn(x, y): | ||
return x + y | ||
|
||
fn = torch._dynamo.optimize(backend, nopython=True)(fn) | ||
out = fn(x, y) | ||
with compiled_autograd.enable(compiler_fn): | ||
out.backward(torch.tensor([2.0, 2.0])) | ||
actual = normalize_gm(graph.print_readable(False)) | ||
self.assertEqual(obj.counter, 1) | ||
self.assertEqual(x.grad, torch.tensor([4.0, 4.0])) | ||
expected = """\ | ||
class GraphModule(torch.nn.Module): | ||
def forward(self, L_inputs_0_ : torch.Tensor): | ||
getitem = L_inputs_0_ | ||
new_empty_strided = torch.ops.aten.new_empty_strided.default(getitem, [2], [1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu')) | ||
copy_ = torch.ops.aten.copy_.default(new_empty_strided, getitem); new_empty_strided = None | ||
call_hook = getitem * getitem; getitem = None | ||
new_empty_strided_1 = torch.ops.aten.new_empty_strided.default(call_hook, [2], [1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu')) | ||
copy__1 = torch.ops.aten.copy_.default(new_empty_strided_1, call_hook); new_empty_strided_1 = call_hook = None | ||
return (copy_, copy__1) | ||
""" | ||
self.assertExpectedInline(actual, expected) | ||
|
||
out = fn(x, y) | ||
out.backward(torch.tensor([2.0, 2.0])) | ||
self.assertEqual(obj.counter, 2) | ||
|
||
out = fn(x, y) | ||
out.backward(torch.tensor([2.0, 2.0])) | ||
self.assertEqual(obj.counter, 3) | ||
graph = None | ||
|
||
|
||
def test_invoke_in_pt2_compiled_autograd_graph_breaks(self): | ||
def compiler_fn(gm): | ||
return torch.compile( | ||
gm, backend="inductor", fullgraph=True, dynamic=True | ||
) | ||
|
||
for backend in ["eager", "aot_eager", "inductor"]: | ||
torch._dynamo.reset() | ||
x = torch.tensor([0.5, 0.5], requires_grad=True) | ||
y = torch.tensor([0.5, 0.5], requires_grad=True) | ||
|
||
|
||
def fn(x, y): | ||
x.register_hook(_graph_break_invoke) | ||
return x + y | ||
|
||
fn = torch._dynamo.optimize(backend, nopython=True)(fn) | ||
out = fn(x, y) | ||
with self.assertRaisesRegex( | ||
torch._dynamo.exc.Unsupported, | ||
"print", | ||
): | ||
with compiled_autograd.enable(compiler_fn): | ||
out.backward(torch.tensor([2.0, 2.0])) | ||
|
||
graph = None |
Oops, something went wrong.