Skip to content

Commit

Permalink
Initial version of Dynamo capture for HigherOrderOperator
Browse files Browse the repository at this point in the history
This PR introduces a `wrap(body_fn, *args)` higher order operator
The semantics of `wrap(body_fn, *args)` is to just run `body_fn(*args)`

Underneath Dynamo, this PR makes it so that we rewrite calls to
`wrap(body_fn, *args)` with `wrap(new_fn, *new_args)` where `new_fn` has
no free variables. This PR does not update cond/map to use the new
mechanism yet (we do not support nn.Modues yet, will come in the future).

The design we take is:
- OutputGraph represents the graph being built by Dynamo that may be
compiled and executed.
- OutputGraph owns a root SubgraphTracer, where it builds the FX graph.
- OutputGraph may own multiple nested SubgraphTracers.
- When we need to trace the body function of a HigherOrderOperator, we
construct a new SubgraphTracer to build the graph of the body function.

Mechanically, when Dynamo sees a new `wrap` HigherOrderOperator with a
body function, it:
- Creates a new SubgraphTracer via OutputGraph.new_subtracer
- Executes the body function
This captures the body function into the graph on the new
SubgraphTracer while modifying the state of the OutputGraph. For
example, the OutputGraph may receive new GraphArgs, new guards, and new
side effects.

If capture of the body function fails, then Dynamo graph breaks on the
HigherOrderOperator.

Test Plan:
- added test/dynamo/test_higher_order_ops.py

Future:
- We're not actually able to tell Dynamo to completely graph break on the
HigherOrderOperator. Instead, when we do graph break, Dynamo begins
introspecting `HigherOrderOperator.__call__`. It should probably not do
this.
- Ideally we would error out on new SideEffects. I don't know how to do
this yet.
- We don't support dealing with nn.Modules yet (e.g. calling nn.Modules
or accessing attributes of tracked nn.Modules from a body_fn). There's
an open question on what should actually happen here
- Ideally we would rewrite map/cond to use the new mechanism but we need
to fix the previous bullet point before we can get there.

ghstack-source-id: 8a3bae02fe0464c6fc8a2a0599f1b81412063830
Pull Request resolved: #99988
  • Loading branch information
zou3519 committed Apr 26, 2023
1 parent c7b27a3 commit 24935e5
Show file tree
Hide file tree
Showing 5 changed files with 541 additions and 5 deletions.
351 changes: 351 additions & 0 deletions test/dynamo/test_higher_order_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,351 @@
# Owner(s): ["module: dynamo"]
import re
import unittest

import torch

import torch._dynamo.test_case
from torch._dynamo.utils import counters
from torch._ops import wrap


class MockBackend:
def __init__(self):
self.graphs = []

def __call__(self, gm: torch.fx.GraphModule, example_inputs):
self.graphs.append(gm)
return gm.forward


global_var = torch.randn(3)
global_num = 3.14


class TestHigherOrderOps(torch._dynamo.test_case.TestCase):
def test_no_freevars(self):
mock = MockBackend()

def f(x):
return wrap(lambda x: torch.sin(x), x)

x = torch.randn(3)
expected = f(x)
result = torch.compile(f, backend=mock)(x)

self.assertEqual(result, expected)
self.assertEqual(len(mock.graphs), 1)
self.assertIsNotNone(re.search(r"wrap\(\w+, \w+\);", mock.graphs[0].code))

def test_capture_untracked_global(self):
counters.clear()
mock = MockBackend()

def f(x):
return wrap(lambda x: x + global_var, x)

x = torch.randn(3)
expected = f(x)
result = torch.compile(f, backend=mock)(x)

self.assertEqual(result, expected)
self.assertEqual(len(mock.graphs), 1)
# wrap(fn, x, global_var)
self.assertIsNotNone(re.search(r"wrap\(\w+, \w+, \w+\);", mock.graphs[0].code))
self.assertEqual(len(counters["graph_break"]), 0)

def test_capture_untracked_global_nested(self):
mock = MockBackend()
counters.clear()

@torch.compile(backend=mock)
def f(x):
return wrap(lambda x: wrap(lambda x: x + global_var, x), x)

x = torch.randn(3)
result = f(x)

