Skip to content

Commit

Permalink
prevent pattern matches across mutation ops in inductor pre-grad FX p…
Browse files Browse the repository at this point in the history
…asses

ghstack-source-id: fa14b4dbba1739326ba0502f438c93632ab00d00
Pull Request resolved: #101144
  • Loading branch information
williamwen42 committed May 18, 2023
1 parent 0a0acce commit a370bca
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 4 deletions.
45 changes: 45 additions & 0 deletions test/inductor/test_pattern_matcher.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Owner(s): ["module: inductor"]
import copy
import unittest

import torch
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.utils import count_calls, counters
Expand Down Expand Up @@ -280,6 +283,48 @@ def fn(a):
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)

def test_match_with_mutation(self):
from torch._inductor.pattern_matcher import (
CallFunction,
KeywordArg,
PatternMatcherPass,
register_graph_pattern,
)

counter = 0
test_pass = PatternMatcherPass(prevent_match_across_mutations=True)

@register_graph_pattern(
CallFunction(
torch.add, KeywordArg("x"), CallFunction(torch.sin, KeywordArg("x"))
),
pass_dict=test_pass,
)
def _test(match, x):
nonlocal counter
counter += 1

def fn(x, y):
a = torch.sin(x)
x.copy_(y)
b = torch.add(x, a)
return b

args1 = [
torch.randn(5, 5, device="cuda"),
torch.randn(5, 5, device="cuda"),
]
args2 = copy.deepcopy(args1)

with unittest.mock.patch(
"torch._inductor.fx_passes.pre_grad.pattern_matcher_passes", [test_pass]
):
expected = fn(*args1)
actual = torch.compile(fn)(*args2)
# should not match
self.assertEqual(counter, 0)
torch.testing.assert_close(actual, expected)


if __name__ == "__main__":
if IS_LINUX and HAS_CUDA and not TEST_WITH_ROCM:
Expand Down
19 changes: 19 additions & 0 deletions test/inductor/test_split_cat_fx_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,25 @@ def split_with_cat(x):
0,
)

@torch._inductor.config.patch(split_cat_fx_passes=True)
def test_split_cat_merge_mutation(self):
args = [
torch.randn(2, 32, 32, 16),
]

def split_cat_mutation(x):
splits = torch.split(x, 4, dim=1)
splits[1].copy_(splits[0])
return torch.cat(splits, dim=1)

expected = split_cat_mutation(*args)
actual = torch.compile(split_cat_mutation, dynamic=True)(*args)

torch.testing.assert_close(actual, expected)

self.assertEqual(counters["inductor"]["scmerge_split_removed"], 0)
self.assertEqual(counters["inductor"]["scmerge_cat_removed"], 0)


if __name__ == "__main__":
if IS_LINUX and HAS_CUDA and not TEST_WITH_ROCM:
Expand Down
6 changes: 3 additions & 3 deletions torch/_inductor/fx_passes/pre_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@

log = logging.getLogger(__name__)

normalize_split_pass = PatternMatcherPass()
merge_splits_pass = PatternMatcherPass()
merge_split_cat_pass = PatternMatcherPass()
normalize_split_pass = PatternMatcherPass(prevent_match_across_mutations=True)
merge_splits_pass = PatternMatcherPass(prevent_match_across_mutations=True)
merge_split_cat_pass = PatternMatcherPass(prevent_match_across_mutations=True)

pattern_matcher_passes: List[PatternMatcherPass] = [
normalize_split_pass,
Expand Down
59 changes: 58 additions & 1 deletion torch/_inductor/pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,10 +718,54 @@ def decorator(handler):
return decorator


def is_start_of_fx_graph(graph, node):
# first node in the graph
return node is next(iter(graph.nodes))


def is_mutation_op(node):
if node.op == "call_function":
if node.target.__name__.endswith("_"):
return True
elif node.op == "call_method":
if node.target.endswith("_"):
return True
if "out" in node.kwargs:
if node.kwargs["out"] in node.all_input_nodes:
return True
return False


def get_mutation_region_id(graph, node):
n = node
while "mutation_region_id" not in n.meta and not is_start_of_fx_graph(graph, n):
n = n.prev
mutation_region_id = n.meta.get("mutation_region_id", 0)
while n is not node:
n = n.next
if is_mutation_op(n):
mutation_region_id += 1
n.meta["mutation_region_id"] = mutation_region_id
return mutation_region_id


def should_compute_mutation_region_ids(graph):
return "mutation_region_id" not in next(iter(graph.nodes)).meta


def compute_mutation_region_ids(graph):
mutation_region_id = 0
for nd in graph.nodes:
if is_mutation_op(nd):
mutation_region_id += 1
nd.meta["mutation_region_id"] = mutation_region_id


class PatternMatcherPass:
def __init__(self):
def __init__(self, prevent_match_across_mutations=False):
super().__init__()
self.patterns = defaultdict(list)
self.prevent_match_across_mutations = prevent_match_across_mutations

def __getitem__(self, item):
return self.patterns[item]
Expand All @@ -731,6 +775,12 @@ def apply(self, graph):
return 0
if isinstance(graph, torch.fx.GraphModule):
graph = graph.graph
if self.prevent_match_across_mutations:
if should_compute_mutation_region_ids(graph):
compute_mutation_region_ids(graph)
get_mutation_region_id_partial = functools.partial(
get_mutation_region_id, graph
)
count = 0
for node in reversed(graph.nodes):
if (
Expand All @@ -747,6 +797,13 @@ def apply(self, graph):
if node._erased:
break
m = entry.pattern.match(node)
# pattern match crosses mutation barrier - discard
if (
self.prevent_match_across_mutations
and m
and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1
):
continue
if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name:
log.warning("%s%s %s %s", node, node.args, m, entry.pattern)
if m and entry.extra_check(m):
Expand Down

0 comments on commit a370bca

Please sign in to comment.