@@ -1554,27 +1554,14 @@ def repl(inp, x1, x2):
15541554 match .replace_by_example (repl , [inp , mat1 , mat2 ])
15551555
15561556
1557- def is_valid_addmm_fusion (match ):
1558- mat1 , mat2 = match .args
1557+ def can_fuse_bias_in_addmm (match ):
15591558 inp = match .kwargs ["inp" ]
15601559
15611560 if not (
15621561 isinstance (inp , torch .fx .Node ) and isinstance (inp .meta ["val" ], torch .Tensor )
15631562 ):
15641563 return False # Input is a number
15651564
1566- in_shape = inp .meta ["val" ].shape
1567- mm_shape = mat1 .meta ["val" ].shape [0 ], mat2 .meta ["val" ].shape [1 ]
1568- matched = is_expandable_to (in_shape , mm_shape )
1569- if not matched :
1570- return False # Shape mismatch
1571-
1572- inp_dtype = inp .meta ["val" ].dtype
1573-
1574- # aten cublas integration assumes equal dtypes
1575- if inp_dtype != mat1 .meta ["val" ].dtype or inp_dtype != mat2 .meta ["val" ].dtype :
1576- return False
1577-
15781565 return not should_prefer_unfused_addmm (match )
15791566
15801567
@@ -1586,7 +1573,7 @@ def is_valid_addmm_fusion(match):
15861573 ),
15871574 # pyrefly: ignore [bad-argument-type]
15881575 pass_dict = pass_patterns [2 ],
1589- extra_check = is_valid_addmm_fusion ,
1576+ extra_check = can_fuse_bias_in_addmm ,
15901577)
15911578@register_graph_pattern (
15921579 CallFunction (
@@ -1596,7 +1583,7 @@ def is_valid_addmm_fusion(match):
15961583 ),
15971584 # pyrefly: ignore [bad-argument-type]
15981585 pass_dict = pass_patterns [2 ],
1599- extra_check = is_valid_addmm_fusion ,
1586+ extra_check = can_fuse_bias_in_addmm ,
16001587)
16011588def addmm (match , mat1 , mat2 , * , inp ):
16021589 def repl (inp , mat1 , mat2 ):
0 commit comments