self.assertEqual(result, x + global_var)
self.assertEqual(len(mock.graphs), 1)
gm = mock.graphs[0]
self.assertIsNotNone(re.search(r"wrap\(\w+, \w+, \w+\);", gm.code))
self.assertIsNotNone(re.search(r"wrap\(\w+, \w+, \w+\);", gm.cond_body_1.code))
self.assertEqual(len(counters["graph_break"]), 0)

def test_capture_untracked_nonlocal(self):
counters.clear()
mock = MockBackend()

x = torch.randn(3, 3)
y = torch.randn(3, 3)

def f(x, y):
@torch.compile(backend=mock)
def g(x):
return wrap(lambda x: x + y, x)

return g(x)

result = f(x, y)
expected = x + y

self.assertEqual(result, expected)
self.assertEqual(len(mock.graphs), 1)
# wrap(fn, x, y)
self.assertIsNotNone(re.search(r"wrap\(\w+, \w+, \w+\);", mock.graphs[0].code))
self.assertEqual(len(counters["graph_break"]), 0)

def test_capture_tracked(self):
counters.clear()
mock = MockBackend()

x = torch.randn(3, 3)
y = torch.randn(3, 3)

@torch.compile(backend=mock)
def f(x, y):
return wrap(lambda x: x + y, x)

result = f(x, y)

self.assertEqual(result, x + y)
self.assertEqual(len(mock.graphs), 1)
self.assertIsNotNone(re.search(r"wrap\(\w+, \w+, \w+\);", mock.graphs[0].code))
self.assertEqual(len(counters["graph_break"]), 0)

def test_inlined_functions(self):
counters.clear()
mock = MockBackend()

x = torch.randn(3, 3)
y = torch.randn(3, 3)

def g(x, y):
return x + y

@torch.compile(backend=mock)
def f(x, y):
return wrap(lambda x: g(x, y), x)

result = f(x, y)

self.assertEqual(result, x + y)
self.assertEqual(len(mock.graphs), 1)
self.assertIsNotNone(re.search(r"wrap\(\w+, \w+, \w+\);", mock.graphs[0].code))
self.assertEqual(len(counters["graph_break"]), 0)

def test_capture_value_created_in_subgraph(self):
counters.clear()
mock = MockBackend()

x = torch.randn(3, 3)
y = torch.randn(3, 3)

def inner(x, y):
z = x + y
return wrap(lambda x: wrap(lambda x: x + z, x), x)

@torch.compile(backend=mock)
def f(x, y):
return wrap(inner, x, y)

result = f(x, y)

self.assertEqual(result, x + y + x)
self.assertEqual(len(mock.graphs), 1)
gm = mock.graphs[0]
# Two inputs: no lifting
self.assertIsNotNone(re.search(r"wrap\(\w+, \w+, \w+\);", gm.code))
# z should have been lifted to input
self.assertIsNotNone(re.search(r"wrap\(\w+, \w+, \w+\);", gm.cond_body_2.code))
self.assertEqual(len(counters["graph_break"]), 0)

def test_capture_global_num(self):
counters.clear()
mock = MockBackend()
x = torch.zeros([])

@torch.compile(backend=mock)
def f(x):
return wrap(lambda x: x + global_num, x)

global global_num
result = f(x)
self.assertEqual(result, x + global_num)
self.assertEqual(len(mock.graphs), 1)
gm = mock.graphs[0]
# Numbers don't get lifted
self.assertIsNotNone(re.search(r"wrap\(\w+, \w+\);", gm.code))

# Check that we still guard on the number
global_num = torch.randn([]).item()
result = f(x)
self.assertEqual(result, x + global_num)
self.assertEqual(len(counters["graph_break"]), 0)

def test_capture_input_num(self):
counters.clear()
mock = MockBackend()
x = torch.zeros([])
y = 3.14

@torch.compile(backend=mock)
def f(x, y):
return wrap(lambda x: x + y, x)

result = f(x, y)
self.assertEqual(result, x + y)
self.assertEqual(len(mock.graphs), 1)
gm = mock.graphs[0]
# Numbers don't get lifted
self.assertIsNotNone(re.search(r"wrap\(\w+, \w+\);", gm.code))
self.assertEqual(len(counters["graph_break"]), 0)

