From 3d72ac9fd810d5e2c4a87b2bf52594ef8a4581e4 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Fri, 24 Jun 2022 01:32:51 +0000 Subject: [PATCH 1/2] Remove op overloads --- torchdynamo/optimizations/python_key.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/torchdynamo/optimizations/python_key.py b/torchdynamo/optimizations/python_key.py index 1a18159abe..7c7031121b 100644 --- a/torchdynamo/optimizations/python_key.py +++ b/torchdynamo/optimizations/python_key.py @@ -44,6 +44,18 @@ def debug_node(n: Node): return f"{n.op} {target} {n.args} {n.kwargs}" +# TODO: remove, copied from functorch +def strip_overloads(gm): + """ + Modifies the target of graph nodes in :attr:`gm` to strip overloads. + Args: + gm(fx.GraphModule): The input Fx graph module to be modified + """ + for node in gm.graph.nodes: + if isinstance(node.target, torch._ops.OpOverload): + node.target = node.target.overloadpacket + gm.recompile() + def python_key_normalize(gm: torch.fx.GraphModule, example_inputs, decompositions={}): """ Use AOT autograd for normalizing IR in inference mode. This is useful @@ -125,6 +137,12 @@ def unpack(x): tracer: torch.fx.Tracer = PythonKeyTracer() graph = tracer.trace(fake_signature(fn_for_tracing, nargs)) traced = GraphModule(tracer.root, graph, "python_key_traced") + # https://github.com/pytorch/pytorch/pull/80013 switched over + # tracing to trace op overloads, however op lowerings are currently + # registered to the overload packet. TODO: switch over to registering + # to overloads after branch cut for 1.12 + strip_overloads(traced) + traced.recompile() # record_graph_stats(traced) From aa8b765b99a05866751c8da8e549647696620748 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Fri, 24 Jun 2022 02:07:43 +0000 Subject: [PATCH 2/2] lint --- torchdynamo/optimizations/python_key.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchdynamo/optimizations/python_key.py b/torchdynamo/optimizations/python_key.py index 7c7031121b..80b1a98eb0 100644 --- a/torchdynamo/optimizations/python_key.py +++ b/torchdynamo/optimizations/python_key.py @@ -56,6 +56,7 @@ def strip_overloads(gm): node.target = node.target.overloadpacket gm.recompile() + def python_key_normalize(gm: torch.fx.GraphModule, example_inputs, decompositions={}): """ Use AOT autograd for normalizing IR in inference mode. This is useful @@ -137,13 +138,12 @@ def unpack(x): tracer: torch.fx.Tracer = PythonKeyTracer() graph = tracer.trace(fake_signature(fn_for_tracing, nargs)) traced = GraphModule(tracer.root, graph, "python_key_traced") - # https://github.com/pytorch/pytorch/pull/80013 switched over + # https://github.com/pytorch/pytorch/pull/80013 switched over # tracing to trace op overloads, however op lowerings are currently # registered to the overload packet. TODO: switch over to registering # to overloads after branch cut for 1.12 strip_overloads(traced) - traced.recompile() # record_graph_stats(traced)