Skip to content

Commit

Permalink
Fix symbolic scalar on "[inductor] Make sure unfuse_addmm and addmm p…
Browse files Browse the repository at this point in the history
…atterns don't overlap"

Inductor has two opposing patterns,
```
addmm -> add + mm
add + mm -> addmm
```

This uses the `extra_check` to disable the addmm fusion pattern when the
heuristic to unfuse add is met, for consistency.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
  • Loading branch information
peterbell10 committed Sep 29, 2023
2 parents 1c44988 + 6ab2d0b commit fab9c1e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
14 changes: 14 additions & 0 deletions test/inductor/test_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,20 @@ def fn(a, b, c):
self.assertEqual(counters["inductor"]["pattern_matcher_count"], count)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], nodes)

def test_addmm_symbolic_scalar(self):
def fn(m1, m2):
bias = m1.size(0)
return torch.add(bias, torch.mm(m1, m2)), torch.mm(m1, m2) + bias

m1 = torch.randn(16, 16, device="cuda")
m2 = torch.randn(16, 16, device="cuda")

counters.clear()
expect = fn(m1, m2)
actual = torch.compile(fn, dynamic=True)(m1, m2)
self.assertEqual(expect, actual)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)

def test_cat_mm(self):
def fn(a, b, c):
return torch.cat(
Expand Down
4 changes: 3 additions & 1 deletion torch/_inductor/fx_passes/post_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,9 @@ def is_valid_addmm_fusion(match):
mat1, mat2 = match.args
inp = match.kwargs["inp"]

if not isinstance(inp, torch.fx.Node):
if not (
isinstance(inp, torch.fx.Node) and isinstance(inp.meta["val"], torch.Tensor)
):
return False # Input is a number

in_shape = inp.meta["val"].shape
Expand Down

0 comments on commit fab9c1e

Please sign in to comment.