# TODO: Ideally we would error out if there are any new live side
# effects (for example, if the body function mutates a global variable).
# I don't know how to detect this in a robust way, because it conflicts with
# benign side effects like storing and loading cells that is necessary for
# capturing variables.
@unittest.expectedFailure
def test_side_effect_in_body(self):
from torch._dynamo.utils import counters

counters.clear()

mock = MockBackend()
x = torch.randn([])
y = torch.randn([])

def inner(x):
nonlocal y
y = x
return x.clone()

@torch.compile(backend=mock)
def f(x):
return wrap(inner, x)

f(x)
self.assertEqual(y, x)
self.assertEqual(
dict(counters["graph_break"]),
{"side effects in HigherOrderOperator body": 1},
)

def test_fallback_on_graph_break_simple(self):
# In the future, there should be a per-HigherOrderOperator switch
# on whether or not to fallback or raise a loud error.
# For now we just fallback by default.
mock = MockBackend()
x = torch.randn([])

def inner(x):
y = x.sin()
torch._dynamo.graph_break()
z = y.sin()
return z

@torch.compile(backend=mock)
def f(x):
return wrap(inner, x)

result = f(x)
self.assertEqual(result, inner(x))
# It's unclear if this is correct: dynamo graph breaks on wrap but
# then interposes on wrap.__call__, which invokes fn(*args),
# leading to two graphs being compiled
self.assertEqual(len(mock.graphs), 2)

def test_fallback_on_graph_break_complicated(self):
mock = MockBackend()
x = torch.randn([])

def inner(x):
y = x.sin()
y = y * global_var
torch._dynamo.graph_break()
z = y.sin()
return z

@torch.compile(backend=mock)
def f(x):
x = x.clone()
result = wrap(inner, x)
return result.clone()

result = f(x)
self.assertEqual(result, inner(x))
# It's unclear if this is correct: dynamo graph breaks on wrap but
# then interposes on wrap.__call__, which invokes fn(*args),
# leading to four graphs being compiled: clone, sin, sin, clone
self.assertEqual(len(mock.graphs), 4)

def test_fallback_on_modules(self):
# We can likely support this in the future, I just don't want to deal
# with it right now
from torch._dynamo.utils import counters

counters.clear()
mock = MockBackend()
mod = torch.nn.Linear(3, 3)
x = torch.randn(3, 3)

@torch.compile(backend=mock)
def f(x):
return wrap(lambda x: mod(x), x)

result = f(x)

self.assertEqual(result, mod(x))
self.assertEqual(len(mock.graphs), 1)
self.assertEqual(
dict(counters["graph_break"]),
{"Invoking an nn.Module inside HigherOrderOperator": 1},
)

def test_access_module_attr(self):
# We can likely support this in the future, I just don't want to deal
# with it right now
counters.clear()
mock = MockBackend()
mod = torch.nn.Linear(3, 3)
x = torch.randn(3, 3)

@torch.compile(backend=mock)
def f(x):
y = mod(x)
return wrap(lambda y: y - mod.bias, y)

result = f(x)
self.assertEqual(result, mod(x) - mod.bias)
self.assertEqual(len(mock.graphs), 2)
self.assertEqual(
dict(counters["graph_break"]),
{"accessing attribute of nn.Module inside HigherOrderOperator": 1},
)

def test_make_closure(self):
counters.clear()
mock = MockBackend()
x = torch.randn(3, 3)
y = torch.randn(3, 3)

def f(x, y):
def g(x):
return x + y

return g(x)

@torch.compile(backend=mock)
def h(x, y):
return wrap(f, x, y)

result = h(x, y)
self.assertEqual(result, x + y)
self.assertEqual(len(counters["graph_break"]), 0)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

run_tests()
5 changes: 5 additions & 0 deletions torch/_dynamo/allowed_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ def _find_torch_objects(module):
torch_object_ids[id(module)] = module.__name__
for name, obj in list(module.__dict__.items()):
if id(obj) not in torch_object_ids:
# Don't handle HigherOrderOperator as builtin
import torch._ops

if isinstance(obj, torch._ops.HigherOrderOperator):
continue
if isinstance(obj, types.ModuleType):
if obj.__name__.startswith("torch.") and _is_allowed_module_prefix(
obj
Expand Down

0 comments on commit 24935e5

Please sign in to comment.