Skip to content

Commit 0af94bb

Browse files
committed
[Inductor] Activation(Addmm) fusion
ghstack-source-id: f4e8f59 Pull-Request: #167469
1 parent 56aa531 commit 0af94bb

File tree

2 files changed

+88
-4
lines changed

2 files changed

+88
-4
lines changed

torch/_inductor/fx_passes/post_grad.py

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,87 @@ def body_fn(*flat_args):
674674
raise AssertionError("scan is not lowered to while_loop")
675675

676676

677+
def register_addmm_activation_replacement():
678+
def addmm_activation_replacement(inp, mat1, mat2, beta=1, alpha=1):
679+
return aten._addmm_activation(inp, mat1, mat2, beta=beta, alpha=alpha)
680+
681+
def gen_addmm_activation_patterns():
682+
for is_beta_eq_one, is_alpha_eq_one in itertools.product((True, False), repeat=2):
683+
if is_alpha_eq_one:
684+
def apply_mm(mat1, mat2, alpha=1):
685+
return mat1 @ mat2
686+
else:
687+
def apply_mm(mat1, mat2, alpha=1):
688+
return alpha * (mat1 @ mat2)
689+
690+
if is_beta_eq_one:
691+
def apply_bias(inp, beta=1):
692+
return inp
693+
else:
694+
def apply_bias(inp, beta=1):
695+
return beta * inp
696+
697+
for activation in (aten.relu, aten.gelu):
698+
def bias_add_mm_activation_pattern(inp, mat1, mat2, beta=1, alpha=1):
699+
return activation(apply_bias(inp, beta) + apply_mm(mat1, mat2, alpha))
700+
701+
def mm_add_bias_activation_pattern(inp, mat1, mat2, beta=1, alpha=1):
702+
return activation(apply_mm(mat1, mat2, alpha) + apply_bias(inp, beta))
703+
704+
yield bias_add_mm_activation_pattern
705+
yield mm_add_bias_activation_pattern
706+
707+
def is_valid_addmm_activation_fusion(match: Match):
708+
if config.max_autotune_gemm:
709+
return False
710+
711+
inp = match.kwargs["input"].meta["val"]
712+
mat1 = match.kwargs["mat1"].meta["val"]
713+
mat2 = match.kwargs["mat2"].meta["val"]
714+
beta = match.kwargs["beta"].meta["val"]
715+
716+
if beta != 1:
717+
return False
718+
719+
if not inp.is_cuda:
720+
return False
721+
722+
if not (mat1.dim() == 2 and mat2.dim() == 2):
723+
return False
724+
725+
if inp.size(0) != mat2.size(1):
726+
return False
727+
728+
if inp.dtype != mat1.dtype or inp.dtype != mat2.dtype:
729+
return False
730+
731+
return not has_uses_tagged_as(
732+
match.output_node(),
733+
(torch.Tag.pointwise, torch.Tag.reduction),
734+
)
735+
736+
args = (
737+
torch.empty(5), # bias
738+
torch.empty(3, 4), # mat1
739+
torch.empty(4, 5), # mat2
740+
)
741+
742+
for addmm_activation_pattern in gen_addmm_activation_patterns():
743+
for beta, alpha in itertools.product((0.5,), repeat=2):
744+
register_replacement(
745+
# pyrefly: ignore [bad-argument-type]
746+
addmm_activation_pattern,
747+
# pyrefly: ignore [bad-argument-type]
748+
addmm_activation_replacement,
749+
[*args, beta, alpha],
750+
# pyrefly: ignore [bad-argument-type]
751+
trace_fn=fwd_only,
752+
# pyrefly: ignore [bad-argument-type]
753+
pass_dicts=pass_patterns[2],
754+
extra_check=is_valid_addmm_activation_fusion,
755+
)
756+
757+
677758
@init_once_fakemode
678759
def lazy_init():
679760
if torch._C._has_mkldnn:
@@ -699,6 +780,8 @@ def lazy_init():
699780
extra_check=prepare_softmax_extra_check,
700781
)
701782

783+
register_addmm_activation_replacement()
784+
702785

703786
def reorder_for_locality(graph: torch.fx.Graph):
704787
if torch.distributed.is_available():
@@ -1528,7 +1611,7 @@ def should_prefer_unfused_addmm(match):
15281611
alpha=KeywordArg("alpha"),
15291612
),
15301613
# pyrefly: ignore [bad-argument-type]
1531-
pass_dict=pass_patterns[2],
1614+
pass_dict=pass_patterns[1],
15321615
extra_check=should_prefer_unfused_addmm,
15331616
)
15341617
def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp, alpha, beta):
@@ -1575,7 +1658,7 @@ def is_valid_addmm_fusion(match):
15751658
KeywordArg("inp"),
15761659
),
15771660
# pyrefly: ignore [bad-argument-type]
1578-
pass_dict=pass_patterns[2],
1661+
pass_dict=pass_patterns[1],
15791662
extra_check=is_valid_addmm_fusion,
15801663
)
15811664
@register_graph_pattern(
@@ -1585,7 +1668,7 @@ def is_valid_addmm_fusion(match):
15851668
CallFunction(aten.mm, Arg(), Arg()),
15861669
),
15871670
# pyrefly: ignore [bad-argument-type]
1588-
pass_dict=pass_patterns[2],
1671+
pass_dict=pass_patterns[1],
15891672
extra_check=is_valid_addmm_fusion,
15901673
)
15911674
def addmm(match, mat1, mat2, *, inp):

torchgen/fuse/gen_patterns.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33

44
from torch._inductor import pattern_matcher
5-
from torch._inductor.fx_passes import joint_graph
5+
from torch._inductor.fx_passes import joint_graph, post_grad
66

77

88
if __name__ == "__main__":
@@ -17,3 +17,4 @@
1717
# to serialize the patterns as it goes.
1818
os.environ["PYTORCH_GEN_PATTERNS"] = "1"
1919
joint_graph.lazy_init()
20+
post_grad.lazy_init()

0 commit comments

Comments
 (0)