@@ -674,6 +674,87 @@ def body_fn(*flat_args):
674674 raise AssertionError ("scan is not lowered to while_loop" )
675675
676676
677+ def register_addmm_activation_replacement ():
678+ def addmm_activation_replacement (inp , mat1 , mat2 , beta = 1 , alpha = 1 ):
679+ return aten ._addmm_activation (inp , mat1 , mat2 , beta = beta , alpha = alpha )
680+
681+ def gen_addmm_activation_patterns ():
682+ for is_beta_eq_one , is_alpha_eq_one in itertools .product ((True , False ), repeat = 2 ):
683+ if is_alpha_eq_one :
684+ def apply_mm (mat1 , mat2 , alpha = 1 ):
685+ return mat1 @ mat2
686+ else :
687+ def apply_mm (mat1 , mat2 , alpha = 1 ):
688+ return alpha * (mat1 @ mat2 )
689+
690+ if is_beta_eq_one :
691+ def apply_bias (inp , beta = 1 ):
692+ return inp
693+ else :
694+ def apply_bias (inp , beta = 1 ):
695+ return beta * inp
696+
697+ for activation in (aten .relu , aten .gelu ):
698+ def bias_add_mm_activation_pattern (inp , mat1 , mat2 , beta = 1 , alpha = 1 ):
699+ return activation (apply_bias (inp , beta ) + apply_mm (mat1 , mat2 , alpha ))
700+
701+ def mm_add_bias_activation_pattern (inp , mat1 , mat2 , beta = 1 , alpha = 1 ):
702+ return activation (apply_mm (mat1 , mat2 , alpha ) + apply_bias (inp , beta ))
703+
704+ yield bias_add_mm_activation_pattern
705+ yield mm_add_bias_activation_pattern
706+
707+ def is_valid_addmm_activation_fusion (match : Match ):
708+ if config .max_autotune_gemm :
709+ return False
710+
711+ inp = match .kwargs ["input" ].meta ["val" ]
712+ mat1 = match .kwargs ["mat1" ].meta ["val" ]
713+ mat2 = match .kwargs ["mat2" ].meta ["val" ]
714+ beta = match .kwargs ["beta" ].meta ["val" ]
715+
716+ if beta != 1 :
717+ return False
718+
719+ if not inp .is_cuda :
720+ return False
721+
722+ if not (mat1 .dim () == 2 and mat2 .dim () == 2 ):
723+ return False
724+
725+ if inp .size (0 ) != mat2 .size (1 ):
726+ return False
727+
728+ if inp .dtype != mat1 .dtype or inp .dtype != mat2 .dtype :
729+ return False
730+
731+ return not has_uses_tagged_as (
732+ match .output_node (),
733+ (torch .Tag .pointwise , torch .Tag .reduction ),
734+ )
735+
736+ args = (
737+ torch .empty (5 ), # bias
738+ torch .empty (3 , 4 ), # mat1
739+ torch .empty (4 , 5 ), # mat2
740+ )
741+
742+ for addmm_activation_pattern in gen_addmm_activation_patterns ():
743+ for beta , alpha in itertools .product ((0.5 ,), repeat = 2 ):
744+ register_replacement (
745+ # pyrefly: ignore [bad-argument-type]
746+ addmm_activation_pattern ,
747+ # pyrefly: ignore [bad-argument-type]
748+ addmm_activation_replacement ,
749+ [* args , beta , alpha ],
750+ # pyrefly: ignore [bad-argument-type]
751+ trace_fn = fwd_only ,
752+ # pyrefly: ignore [bad-argument-type]
753+ pass_dicts = pass_patterns [2 ],
754+ extra_check = is_valid_addmm_activation_fusion ,
755+ )
756+
757+
677758@init_once_fakemode
678759def lazy_init ():
679760 if torch ._C ._has_mkldnn :
@@ -699,6 +780,8 @@ def lazy_init():
699780 extra_check = prepare_softmax_extra_check ,
700781 )
701782
783+ register_addmm_activation_replacement ()
784+
702785
703786def reorder_for_locality (graph : torch .fx .Graph ):
704787 if torch .distributed .is_available ():
@@ -1528,7 +1611,7 @@ def should_prefer_unfused_addmm(match):
15281611 alpha = KeywordArg ("alpha" ),
15291612 ),
15301613 # pyrefly: ignore [bad-argument-type]
1531- pass_dict = pass_patterns [2 ],
1614+ pass_dict = pass_patterns [1 ],
15321615 extra_check = should_prefer_unfused_addmm ,
15331616)
15341617def unfuse_bias_add_to_pointwise (match : Match , mat1 , mat2 , * , inp , alpha , beta ):
@@ -1575,7 +1658,7 @@ def is_valid_addmm_fusion(match):
15751658 KeywordArg ("inp" ),
15761659 ),
15771660 # pyrefly: ignore [bad-argument-type]
1578- pass_dict = pass_patterns [2 ],
1661+ pass_dict = pass_patterns [1 ],
15791662 extra_check = is_valid_addmm_fusion ,
15801663)
15811664@register_graph_pattern (
@@ -1585,7 +1668,7 @@ def is_valid_addmm_fusion(match):
15851668 CallFunction (aten .mm , Arg (), Arg ()),
15861669 ),
15871670 # pyrefly: ignore [bad-argument-type]
1588- pass_dict = pass_patterns [2 ],
1671+ pass_dict = pass_patterns [1 ],
15891672 extra_check = is_valid_addmm_fusion ,
15901673)
15911674def addmm (match , mat1 , mat2 , * , inp ):
0 commit comments