From 3a79621c9dce17f77fbddc06aab21f6bc477f313 Mon Sep 17 00:00:00 2001 From: "Liao, Xuan" Date: Wed, 30 Aug 2023 05:07:32 +0000 Subject: [PATCH] [Inductor] Add fused_attention pattern matcher with additional clone (#108141) A previous PR https://github.com/pytorch/pytorch/pull/106274 decomposes `aten.dropout` and would create a `clone()` when `eval()` or `p=0`. This makes many SDPA-related models fail to match fused_attention pattern matchers. This PR adds new fused_attention pattern matchers with an additional clone to re-enable the SDPA op matching. Pull Request resolved: https://github.com/pytorch/pytorch/pull/108141 Approved by: https://github.com/jgong5, https://github.com/eellison --- test/inductor/test_fused_attention.py | 70 +++++++++++++++ torch/_inductor/fx_passes/fuse_attention.py | 96 +++++++++++++++++++++ 2 files changed, 166 insertions(+) diff --git a/test/inductor/test_fused_attention.py b/test/inductor/test_fused_attention.py index cc6ef73e8c621..df018354827f0 100644 --- a/test/inductor/test_fused_attention.py +++ b/test/inductor/test_fused_attention.py @@ -443,6 +443,58 @@ def dot_prod_attention( self._check_common(dot_prod_attention, contains=False, has_dropout=True) + @skipIfRocm + def _test_sdpa_rewriter_13(self): + def dot_prod_attention( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + """Input tensors assumed to have shape (batch_size, n_head, seq_len, embed_dim)""" + return ( + torch.matmul(query, key.transpose(-2, -1)) + .div(math.sqrt(key.shape[-1])) + .softmax(dim=-1) + .clone() + .matmul(value) + ) + + self._check_common(dot_prod_attention) + self._check_common(checkpoint_wrapper(dot_prod_attention)) + + @skipIfRocm + def _test_sdpa_rewriter_14(self): + def dot_prod_attention( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + return ( + torch.matmul(query, key.transpose(-2, -1)) + .mul(1.0 / math.sqrt(key.shape[-1])) + .softmax(dim=-1) + .clone() + .matmul(value) + ) + + self._check_common(dot_prod_attention) + self._check_common(checkpoint_wrapper(dot_prod_attention)) + + @skipIfRocm + def _test_sdpa_rewriter_15(self): + def dot_prod_attention( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" + q = query.transpose(1, 2) + k = key.transpose(1, 2) + v = value.transpose(1, 2) + return ( + torch.matmul(q, k.transpose(-2, -1)) + .div(math.sqrt(key.shape[-1])) + .softmax(dim=-1) + .clone() + .matmul(v) + ) + + self._check_common(dot_prod_attention) + if HAS_CUDA and PLATFORM_SUPPORTS_FUSED_SDPA: @@ -493,6 +545,15 @@ class SDPAPatternRewriterCudaTests(TestSDPAPatternRewriterTemplate): test_sdpa_rewriter_12_cuda = ( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_12 ) + test_sdpa_rewriter_13_cuda = ( + TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13 + ) + test_sdpa_rewriter_14_cuda = ( + TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_14 + ) + test_sdpa_rewriter_15_cuda = ( + TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_15 + ) if HAS_CPU: @@ -517,6 +578,15 @@ class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate): test_sdpa_rewriter_12_cpu = ( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_12 ) + test_sdpa_rewriter_13_cpu = ( + TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13 + ) + test_sdpa_rewriter_14_cpu = ( + TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_14 + ) + test_sdpa_rewriter_15_cpu = ( + TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_15 + ) if __name__ == "__main__": diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py index 444d4dd3e3398..333eafa1d5a26 100644 --- a/torch/_inductor/fx_passes/fuse_attention.py +++ b/torch/_inductor/fx_passes/fuse_attention.py @@ -303,6 +303,81 @@ def _sfdp_replacement_12(query, key, value, inv_scale_factor, dropout_p): ) +def _sfdp_pattern_13(query, key, value, inv_scale): + # dropout would create a clone() if eval() or p = 0 + return ( + torch.matmul(query, key.transpose(-2, -1)) + .div(inv_scale) + .softmax(dim=-1) + .clone() + .matmul(value) + ) + + +def _sfdp_replacement_13(query, key, value, inv_scale): + counters["inductor"]["fuse_attention"] += 1 + return aten.scaled_dot_product_attention( + query.contiguous(), + key.contiguous(), + value.contiguous(), + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_14(query, key, value, scale_factor): + # dropout would create a clone() if eval() or p = 0 + return ( + torch.matmul(query, key.transpose(-2, -1)) + .mul(scale_factor) + .softmax(dim=-1) + .clone() + .matmul(value) + ) + + +def _sfdp_replacement_14(query, key, value, scale_factor): + counters["inductor"]["fuse_attention"] += 1 + return aten.scaled_dot_product_attention( + query.contiguous(), + key.contiguous(), + value.contiguous(), + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=scale_factor, + ) + + +def _sfdp_pattern_15(query, key, value, inv_scale): + # dropout would create a clone() if eval() or p = 0 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return ( + torch.matmul(q, k.transpose(-2, -1)) + .div(inv_scale) + .softmax(dim=-1) + .clone() + .matmul(v) + ) + + +def _sfdp_replacement_15(query, key, value, inv_scale): + counters["inductor"]["fuse_attention"] += 1 + return aten.scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=1.0 / inv_scale, + ) + + def _sfdp_params_check(match): assert all(k in match.kwargs for k in ("query", "key", "value")) query = match.kwargs["query"].meta["val"] @@ -450,6 +525,27 @@ def _sfdp_init(): d, _sfdp_scale_factor_check(aten.div.Tensor), ), + ( + _sfdp_pattern_13, + _sfdp_replacement_13, + [g(), g(), g(), c()], + {}, + _sfdp_scale_factor_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_14, + _sfdp_replacement_14, + [g(), g(), g(), c()], + {}, + _sfdp_scale_factor_check(aten.mul.Tensor), + ), + ( + _sfdp_pattern_15, + _sfdp_replacement_15, + [g(), g(), g(), c()], + {}, + _sfdp_scale_factor_check(aten.div.Tensor), + ), ]: args = [*args, *workaround.values()] register_replacement(