diff --git a/torchdynamo/optimizations/python_key.py b/torchdynamo/optimizations/python_key.py index 1a18159abe..80b1a98eb0 100644 --- a/torchdynamo/optimizations/python_key.py +++ b/torchdynamo/optimizations/python_key.py @@ -44,6 +44,19 @@ def debug_node(n: Node): return f"{n.op} {target} {n.args} {n.kwargs}" +# TODO: remove, copied from functorch +def strip_overloads(gm): + """ + Modifies the target of graph nodes in :attr:`gm` to strip overloads. + Args: + gm(fx.GraphModule): The input Fx graph module to be modified + """ + for node in gm.graph.nodes: + if isinstance(node.target, torch._ops.OpOverload): + node.target = node.target.overloadpacket + gm.recompile() + + def python_key_normalize(gm: torch.fx.GraphModule, example_inputs, decompositions={}): """ Use AOT autograd for normalizing IR in inference mode. This is useful @@ -125,6 +138,11 @@ def unpack(x): tracer: torch.fx.Tracer = PythonKeyTracer() graph = tracer.trace(fake_signature(fn_for_tracing, nargs)) traced = GraphModule(tracer.root, graph, "python_key_traced") + # https://github.com/pytorch/pytorch/pull/80013 switched over + # tracing to trace op overloads, however op lowerings are currently + # registered to the overload packet. TODO: switch over to registering + # to overloads after branch cut for 1.12 + strip_overloads(traced) traced.recompile() # record_graph_stats(traced)