Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions torchdynamo/optimizations/python_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@ def debug_node(n: Node):
return f"{n.op} {target} {n.args} {n.kwargs}"


# TODO: remove, copied from functorch
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()


def python_key_normalize(gm: torch.fx.GraphModule, example_inputs, decompositions={}):
"""
Use AOT autograd for normalizing IR in inference mode. This is useful
Expand Down Expand Up @@ -125,6 +138,11 @@ def unpack(x):
tracer: torch.fx.Tracer = PythonKeyTracer()
graph = tracer.trace(fake_signature(fn_for_tracing, nargs))
traced = GraphModule(tracer.root, graph, "python_key_traced")
# https://github.com/pytorch/pytorch/pull/80013 switched over
# tracing to trace op overloads, however op lowerings are currently
# registered to the overload packet. TODO: switch over to registering
# to overloads after branch cut for 1.12
strip_overloads(traced)

traced.recompile()
# record_graph_stats(traced)
Expand Down