diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py new file mode 100644 index 000000000000..d45f23af1488 --- /dev/null +++ b/test/dynamo/test_higher_order_ops.py @@ -0,0 +1,351 @@ +# Owner(s): ["module: dynamo"] +import re +import unittest + +import torch + +import torch._dynamo.test_case +from torch._dynamo.utils import counters +from torch._ops import wrap + + +class MockBackend: + def __init__(self): + self.graphs = [] + + def __call__(self, gm: torch.fx.GraphModule, example_inputs): + self.graphs.append(gm) + return gm.forward + + +global_var = torch.randn(3) +global_num = 3.14 + + +class TestHigherOrderOps(torch._dynamo.test_case.TestCase): + def test_no_freevars(self): + mock = MockBackend() + + def f(x): + return wrap(lambda x: torch.sin(x), x) + + x = torch.randn(3) + expected = f(x) + result = torch.compile(f, backend=mock)(x) + + self.assertEqual(result, expected) + self.assertEqual(len(mock.graphs), 1) + self.assertIsNotNone(re.search(r"wrap\(\w+, \w+\);", mock.graphs[0].code)) + + def test_capture_untracked_global(self): + counters.clear() + mock = MockBackend() + + def f(x): + return wrap(lambda x: x + global_var, x) + + x = torch.randn(3) + expected = f(x) + result = torch.compile(f, backend=mock)(x) + + self.assertEqual(result, expected) + self.assertEqual(len(mock.graphs), 1) + # wrap(fn, x, global_var) + self.assertIsNotNone(re.search(r"wrap\(\w+, \w+, \w+\);", mock.graphs[0].code)) + self.assertEqual(len(counters["graph_break"]), 0) + + def test_capture_untracked_global_nested(self): + mock = MockBackend() + counters.clear() + + @torch.compile(backend=mock) + def f(x): + return wrap(lambda x: wrap(lambda x: x + global_var, x), x) + + x = torch.randn(3) + result = f(x) + + self.assertEqual(result, x + global_var) + self.assertEqual(len(mock.graphs), 1) + gm = mock.graphs[0] + self.assertIsNotNone(re.search(r"wrap\(\w+, \w+, \w+\);", gm.code)) + self.assertIsNotNone(re.search(r"wrap\(\w+, \w+, \w+\);", gm.cond_body_1.code)) + self.assertEqual(len(counters["graph_break"]), 0) + + def test_capture_untracked_nonlocal(self): + counters.clear() + mock = MockBackend() + + x = torch.randn(3, 3) + y = torch.randn(3, 3) + + def f(x, y): + @torch.compile(backend=mock) + def g(x): + return wrap(lambda x: x + y, x) + + return g(x) + + result = f(x, y) + expected = x + y + + self.assertEqual(result, expected) + self.assertEqual(len(mock.graphs), 1) + # wrap(fn, x, y) + self.assertIsNotNone(re.search(r"wrap\(\w+, \w+, \w+\);", mock.graphs[0].code)) + self.assertEqual(len(counters["graph_break"]), 0) + + def test_capture_tracked(self): + counters.clear() + mock = MockBackend() + + x = torch.randn(3, 3) + y = torch.randn(3, 3) + + @torch.compile(backend=mock) + def f(x, y): + return wrap(lambda x: x + y, x) + + result = f(x, y) + + self.assertEqual(result, x + y) + self.assertEqual(len(mock.graphs), 1) + self.assertIsNotNone(re.search(r"wrap\(\w+, \w+, \w+\);", mock.graphs[0].code)) + self.assertEqual(len(counters["graph_break"]), 0) + + def test_inlined_functions(self): + counters.clear() + mock = MockBackend() + + x = torch.randn(3, 3) + y = torch.randn(3, 3) + + def g(x, y): + return x + y + + @torch.compile(backend=mock) + def f(x, y): + return wrap(lambda x: g(x, y), x) + + result = f(x, y) + + self.assertEqual(result, x + y) + self.assertEqual(len(mock.graphs), 1) + self.assertIsNotNone(re.search(r"wrap\(\w+, \w+, \w+\);", mock.graphs[0].code)) + self.assertEqual(len(counters["graph_break"]), 0) + + def test_capture_value_created_in_subgraph(self): + counters.clear() + mock = MockBackend() + + x = torch.randn(3, 3) + y = torch.randn(3, 3) + + def inner(x, y): + z = x + y + return wrap(lambda x: wrap(lambda x: x + z, x), x) + + @torch.compile(backend=mock) + def f(x, y): + return wrap(inner, x, y) + + result = f(x, y) + + self.assertEqual(result, x + y + x) + self.assertEqual(len(mock.graphs), 1) + gm = mock.graphs[0] + # Two inputs: no lifting + self.assertIsNotNone(re.search(r"wrap\(\w+, \w+, \w+\);", gm.code)) + # z should have been lifted to input + self.assertIsNotNone(re.search(r"wrap\(\w+, \w+, \w+\);", gm.cond_body_2.code)) + self.assertEqual(len(counters["graph_break"]), 0) + + def test_capture_global_num(self): + counters.clear() + mock = MockBackend() + x = torch.zeros([]) + + @torch.compile(backend=mock) + def f(x): + return wrap(lambda x: x + global_num, x) + + global global_num + result = f(x) + self.assertEqual(result, x + global_num) + self.assertEqual(len(mock.graphs), 1) + gm = mock.graphs[0] + # Numbers don't get lifted + self.assertIsNotNone(re.search(r"wrap\(\w+, \w+\);", gm.code)) + + # Check that we still guard on the number + global_num = torch.randn([]).item() + result = f(x) + self.assertEqual(result, x + global_num) + self.assertEqual(len(counters["graph_break"]), 0) + + def test_capture_input_num(self): + counters.clear() + mock = MockBackend() + x = torch.zeros([]) + y = 3.14 + + @torch.compile(backend=mock) + def f(x, y): + return wrap(lambda x: x + y, x) + + result = f(x, y) + self.assertEqual(result, x + y) + self.assertEqual(len(mock.graphs), 1) + gm = mock.graphs[0] + # Numbers don't get lifted + self.assertIsNotNone(re.search(r"wrap\(\w+, \w+\);", gm.code)) + self.assertEqual(len(counters["graph_break"]), 0) + + # TODO: Ideally we would error out if there are any new live side + # effects (for example, if the body function mutates a global variable). + # I don't know how to detect this in a robust way, because it conflicts with + # benign side effects like storing and loading cells that is necessary for + # capturing variables. + @unittest.expectedFailure + def test_side_effect_in_body(self): + from torch._dynamo.utils import counters + + counters.clear() + + mock = MockBackend() + x = torch.randn([]) + y = torch.randn([]) + + def inner(x): + nonlocal y + y = x + return x.clone() + + @torch.compile(backend=mock) + def f(x): + return wrap(inner, x) + + f(x) + self.assertEqual(y, x) + self.assertEqual( + dict(counters["graph_break"]), + {"side effects in HigherOrderOperator body": 1}, + ) + + def test_fallback_on_graph_break_simple(self): + # In the future, there should be a per-HigherOrderOperator switch + # on whether or not to fallback or raise a loud error. + # For now we just fallback by default. + mock = MockBackend() + x = torch.randn([]) + + def inner(x): + y = x.sin() + torch._dynamo.graph_break() + z = y.sin() + return z + + @torch.compile(backend=mock) + def f(x): + return wrap(inner, x) + + result = f(x) + self.assertEqual(result, inner(x)) + # It's unclear if this is correct: dynamo graph breaks on wrap but + # then interposes on wrap.__call__, which invokes fn(*args), + # leading to two graphs being compiled + self.assertEqual(len(mock.graphs), 2) + + def test_fallback_on_graph_break_complicated(self): + mock = MockBackend() + x = torch.randn([]) + + def inner(x): + y = x.sin() + y = y * global_var + torch._dynamo.graph_break() + z = y.sin() + return z + + @torch.compile(backend=mock) + def f(x): + x = x.clone() + result = wrap(inner, x) + return result.clone() + + result = f(x) + self.assertEqual(result, inner(x)) + # It's unclear if this is correct: dynamo graph breaks on wrap but + # then interposes on wrap.__call__, which invokes fn(*args), + # leading to four graphs being compiled: clone, sin, sin, clone + self.assertEqual(len(mock.graphs), 4) + + def test_fallback_on_modules(self): + # We can likely support this in the future, I just don't want to deal + # with it right now + from torch._dynamo.utils import counters + + counters.clear() + mock = MockBackend() + mod = torch.nn.Linear(3, 3) + x = torch.randn(3, 3) + + @torch.compile(backend=mock) + def f(x): + return wrap(lambda x: mod(x), x) + + result = f(x) + + self.assertEqual(result, mod(x)) + self.assertEqual(len(mock.graphs), 1) + self.assertEqual( + dict(counters["graph_break"]), + {"Invoking an nn.Module inside HigherOrderOperator": 1}, + ) + + def test_access_module_attr(self): + # We can likely support this in the future, I just don't want to deal + # with it right now + counters.clear() + mock = MockBackend() + mod = torch.nn.Linear(3, 3) + x = torch.randn(3, 3) + + @torch.compile(backend=mock) + def f(x): + y = mod(x) + return wrap(lambda y: y - mod.bias, y) + + result = f(x) + self.assertEqual(result, mod(x) - mod.bias) + self.assertEqual(len(mock.graphs), 2) + self.assertEqual( + dict(counters["graph_break"]), + {"accessing attribute of nn.Module inside HigherOrderOperator": 1}, + ) + + def test_make_closure(self): + counters.clear() + mock = MockBackend() + x = torch.randn(3, 3) + y = torch.randn(3, 3) + + def f(x, y): + def g(x): + return x + y + + return g(x) + + @torch.compile(backend=mock) + def h(x, y): + return wrap(f, x, y) + + result = h(x, y) + self.assertEqual(result, x + y) + self.assertEqual(len(counters["graph_break"]), 0) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/torch/_dynamo/allowed_functions.py b/torch/_dynamo/allowed_functions.py index 875b45bccfbd..81116c495b7a 100644 --- a/torch/_dynamo/allowed_functions.py +++ b/torch/_dynamo/allowed_functions.py @@ -168,6 +168,11 @@ def _find_torch_objects(module): torch_object_ids[id(module)] = module.__name__ for name, obj in list(module.__dict__.items()): if id(obj) not in torch_object_ids: + # Don't handle HigherOrderOperator as builtin + import torch._ops + + if isinstance(obj, torch._ops.HigherOrderOperator): + continue if isinstance(obj, types.ModuleType): if obj.__name__.startswith("torch.") and _is_allowed_module_prefix( obj diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 67eaa45ce655..61dc84d70ef7 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1,4 +1,5 @@ import collections +import contextlib import copy import functools import itertools @@ -18,6 +19,7 @@ import torch._logging import torch.nn +import torch.utils._pytree as pytree from torch import fx from torch._guards import ( Checkpointable, @@ -289,7 +291,7 @@ def real_value_cache(self): return self.current_tracer.real_value_cache def create_graph_input(self, *args, **kwargs): - return self.current_tracer.create_graph_input(*args, **kwargs) + return self.root_tracer.create_graph_input(*args, **kwargs) def create_proxy(self, *args, **kwargs): return self.current_tracer.create_proxy(*args, **kwargs) @@ -303,6 +305,15 @@ def remove_node(self, *args, **kwargs): def create_arg(self, *args, **kwargs): return self.current_tracer.create_arg(*args, **kwargs) + @contextlib.contextmanager + def new_subtracer(self): + try: + tracer = SubgraphTracer(self, parent=self.current_tracer) + self.tracers.append(tracer) + yield tracer + finally: + self.tracers.pop() + @property def output(self): return self @@ -454,6 +465,10 @@ def register_attr_or_module( assert not isinstance(source, ParamBufferSource) if isinstance(target, torch.Tensor): + if len(self.tracers) > 1: + unimplemented( + "accessing attribute of nn.Module inside HigherOrderOperator" + ) if not is_constant_source(source): options["guards"].add(source.make_guard(GuardBuilder.TENSOR_MATCH)) @@ -920,7 +935,7 @@ class SubgraphTracer(fx.Tracer): compiling and executing the graph. """ - def __init__(self, output_graph): + def __init__(self, output_graph, parent=None): super(SubgraphTracer, self).__init__() self.output_graph = weakref.proxy(output_graph) self.graph = torch.fx.Graph() @@ -931,6 +946,18 @@ def __init__(self, output_graph): # Node => computed real value (see utils.get_real_value) self.real_value_cache: Dict[fx.Node, torch.Tensor] = {} + # SubgraphTracers can be nested. See NOTE [HigherOrderOperator tracing design] + self.parent = parent + # A list of proxies that exist in the graph being traced. We use this + # list to determine that, when tracing the body function of a HigherOrderOperator, + # if a new proxy is actually a free variable. + self.seen_proxies = set({}) + # A list of previously free variables that we lifted to being inputs + # of the graph. If we are tracing a HigherOrderOperator's body_fn, + # then we need to keep track of this so we can rewrite the + # HigherOrderOperator call using the traced body_fn. + self.lifted_freevars = set({}) + def create_proxy( self, kind, @@ -941,6 +968,21 @@ def create_proxy( type_expr=None, proxy_factory_fn=None, ): + # If there are any freevars, lift them to being inputs + if self.parent is not None: + flat_args, _ = pytree.tree_flatten(args) + for arg in flat_args: + if not isinstance(arg, torch.fx.Proxy): + # Is a constant + continue + if arg in self.seen_proxies: + continue + if not hasattr(arg, "node"): + continue + if arg.node.name in self.input_name_to_proxy: + continue + self.lift_freevar_to_input(arg) + rv = super().create_proxy( kind, target, args, kwargs, name, type_expr, proxy_factory_fn ) @@ -955,6 +997,8 @@ def create_proxy( if kind in {"call_function", "call_method"}: rv.node.meta["source_fn"] = target elif kind == "call_module": + if self.parent is not None: + unimplemented("Invoking an nn.Module inside HigherOrderOperator") # For modules we store the class rv.node.meta["source_fn"] = rv.node.meta["nn_module_stack"][target][1] @@ -969,6 +1013,7 @@ def create_proxy( msgs = traceback.StackSummary.from_list(frame_summaries).format() # type: ignore[arg-type] rv.node.stack_trace = "".join(msgs) + self.seen_proxies.add(rv) return rv def create_node(self, *args, **kwargs): @@ -1015,3 +1060,51 @@ def create_graph_input(self, name, type_expr=None, before=False): else: self.input_name_to_proxy[name] = proxy return proxy + + def is_name_bound(self, name): + if name in self.input_name_to_proxy: + return True + for proxy in self.seen_proxies: + if proxy.node.name == name: + return True + return False + + def lift_freevar_to_input(self, proxy): + self.create_graph_input(proxy.node.name) + self.lifted_freevars.add(proxy) + if self.parent is not None and not self.parent.is_name_bound(proxy.node.name): + self.parent.lift_freevar_to_input(proxy) + + +# NOTE: [HigherOrderOperator tracing design] +# Ignoring HigherOrderOperators for a moment, +# OutputGraph represents the graph being built by Dynamo that may be compiled +# and executed. It holds a root SubgraphTracer where the FX graph is built. +# +# HigherOrderOperators are operators that take functions (represented by +# torch.fx.GraphModule) as their arguments. When Dynamo encounters a +# HigherOrderOperator, then it attempts to introspect the function passed +# to it (call this the "body function"), capture it into a GraphModule, +# and rewrite the call to the HigherOrderOperator to use the GraphModule. +# +# The way we handle the capture of body functions is through having +# (possibly nested) SubgraphTracers, one per body function. +# +# Mechanically, we do the introspection by: +# - Creating a new SubgraphTracer via OutputGraph.new_subtracer +# - Executing the body function. +# This constructs the graph of the body function in the new SubgraphTracer +# while modifying the state of the OutputGraph. For example: +# - the OutputGraph can receive new GraphArgs (if we discover any new +# untracked Tensors) +# - side effects from the body function get accumulated into +# OutputGraph.side_effects +# - guards produced by the body function get accumulated into OutputGraph.guards +# +# The traced function has some special properties that make it easier for us +# to transform later down the line: +# - we lift all free variables to being inputs. +# +# If the introspection fails (due to the existence of graph breaks), then +# we roll back the current OutputGraph state and graph break on the +# HigherOrderOperator. diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 2bf2233bc282..6712a4def4e1 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -828,7 +828,7 @@ def get_comparable_state(state): ) ) - def speculate_subgraph(f, sub_args, graph_checkpoint, checkpoint): + def old_speculate_subgraph(f, sub_args, graph_checkpoint, checkpoint): if isinstance(f, NestedUserFunctionVariable) and f.closure is not None: # closure vars other than 'self' are not in scope of generated code, so error early # TODO(avik): we should eventually support this. @@ -911,6 +911,53 @@ def speculate_subgraph(f, sub_args, graph_checkpoint, checkpoint): comparable_state, ) + # See NOTE [HigherOrderOperator tracing design] for more details. + def speculate_subgraph(f, sub_args, graph_checkpoint, checkpoint): + try: + with tx.output.new_subtracer() as tracer: + args = [] + # One argument to graph per sub_args + for a in sub_args: + if isinstance(a, TensorVariable): + tracer.create_graph_input(a.as_proxy().node.name) + args.append(a) + else: + # call_function() needs a TensorVariable, therefore we construct + # one with inner graph proxy. + assert isinstance(a, torch.Tensor) + proxy = tracer.create_graph_input("arg") + args.append( + wrap_fx_proxy(tx=tx, proxy=proxy, example_value=a) + ) + + output = f.call_function(tx, args, {}) + # Register output to graph + # Modeled off of compile_and_call_fx_graph + # TODO: support non single Tensor output + assert isinstance(output, TensorVariable) + + tx.output.guards.update(output.guards) + tx.output.create_node( + "output", + "output", + (tracer.create_arg((output.as_proxy(),))), + {}, + ) + + graph = tx.output.graph + lifted_freevars = tracer.lifted_freevars + + return ( + output, + graph, + lifted_freevars, + ) + + except torch._dynamo.exc.Unsupported as ex: + tx.output.graph = graph_checkpoint + tx.restore_graphstate(checkpoint) + raise + if self.value.__name__ == "cond": # TODO(voz): Support fake tensor dispatch for recursive # ops - see torch/dispatch/_dispatcher.py @@ -983,7 +1030,7 @@ def speculate_branch(branch): try: # NB: 0 is predicate ix = 1 if branch else 2 - return speculate_subgraph( + return old_speculate_subgraph( args[ix], operands, graph_checkpoint, checkpoint ) except ArgsMismatchError as e: @@ -1054,7 +1101,7 @@ def speculate_branch(branch): body_guards, body_nn_modules, body_cmp, - ) = speculate_subgraph( + ) = old_speculate_subgraph( args[0], [ get_fake_value(args[1].as_proxy().node, tx)[0], @@ -1102,6 +1149,34 @@ def speculate_branch(branch): example_value = deepcopy_to_fake_tensor(example_res, tx.fake_mode) p_args = (lowered_node,) + p_args + elif self.value.__name__ == "wrap": + # See NOTE [HigherOrderOperator tracing design] for more details + checkpoint = tx.copy_graphstate() + graph_checkpoint = tx.output.graph + ( + body_r, + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + args[0], + [ + *args[1:], + ], + graph_checkpoint, + checkpoint, + ) + + body_name = add_subgraph( + "body", torch.fx.GraphModule(tx.output.nn_modules, body_graph) + ) + body_node = make_attr(body_name) + p_args = ( + body_node, + *(arg.as_proxy() for arg in args[1:]), + *(arg for arg in body_lifted_freevars), + ) + r = body_r.as_proxy().node.meta["example_value"] + example_value = r else: unimplemented(f"HigherOrderOperator {self.value.__name__}") diff --git a/torch/_ops.py b/torch/_ops.py index e8460e6920e2..55f6ef695ede 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -788,5 +788,17 @@ def load_library(self, path): self.loaded_libraries.add(path) +class Wrap(HigherOrderOperator): + def __init__(self): + super().__init__("wrap") + + def __call__(self, func, *args): + result = func(*args) + return result + + +wrap = Wrap() + + # The ops "namespace" ops = _Ops()