Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent pattern matches across mutation ops in inductor pre-grad FX passes #101144

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
46 changes: 46 additions & 0 deletions test/inductor/test_pattern_matcher.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Owner(s): ["module: inductor"]
import copy
import unittest

import torch
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.utils import count_calls, counters
Expand Down Expand Up @@ -262,6 +265,49 @@ def fn(a):
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)

def test_match_with_mutation(self):
from torch._inductor.pattern_matcher import (
CallFunction,
KeywordArg,
PatternMatcherPass,
register_graph_pattern,
)

counter = 0
test_pass = PatternMatcherPass()

@register_graph_pattern(
CallFunction(
torch.add, KeywordArg("x"), CallFunction(torch.sin, KeywordArg("x"))
),
pass_dict=test_pass,
)
def _test(match, x):
breakpoint()
nonlocal counter
counter += 1

def fn(x, y):
a = torch.sin(x)
x.copy_(y)
b = torch.add(x, a)
return b

args1 = [
torch.randn(5, 5, device="cuda"),
torch.randn(5, 5, device="cuda"),
]
args2 = copy.deepcopy(args1)

with unittest.mock.patch(
"torch._inductor.fx_passes.pre_grad.pattern_matcher_passes", [test_pass]
):
expected = fn(*args1)
actual = torch.compile(fn)(*args2)
# should not match
self.assertEqual(counter, 0)
torch.testing.assert_close(actual, expected)


if __name__ == "__main__":
if IS_LINUX and HAS_CUDA:
Expand Down
33 changes: 33 additions & 0 deletions torch/_inductor/pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,36 @@ def decorator(handler):
return decorator


def is_start_of_fx_graph(node):
return len(node.all_input_nodes) == 0
williamwen42 marked this conversation as resolved.
Show resolved Hide resolved


def is_mutation_op(node):
if node.op == "call_function":
if node.target.__name__.endswith("_"):
return True
elif node.op == "call_method":
if node.target.endswith("_"):
return True
if "out" in node.kwargs:
if node.kwargs["out"] in node.all_input_nodes:
return True
return False


def get_mutation_region_id(node):
n = node
while "mutation_region_id" not in n.meta and not is_start_of_fx_graph(n):
n = n.prev
mutation_region_id = n.meta.get("mutation_region_id", 0)
while n is not node:
n = n.next
if is_mutation_op(n):
mutation_region_id += 1
n.meta["mutation_region_id"] = mutation_region_id
return mutation_region_id


class PatternMatcherPass:
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -740,6 +770,9 @@ def apply(self, graph):
if node._erased:
break
m = entry.pattern.match(node)
# pattern match crosses mutation barrier - discard
if m and len(set(map(get_mutation_region_id, m.nodes))) != 1:
continue
if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name:
log.warning("%s%s %s %s", node, node.args, m, entry.pattern)
if m and entry.extra_check(m):
Expand Down