Skip to content

Commit

Permalink
[ONNX] Safely set node name for 'replace_placeholder_name_and_target' (
Browse files Browse the repository at this point in the history
…#98633)

Pull Request resolved: #98633
Approved by: https://github.com/wschin
  • Loading branch information
BowenBao authored and pytorchmergebot committed Apr 11, 2023
1 parent ad1d842 commit 2b38bd5
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 4 deletions.
61 changes: 61 additions & 0 deletions test/onnx/test_fx_passes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Owner(s): ["module: onnx"]
from __future__ import annotations

import torch._dynamo
import torch.fx
from torch.onnx._internal.fx.passes import _utils as pass_utils
from torch.testing._internal import common_utils


class TestFxPasses(common_utils.TestCase):
def test_set_node_name_correctly_renames_when_new_name_collides_recursively(self):
def func(x, y, z):
return x + y + z

x = torch.randn(3)
y = torch.randn(3)
z = torch.randn(3)
gm, _ = torch._dynamo.export(func, x, y, z)
torch._dynamo.reset()

# Purposely name the nodes in a way that will cause a recursive collision later.
# See :func:`set_node_name` for name collision renaming logic.
base_name = "tensor"
nodes = list(gm.graph.nodes)
for i, node in enumerate(nodes[1:]):
if i == 0:
node.name = base_name
else:
node.name = f"{base_name}.{i}"

# Run `set_node_name` and verify that the names are correct.
name_to_node = {node.name: node for node in gm.graph.nodes}
pass_utils.set_node_name(nodes[0], base_name, name_to_node)
assert nodes[0].name == base_name, f"Expected {base_name}, got {nodes[0].name}"
assert len({node.name for node in nodes}) == len(
nodes
), f"Expected all names to be unique, got {nodes}"

def test_set_node_name_succeeds_when_no_name_collisions(self):
def func(x, y, z):
return x + y + z

x = torch.randn(3)
y = torch.randn(3)
z = torch.randn(3)
gm, _ = torch._dynamo.export(func, x, y, z)
torch._dynamo.reset()

# Run `set_node_name` and verify that the names are correct.
new_name = "some_tensor"
nodes = list(gm.graph.nodes)
name_to_node = {node.name: node for node in nodes}
pass_utils.set_node_name(nodes[1], new_name, name_to_node)
assert nodes[1].name == new_name, f"Expected {new_name}, got {nodes[0].name}"
assert len({node.name for node in nodes}) == len(
nodes
), f"Expected all names to be unique, got {nodes}"


if __name__ == "__main__":
common_utils.run_tests()
64 changes: 60 additions & 4 deletions torch/onnx/_internal/fx/passes/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
"""
from __future__ import annotations

from typing import Callable
import collections

import re

from typing import Callable, Dict, Optional, Tuple

import torch.fx
import torch.fx.traceback as fx_traceback
Expand All @@ -28,6 +32,54 @@ def wrapped(*args):
return wrapped


def _get_node_base_name(node_name: str) -> Tuple[str, Optional[int]]:
pattern = r"(.*)\.(\d+)"
match = re.match(pattern, node_name)
if match is not None:
base_name, count_str = match.groups()
return base_name, int(count_str)
return node_name, None


@_beartype.beartype
def set_node_name(
node: torch.fx.Node,
new_name: str,
name_to_node_cache: Dict[str, torch.fx.Node],
):
"""Safely set the unique name of a node.
If the new name is already taken by another node, the name of the other node will be
updated. If `new_name` is a string of format f"{base_name}.{count}", where `count`
is an integer, the other node will be renamed as f"{base_name}.{count+1}". If not,
the other node will be renamed as "{new_name}.1". This function will iteratively
update the names until there is no conflict.
``name_to_node_cache`` is required as an argument to avoid recomputation. The caller
is responsible for ensuring the cache is accurate and in sync with the owning module
of the node. The values in the cache will be updated accordingly.
Args:
node: The node to update.
new_name: The new name to use.
name_to_node_cache: A cache of node names to nodes.
"""
module = node.graph.owning_module
node_name_to_set = collections.deque([(node, new_name)])

while node_name_to_set:
node, new_name = node_name_to_set.pop()
if new_name in name_to_node_cache and name_to_node_cache[new_name] != node:
base_name, postfix_count = _get_node_base_name(new_name)
if postfix_count is None:
postfix_count = 0
node_name_to_set.append(
(name_to_node_cache[new_name], f"{base_name}.{postfix_count + 1}")
)
node.name = new_name
name_to_node_cache[new_name] = node


@_beartype.beartype
def replace_placeholder_name_and_target(
module: torch.fx.GraphModule, reference_module: torch.fx.GraphModule
Expand All @@ -39,8 +91,8 @@ def replace_placeholder_name_and_target(
function is undefined. This function only does minimal sanity check that the two
modules have the same number of arguments.
TODO(bowbao): Handle potential name conflicts between new names and existing node
names in the graph.
Name conflicts between new names and existing node names in the graph are handled.
Check the documentation of :func:`set_node_name` for more details.
Raises:
RuntimeError: If the two modules have different number of arguments.
Expand All @@ -56,8 +108,12 @@ def replace_placeholder_name_and_target(
f"module: {len(placeholders)}, reference_module: {len(reference_placeholders)}"
)

name_to_node: Dict[str, torch.fx.Node] = {}
for node in module.graph.nodes:
name_to_node[node.name] = node

for placeholder, reference_placeholder in zip(placeholders, reference_placeholders):
placeholder.target = reference_placeholder.target
placeholder.name = reference_placeholder.name
set_node_name(placeholder, reference_placeholder.name, name_to_node)

module.recompile()

0 comments on commit 2b38bd5

Please sign in to comment.