-
Notifications
You must be signed in to change notification settings - Fork 21.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial version of Dynamo capture for HigherOrderOperator
This PR introduces a `wrap(body_fn, *args)` higher order operator The semantics of `wrap(body_fn, *args)` is to just run `body_fn(*args)` Underneath Dynamo, this PR makes it so that we rewrite calls to `wrap(body_fn, *args)` with `wrap(new_fn, *new_args)` where `new_fn` has no free variables. This PR does not update cond/map to use the new mechanism yet (we do not support nn.Modues yet, will come in the future). The design we take is: - OutputGraph represents the graph being built by Dynamo that may be compiled and executed. - OutputGraph owns a root SubgraphTracer, where it builds the FX graph. - OutputGraph may own multiple nested SubgraphTracers. - When we need to trace the body function of a HigherOrderOperator, we construct a new SubgraphTracer to build the graph of the body function. Mechanically, when Dynamo sees a new `wrap` HigherOrderOperator with a body function, it: - Creates a new SubgraphTracer via OutputGraph.new_subtracer - Executes the body function This captures the body function into the graph on the new SubgraphTracer while modifying the state of the OutputGraph. For example, the OutputGraph may receive new GraphArgs, new guards, and new side effects. If capture of the body function fails, then Dynamo graph breaks on the HigherOrderOperator. Test Plan: - added test/dynamo/test_higher_order_ops.py Future: - We're not actually able to tell Dynamo to completely graph break on the HigherOrderOperator. Instead, when we do graph break, Dynamo begins introspecting `HigherOrderOperator.__call__`. It should probably not do this. - Ideally we would error out on new SideEffects. I don't know how to do this yet. - We don't support dealing with nn.Modules yet (e.g. calling nn.Modules or accessing attributes of tracked nn.Modules from a body_fn). There's an open question on what should actually happen here - Ideally we would rewrite map/cond to use the new mechanism but we need to fix the previous bullet point before we can get there. ghstack-source-id: 8a3bae02fe0464c6fc8a2a0599f1b81412063830 Pull Request resolved: #99988
- Loading branch information
Showing
5 changed files
with
541 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,351 @@ | ||
# Owner(s): ["module: dynamo"] | ||
import re | ||
import unittest | ||
|
||
import torch | ||
|
||
import torch._dynamo.test_case | ||
from torch._dynamo.utils import counters | ||
from torch._ops import wrap | ||
|
||
|
||
class MockBackend: | ||
def __init__(self): | ||
self.graphs = [] | ||
|
||
def __call__(self, gm: torch.fx.GraphModule, example_inputs): | ||
self.graphs.append(gm) | ||
return gm.forward | ||
|
||
|
||
global_var = torch.randn(3) | ||
global_num = 3.14 | ||
|
||
|
||
class TestHigherOrderOps(torch._dynamo.test_case.TestCase): | ||
def test_no_freevars(self): | ||
mock = MockBackend() | ||
|
||
def f(x): | ||
return wrap(lambda x: torch.sin(x), x) | ||
|
||
x = torch.randn(3) | ||
expected = f(x) | ||
result = torch.compile(f, backend=mock)(x) | ||
|
||
self.assertEqual(result, expected) | ||
self.assertEqual(len(mock.graphs), 1) | ||
self.assertIsNotNone(re.search(r"wrap\(\w+, \w+\);", mock.graphs[0].code)) | ||
|
||
def test_capture_untracked_global(self): | ||
counters.clear() | ||
mock = MockBackend() | ||
|
||
def f(x): | ||
return wrap(lambda x: x + global_var, x) | ||
|
||
x = torch.randn(3) | ||
expected = f(x) | ||
result = torch.compile(f, backend=mock)(x) | ||
|
||
self.assertEqual(result, expected) | ||
self.assertEqual(len(mock.graphs), 1) | ||
# wrap(fn, x, global_var) | ||
self.assertIsNotNone(re.search(r"wrap\(\w+, \w+, \w+\);", mock.graphs[0].code)) | ||
self.assertEqual(len(counters["graph_break"]), 0) | ||
|
||
def test_capture_untracked_global_nested(self): | ||
mock = MockBackend() | ||
counters.clear() | ||
|
||
@torch.compile(backend=mock) | ||
def f(x): | ||
return wrap(lambda x: wrap(lambda x: x + global_var, x), x) | ||
|
||
x = torch.randn(3) | ||
result = f(x) | ||
|
||
self.assertEqual(result, x + global_var) | ||
self.assertEqual(len(mock.graphs), 1) | ||
gm = mock.graphs[0] | ||
self.assertIsNotNone(re.search(r"wrap\(\w+, \w+, \w+\);", gm.code)) | ||
self.assertIsNotNone(re.search(r"wrap\(\w+, \w+, \w+\);", gm.cond_body_1.code)) | ||
self.assertEqual(len(counters["graph_break"]), 0) | ||
|
||
def test_capture_untracked_nonlocal(self): | ||
counters.clear() | ||
mock = MockBackend() | ||
|
||
x = torch.randn(3, 3) | ||
y = torch.randn(3, 3) | ||
|
||
def f(x, y): | ||
@torch.compile(backend=mock) | ||
def g(x): | ||
return wrap(lambda x: x + y, x) | ||
|
||
return g(x) | ||
|
||
result = f(x, y) | ||
expected = x + y | ||
|
||
self.assertEqual(result, expected) | ||
self.assertEqual(len(mock.graphs), 1) | ||
# wrap(fn, x, y) | ||
self.assertIsNotNone(re.search(r"wrap\(\w+, \w+, \w+\);", mock.graphs[0].code)) | ||
self.assertEqual(len(counters["graph_break"]), 0) | ||
|
||
def test_capture_tracked(self): | ||
counters.clear() | ||
mock = MockBackend() | ||
|
||
x = torch.randn(3, 3) | ||
y = torch.randn(3, 3) | ||
|
||
@torch.compile(backend=mock) | ||
def f(x, y): | ||
return wrap(lambda x: x + y, x) | ||
|
||
result = f(x, y) | ||
|
||
self.assertEqual(result, x + y) | ||
self.assertEqual(len(mock.graphs), 1) | ||
self.assertIsNotNone(re.search(r"wrap\(\w+, \w+, \w+\);", mock.graphs[0].code)) | ||
self.assertEqual(len(counters["graph_break"]), 0) | ||
|
||
def test_inlined_functions(self): | ||
counters.clear() | ||
mock = MockBackend() | ||
|
||
x = torch.randn(3, 3) | ||
y = torch.randn(3, 3) | ||
|
||
def g(x, y): | ||
return x + y | ||
|
||
@torch.compile(backend=mock) | ||
def f(x, y): | ||
return wrap(lambda x: g(x, y), x) | ||
|
||
result = f(x, y) | ||
|
||
self.assertEqual(result, x + y) | ||
self.assertEqual(len(mock.graphs), 1) | ||
self.assertIsNotNone(re.search(r"wrap\(\w+, \w+, \w+\);", mock.graphs[0].code)) | ||
self.assertEqual(len(counters["graph_break"]), 0) | ||
|
||
def test_capture_value_created_in_subgraph(self): | ||
counters.clear() | ||
mock = MockBackend() | ||
|
||
x = torch.randn(3, 3) | ||
y = torch.randn(3, 3) | ||
|
||
def inner(x, y): | ||
z = x + y | ||
return wrap(lambda x: wrap(lambda x: x + z, x), x) | ||
|
||
@torch.compile(backend=mock) | ||
def f(x, y): | ||
return wrap(inner, x, y) | ||
|
||
result = f(x, y) | ||
|
||
self.assertEqual(result, x + y + x) | ||
self.assertEqual(len(mock.graphs), 1) | ||
gm = mock.graphs[0] | ||
# Two inputs: no lifting | ||
self.assertIsNotNone(re.search(r"wrap\(\w+, \w+, \w+\);", gm.code)) | ||
# z should have been lifted to input | ||
self.assertIsNotNone(re.search(r"wrap\(\w+, \w+, \w+\);", gm.cond_body_2.code)) | ||
self.assertEqual(len(counters["graph_break"]), 0) | ||
|
||
def test_capture_global_num(self): | ||
counters.clear() | ||
mock = MockBackend() | ||
x = torch.zeros([]) | ||
|
||
@torch.compile(backend=mock) | ||
def f(x): | ||
return wrap(lambda x: x + global_num, x) | ||
|
||
global global_num | ||
result = f(x) | ||
self.assertEqual(result, x + global_num) | ||
self.assertEqual(len(mock.graphs), 1) | ||
gm = mock.graphs[0] | ||
# Numbers don't get lifted | ||
self.assertIsNotNone(re.search(r"wrap\(\w+, \w+\);", gm.code)) | ||
|
||
# Check that we still guard on the number | ||
global_num = torch.randn([]).item() | ||
result = f(x) | ||
self.assertEqual(result, x + global_num) | ||
self.assertEqual(len(counters["graph_break"]), 0) | ||
|
||
def test_capture_input_num(self): | ||
counters.clear() | ||
mock = MockBackend() | ||
x = torch.zeros([]) | ||
y = 3.14 | ||
|
||
@torch.compile(backend=mock) | ||
def f(x, y): | ||
return wrap(lambda x: x + y, x) | ||
|
||
result = f(x, y) | ||
self.assertEqual(result, x + y) | ||
self.assertEqual(len(mock.graphs), 1) | ||
gm = mock.graphs[0] | ||
# Numbers don't get lifted | ||
self.assertIsNotNone(re.search(r"wrap\(\w+, \w+\);", gm.code)) | ||
self.assertEqual(len(counters["graph_break"]), 0) | ||
|
||
# TODO: Ideally we would error out if there are any new live side | ||
# effects (for example, if the body function mutates a global variable). | ||
# I don't know how to detect this in a robust way, because it conflicts with | ||
# benign side effects like storing and loading cells that is necessary for | ||
# capturing variables. | ||
@unittest.expectedFailure | ||
def test_side_effect_in_body(self): | ||
from torch._dynamo.utils import counters | ||
|
||
counters.clear() | ||
|
||
mock = MockBackend() | ||
x = torch.randn([]) | ||
y = torch.randn([]) | ||
|
||
def inner(x): | ||
nonlocal y | ||
y = x | ||
return x.clone() | ||
|
||
@torch.compile(backend=mock) | ||
def f(x): | ||
return wrap(inner, x) | ||
|
||
f(x) | ||
self.assertEqual(y, x) | ||
self.assertEqual( | ||
dict(counters["graph_break"]), | ||
{"side effects in HigherOrderOperator body": 1}, | ||
) | ||
|
||
def test_fallback_on_graph_break_simple(self): | ||
# In the future, there should be a per-HigherOrderOperator switch | ||
# on whether or not to fallback or raise a loud error. | ||
# For now we just fallback by default. | ||
mock = MockBackend() | ||
x = torch.randn([]) | ||
|
||
def inner(x): | ||
y = x.sin() | ||
torch._dynamo.graph_break() | ||
z = y.sin() | ||
return z | ||
|
||
@torch.compile(backend=mock) | ||
def f(x): | ||
return wrap(inner, x) | ||
|
||
result = f(x) | ||
self.assertEqual(result, inner(x)) | ||
# It's unclear if this is correct: dynamo graph breaks on wrap but | ||
# then interposes on wrap.__call__, which invokes fn(*args), | ||
# leading to two graphs being compiled | ||
self.assertEqual(len(mock.graphs), 2) | ||
|
||
def test_fallback_on_graph_break_complicated(self): | ||
mock = MockBackend() | ||
x = torch.randn([]) | ||
|
||
def inner(x): | ||
y = x.sin() | ||
y = y * global_var | ||
torch._dynamo.graph_break() | ||
z = y.sin() | ||
return z | ||
|
||
@torch.compile(backend=mock) | ||
def f(x): | ||
x = x.clone() | ||
result = wrap(inner, x) | ||
return result.clone() | ||
|
||
result = f(x) | ||
self.assertEqual(result, inner(x)) | ||
# It's unclear if this is correct: dynamo graph breaks on wrap but | ||
# then interposes on wrap.__call__, which invokes fn(*args), | ||
# leading to four graphs being compiled: clone, sin, sin, clone | ||
self.assertEqual(len(mock.graphs), 4) | ||
|
||
def test_fallback_on_modules(self): | ||
# We can likely support this in the future, I just don't want to deal | ||
# with it right now | ||
from torch._dynamo.utils import counters | ||
|
||
counters.clear() | ||
mock = MockBackend() | ||
mod = torch.nn.Linear(3, 3) | ||
x = torch.randn(3, 3) | ||
|
||
@torch.compile(backend=mock) | ||
def f(x): | ||
return wrap(lambda x: mod(x), x) | ||
|
||
result = f(x) | ||
|
||
self.assertEqual(result, mod(x)) | ||
self.assertEqual(len(mock.graphs), 1) | ||
self.assertEqual( | ||
dict(counters["graph_break"]), | ||
{"Invoking an nn.Module inside HigherOrderOperator": 1}, | ||
) | ||
|
||
def test_access_module_attr(self): | ||
# We can likely support this in the future, I just don't want to deal | ||
# with it right now | ||
counters.clear() | ||
mock = MockBackend() | ||
mod = torch.nn.Linear(3, 3) | ||
x = torch.randn(3, 3) | ||
|
||
@torch.compile(backend=mock) | ||
def f(x): | ||
y = mod(x) | ||
return wrap(lambda y: y - mod.bias, y) | ||
|
||
result = f(x) | ||
self.assertEqual(result, mod(x) - mod.bias) | ||
self.assertEqual(len(mock.graphs), 2) | ||
self.assertEqual( | ||
dict(counters["graph_break"]), | ||
{"accessing attribute of nn.Module inside HigherOrderOperator": 1}, | ||
) | ||
|
||
def test_make_closure(self): | ||
counters.clear() | ||
mock = MockBackend() | ||
x = torch.randn(3, 3) | ||
y = torch.randn(3, 3) | ||
|
||
def f(x, y): | ||
def g(x): | ||
return x + y | ||
|
||
return g(x) | ||
|
||
@torch.compile(backend=mock) | ||
def h(x, y): | ||
return wrap(f, x, y) | ||
|
||
result = h(x, y) | ||
self.assertEqual(result, x + y) | ||
self.assertEqual(len(counters["graph_break"]), 0) | ||
|
||
|
||
if __name__ == "__main__": | ||
from torch._dynamo.test_case import run_tests | ||
|
||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.