Skip to content

Commit

Permalink
[ONNX] Safely set node name for 'replace_placeholder_name_and_target'
Browse files Browse the repository at this point in the history
ghstack-source-id: d6f4f851290df16fe0abfd346dfb8a98ebc60bb8
Pull Request resolved: #98633
  • Loading branch information
BowenBao committed Apr 7, 2023
1 parent 89e5774 commit 1266cc5
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 4 deletions.
38 changes: 38 additions & 0 deletions test/onnx/test_fx_passes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Owner(s): ["module: onnx"]
from __future__ import annotations

import torch._dynamo
import torch.fx
from torch.onnx._internal.fx.passes import _utils as pass_utils
from torch.testing._internal import common_utils


class TestFxPasses(common_utils.TestCase):
def test_set_node_name_correctly_renames_when_new_name_collides_recursively(self):
def func(x, y, z):
return x + y + z

x = torch.randn(3)
y = torch.randn(3)
z = torch.randn(3)
gm, _ = torch._dynamo.export(func, x, y, z)
torch._dynamo.reset()

# Purposely name the nodes in a way that will cause a recursive collision later.
# See :func:`set_node_name` for name collision renaming logic.
base_name = "tensor"
nodes = list(gm.graph.nodes)
for i, node in enumerate(nodes[1:]):
post_fix = f".{1}"
node.name = f"{base_name}{post_fix*i}"

# Run `set_node_name` and verify that the names are correct.
pass_utils.set_node_name(nodes[0], base_name)
assert nodes[0].name == base_name, f"Expected {base_name}, got {nodes[0].name}"
assert len({node.name for node in nodes}) == len(
nodes
), f"Expected all names to be unique, got {nodes}"


if __name__ == "__main__":
common_utils.run_tests()
47 changes: 43 additions & 4 deletions torch/onnx/_internal/fx/passes/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""
from __future__ import annotations

from typing import Callable
from typing import Callable, Dict, Optional

import torch.fx
import torch.fx.traceback as fx_traceback
Expand All @@ -28,6 +28,40 @@ def wrapped(*args):
return wrapped


@_beartype.beartype
def set_node_name(
node: torch.fx.Node,
new_name: str,
name_to_node_cache: Optional[Dict[str, torch.fx.Node]] = None,
):
"""Safely set the unique name of a node.
If the new name is already taken by another node, the name of the other node will be
updated to be "{new_name}.1". This function will recursively update the names until
there is no conflict.
To avoid recomputing the name_to_node_cache, it can be provided as an argument. If
provided, the caller is responsible for ensuring the cache is accurate and in sync
with the owning module of the node.
Args:
node: The node to update.
new_name: The new name to use.
name_to_node_cache: A cache of node names to nodes. If not provided, this
function will build the cache from the owning module of the node.
"""
module = node.graph.owning_module
name_to_node_cache = name_to_node_cache or {
_node.name: _node for _node in module.graph.nodes
}

if new_name in name_to_node_cache and name_to_node_cache[new_name] != node:
set_node_name(name_to_node_cache[new_name], f"{new_name}.1")

node.name = new_name
name_to_node_cache[new_name] = node


@_beartype.beartype
def replace_placeholder_name_and_target(
module: torch.fx.GraphModule, reference_module: torch.fx.GraphModule
Expand All @@ -39,8 +73,9 @@ def replace_placeholder_name_and_target(
function is undefined. This function only does minimal sanity check that the two
modules have the same number of arguments.
TODO(bowbao): Handle potential name conflicts between new names and existing node
names in the graph.
Name conflicts, if discovered, between new names and existing node names in the
graph will be handled. Check the documentation of :func:`set_node_name` for more
details.
Raises:
RuntimeError: If the two modules have different number of arguments.
Expand All @@ -56,8 +91,12 @@ def replace_placeholder_name_and_target(
f"module: {len(placeholders)}, reference_module: {len(reference_placeholders)}"
)

name_to_node: Dict[str, torch.fx.Node] = {}
for node in module.graph.nodes:
name_to_node[node.name] = node

for placeholder, reference_placeholder in zip(placeholders, reference_placeholders):
placeholder.target = reference_placeholder.target
placeholder.name = reference_placeholder.name
set_node_name(placeholder, reference_placeholder.name, name_to_node)

module.recompile()

0 comments on commit 1266cc5

Please sign in to comment.