Skip to content

Commit

Permalink
Extend subgraph_rewriter logic (#51532)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #51532

- Change output of `replace_pattern` to `List[Match]` reflecting the
pattern(s) matched in the original graph
- Ensure that all Callables (not just FunctionType objects) work with
the rewriter
- Fix incorrect matching in degenerate case (`test_subgraph_rewriter_correct_output_replacement`)
- Verify that pattern matching works when pattern and original graph are
the same

Test Plan: Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D26193082

Pulled By: ansley

fbshipit-source-id: 7f40c3862012a44adb88f403ade7afc37e50417f
  • Loading branch information
Ansley Ussery authored and facebook-github-bot committed Feb 4, 2021
1 parent 627ec8b commit f2c4dea
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 38 deletions.
140 changes: 112 additions & 28 deletions test/fx/test_subgraph_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,19 @@ def comparison(x):
val = torch.neg(x) + torch.relu(x)
return torch.add(val, val)

traced_module = symbolic_trace(M())
traced = symbolic_trace(M())
comparison_fn = symbolic_trace(comparison)

x = torch.rand(1, 3)

# Replace `pattern` with the same pattern (shouldn't change
# the underlying logic)
subgraph_rewriter.replace_pattern(traced_module, pattern, pattern)
subgraph_rewriter.replace_pattern(traced, pattern, pattern)

traced_module.graph.lint(traced_module)
traced.graph.lint(traced)

ref_output = comparison_fn(x)
test_output = traced_module.forward(x)
test_output = traced.forward(x)
self.assertEqual(ref_output, test_output)

def test_subgraph_rewriter_with_oneliner_pattern(self):
Expand All @@ -60,17 +60,17 @@ def comparison(x):
val = torch.relu(x)
return torch.add(val, val)

traced_module = symbolic_trace(M())
traced = symbolic_trace(M())
comparison_fn = symbolic_trace(comparison)

x = torch.rand(1, 3)

subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
subgraph_rewriter.replace_pattern(traced, pattern, replacement)

traced_module.graph.lint(traced_module)
traced.graph.lint(traced)

ref_output = comparison_fn(x)
test_output = traced_module.forward(x)
test_output = traced.forward(x)
self.assertEqual(ref_output, test_output)

def test_subgraph_rewriter_single_pattern_match(self):
Expand All @@ -89,24 +89,21 @@ def comparison(x):
val = torch.relu(x)
return torch.add(val, val)

traced_module = symbolic_trace(M())
traced = symbolic_trace(M())
comparison_fn = symbolic_trace(comparison)

x = torch.rand(1, 3)

subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
subgraph_rewriter.replace_pattern(traced, pattern, replacement)

traced_module.graph.lint(traced_module)
traced.graph.lint(traced)

ref_output = comparison_fn(x)
test_output = traced_module.forward(x)
test_output = traced.forward(x)
self.assertEqual(ref_output, test_output)

def test_subgraph_rewriter_multiple_pattern_match(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, w1, w2):
m1 = torch.cat([w1, w2]).sum()
m2 = torch.cat([w1, w2]).sum()
Expand All @@ -123,26 +120,23 @@ def comparison(x, w1, w2):
m2 = torch.stack([w1, w2])
return x + torch.max(m1) + torch.max(m2)

traced_module = symbolic_trace(M())
traced = symbolic_trace(M())
comparison_fn = symbolic_trace(comparison)

x = torch.rand(1, 3)
w1 = torch.rand(1, 3)
w2 = torch.rand(1, 3)

subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
subgraph_rewriter.replace_pattern(traced, pattern, replacement)

traced_module.graph.lint(traced_module)
traced.graph.lint(traced)

ref_outs = comparison_fn(x, w1, w2)
test_outs = traced_module.forward(x, w1, w2)
test_outs = traced.forward(x, w1, w2)
self.assertEqual(ref_outs, test_outs)

def test_subgraph_rewriter_graph_argument_order(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.mm(x, y)

Expand All @@ -152,18 +146,108 @@ def pattern(x, y):
def comparison(x, y):
return torch.mm(x, y)

traced_module = symbolic_trace(M())
traced = symbolic_trace(M())
comparison_fn = symbolic_trace(comparison)

x = torch.randn(3, 4)
y = torch.randn(4, 5)

subgraph_rewriter.replace_pattern(traced_module, pattern, pattern)
subgraph_rewriter.replace_pattern(traced, pattern, pattern)

traced_module.graph.lint(traced_module)
traced.graph.lint(traced)

ref_outs = comparison_fn(x, y)
test_outs = traced_module.forward(x, y)
test_outs = traced.forward(x, y)
self.assertEqual(ref_outs, test_outs)

def test_subgraph_rewriter_correct_output_replacement(self):
class M(torch.nn.Module):
def forward(self, x, y):
val = torch.neg(y) + torch.relu(x)
return torch.add(val, val)

def pattern(x):
return torch.relu(x)

def replacement(x):
return torch.neg(x)

def comparison(x, y):
val = torch.neg(y) + torch.neg(x)
return torch.add(val, val)

traced = symbolic_trace(M())
comparison_fn = symbolic_trace(comparison)

x = torch.randn(4, 4)
y = torch.randn(4, 4)

subgraph_rewriter.replace_pattern(traced, pattern, replacement)

traced.graph.lint(traced)

ref_outs = comparison_fn(x, y)
test_outs = traced.forward(x, y)
self.assertEqual(ref_outs, test_outs)

def test_subgraph_rewriter_traced_as_callable(self):
class M(torch.nn.Module):
def forward(self, x):
val = torch.neg(x) + torch.relu(x)
return torch.add(val, val)

class Pattern(torch.nn.Module):
def forward(self, x):
return torch.neg(x) + torch.relu(x)

class Replacement(torch.nn.Module):
def forward(self, x):
return torch.sigmoid(x)

def comparison(x):
val = torch.sigmoid(x)
return torch.add(val, val)

traced = symbolic_trace(M())
traced_pattern = symbolic_trace(Pattern())
traced_replacement = symbolic_trace(Replacement())
comparison_fn = symbolic_trace(comparison)

x = torch.randn(3, 4)

subgraph_rewriter.replace_pattern(traced, traced_pattern, traced_replacement)

traced.graph.lint(traced)

ref_outs = comparison_fn(x)
test_outs = traced.forward(x)
self.assertEqual(ref_outs, test_outs)

def test_subgraph_rewriter_pattern_is_entire_graph(self):
class M(torch.nn.Module):
def forward(self, x):
a = torch.neg(x)
return torch.add(a, a)

def pattern(x):
a = torch.neg(x)
return torch.add(a, a)

def replacement(x):
a = torch.sigmoid(x)
return torch.cat([a, a])

traced = symbolic_trace(M())
comparison_fn = symbolic_trace(replacement)

x = torch.randn(3, 4)

subgraph_rewriter.replace_pattern(traced, pattern, replacement)

traced.graph.lint(traced)

ref_outs = comparison_fn(x)
test_outs = traced.forward(x)
self.assertEqual(ref_outs, test_outs)

def test_subgraph_rewriter_pattern_output_pattern_node_can_have_users_that_are_not_matched(self):
Expand Down Expand Up @@ -216,9 +300,9 @@ def replacement(x, w1, w2, b1, b2):

traced = symbolic_trace(M())

# Result should be None since no matches can be found
# Result should be [] since no matches can be found
res = subgraph_rewriter.replace_pattern(traced, pattern, replacement)

traced.graph.lint(traced)

self.assertEqual(res, None)
self.assertEqual(res, [])
69 changes: 59 additions & 10 deletions torch/fx/subgraph_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self, pattern : Graph) -> None:
if len(pattern.nodes) == 0:
raise ValueError("SubgraphMatcher cannot be initialized with an "
"empty pattern")
# `self.pattern_anchor` is the output Node in `pattern`
self.pattern_anchor = next(iter(reversed(pattern.nodes)))
# Ensure that there is only a single output value in the pattern
# since we don't support multiple outputs
Expand All @@ -39,6 +40,7 @@ def matches_subgraph_from_anchor(self, anchor : Node) -> bool:

# Compare the pattern node `pn` against the graph node `gn`
def _match_nodes(self, pn : Node, gn : Node) -> bool:

# Check if we've already matched these nodes in the current
# traversal
if pn in self.nodes_map:
Expand All @@ -65,16 +67,21 @@ def attributes_are_equal(pn : Node, gn : Node) -> bool:
if (pn.op != "output"
and len(pn.all_input_nodes) != len(gn.all_input_nodes)):
return False
match_found = all(self._match_nodes(pn_, gn_) for pn_, gn_
in zip(pn.all_input_nodes, gn.all_input_nodes))
if pn.op == "output":
match_found = any(self._match_nodes(pn.all_input_nodes[0], gn_)
for gn_ in gn.all_input_nodes)
else:
match_found = (len(pn.all_input_nodes) == len(gn.all_input_nodes)
and all(self._match_nodes(pn_, gn_) for pn_, gn_
in zip(pn.all_input_nodes, gn.all_input_nodes)))
if not match_found:
self.nodes_map.pop(pn)
return False

return True


def replace_pattern(gm : GraphModule, pattern : Callable, replacement : Callable) -> None:
def replace_pattern(gm : GraphModule, pattern : Callable, replacement : Callable) -> List[Match]:
"""
Matches all possible non-overlapping sets of operators and their
data dependencies (``pattern``) in the Graph of a GraphModule
Expand All @@ -86,6 +93,19 @@ def replace_pattern(gm : GraphModule, pattern : Callable, replacement : Callable
``pattern``: The subgraph to match in ``gm`` for replacement
``replacement``: The subgraph to replace ``pattern`` with
Returns:
List[Match]: A list of ``Match`` objects representing the places
in the original graph that ``pattern`` was matched to. The list
is empty if there are no matches. ``Match`` is defined as:
.. code-block:: python
class Match(NamedTuple):
# Node from which the match was found
anchor: Node
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node]
Examples:
.. code-block:: python
Expand Down Expand Up @@ -234,7 +254,7 @@ def pattern_is_contained(nodes_map : Dict[Node, Node]) -> bool:
# as part of a pattern match
replaced_nodes: Set[Node] = set()

# Return TRUE if one of the nodes in the current match has already
# Return True if one of the nodes in the current match has already
# been used as part of another match
def overlaps_with_prev_match(match : Match) -> bool:
for n in match.nodes_map.values():
Expand Down Expand Up @@ -287,14 +307,41 @@ def mark_node_as_replaced(n : Node) -> None:
with original_graph.inserting_before(subgraph_output):
copied_output = original_graph.graph_copy(replacement_graph,
val_map)
assert isinstance(copied_output, Node)

# We only want to copy in the output node from `pattern` if we
# have an output-output match. Otherwise, we leave out the
# `pattern` output node so we don't have two outputs in the
# resultant graph
# Hook the output Node of the replacement subgraph in to the
# original Graph at the correct location

# CASE 1: We need to hook the replacement subgraph in somewhere
# in the middle of the graph. We replace the Node in the
# original graph that corresponds to the end of the pattern
# subgraph
if subgraph_output.op != "output":
subgraph_output = subgraph_output.args[0] # type: ignore
# `subgraph_output` may have multiple args. These args could
# be from the orignal graph, or they could have come from
# the insertion of `replacement_subgraph`. We need to find
# the Node that was originally matched as part of
# `pattern` (i.e. a Node from the original graph). We can
# figure this out by looking in `match.nodes_map`. The map
# was created before `replacement_subgraph` was spliced in,
# so we know that, if a Node is in `match.nodes_map.values`,
# it must have come from the original graph
for n in subgraph_output.all_input_nodes:
if (n.op != "placeholder"
and n in match.nodes_map.values()):
subgraph_output = n
break
assert subgraph_output.op != "output"
# CASE 2: The pattern subgraph match extends to the end of the
# original graph, so we need to change the current graph's
# output Node to reflect the insertion of the replacement graph.
# We'll keep the current output Node, but update its args and
# `_input_nodes` as necessary
else:
subgraph_output.args = ((copied_output,))
if isinstance(copied_output, Node):
subgraph_output._input_nodes = {copied_output: None}

assert isinstance(copied_output, Node)
subgraph_output.replace_all_uses_with(copied_output)

# Erase the `pattern` nodes
Expand All @@ -305,3 +352,5 @@ def mark_node_as_replaced(n : Node) -> None:
# Update the passed-in GraphModule to reflect the new state of
# `original_graph`
gm.recompile()

return matches

0 comments on commit f2c4dea

Please sign in to comment.