Skip to content

Commit 326120a

Browse files
committed
[Inductor] refine the logic in (mm + bias) -> addmm
ghstack-source-id: ea3fb8e Pull-Request: #166300
1 parent c69f15e commit 326120a

File tree

1 file changed

+3
-16
lines changed

1 file changed

+3
-16
lines changed

torch/_inductor/fx_passes/post_grad.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,27 +1554,14 @@ def repl(inp, x1, x2):
15541554
match.replace_by_example(repl, [inp, mat1, mat2])
15551555

15561556

1557-
def is_valid_addmm_fusion(match):
1558-
mat1, mat2 = match.args
1557+
def can_fuse_bias_in_addmm(match):
15591558
inp = match.kwargs["inp"]
15601559

15611560
if not (
15621561
isinstance(inp, torch.fx.Node) and isinstance(inp.meta["val"], torch.Tensor)
15631562
):
15641563
return False # Input is a number
15651564

1566-
in_shape = inp.meta["val"].shape
1567-
mm_shape = mat1.meta["val"].shape[0], mat2.meta["val"].shape[1]
1568-
matched = is_expandable_to(in_shape, mm_shape)
1569-
if not matched:
1570-
return False # Shape mismatch
1571-
1572-
inp_dtype = inp.meta["val"].dtype
1573-
1574-
# aten cublas integration assumes equal dtypes
1575-
if inp_dtype != mat1.meta["val"].dtype or inp_dtype != mat2.meta["val"].dtype:
1576-
return False
1577-
15781565
return not should_prefer_unfused_addmm(match)
15791566

15801567

@@ -1586,7 +1573,7 @@ def is_valid_addmm_fusion(match):
15861573
),
15871574
# pyrefly: ignore [bad-argument-type]
15881575
pass_dict=pass_patterns[2],
1589-
extra_check=is_valid_addmm_fusion,
1576+
extra_check=can_fuse_bias_in_addmm,
15901577
)
15911578
@register_graph_pattern(
15921579
CallFunction(
@@ -1596,7 +1583,7 @@ def is_valid_addmm_fusion(match):
15961583
),
15971584
# pyrefly: ignore [bad-argument-type]
15981585
pass_dict=pass_patterns[2],
1599-
extra_check=is_valid_addmm_fusion,
1586+
extra_check=can_fuse_bias_in_addmm,
16001587
)
16011588
def addmm(match, mat1, mat2, *, inp):
16021589
def repl(inp, mat1, mat2):

0 commit comments

Comments
 (0)