Skip to content

Commit

Permalink
[Inductor] Add fused_attention pattern matcher with additional clone (#…
Browse files Browse the repository at this point in the history
…108141)

A previous PR #106274 decomposes `aten.dropout` and would create a `clone()` when `eval()` or `p=0`. This makes many SDPA-related models fail to match fused_attention pattern matchers.

This PR adds new fused_attention pattern matchers with an additional clone to re-enable the SDPA op matching.

Pull Request resolved: #108141
Approved by: https://github.com/jgong5, https://github.com/eellison
  • Loading branch information
Valentine233 authored and pytorchmergebot committed Aug 30, 2023
1 parent e45b391 commit 3a79621
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 0 deletions.
70 changes: 70 additions & 0 deletions test/inductor/test_fused_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,58 @@ def dot_prod_attention(

self._check_common(dot_prod_attention, contains=False, has_dropout=True)

@skipIfRocm
def _test_sdpa_rewriter_13(self):
def dot_prod_attention(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
"""Input tensors assumed to have shape (batch_size, n_head, seq_len, embed_dim)"""
return (
torch.matmul(query, key.transpose(-2, -1))
.div(math.sqrt(key.shape[-1]))
.softmax(dim=-1)
.clone()
.matmul(value)
)

self._check_common(dot_prod_attention)
self._check_common(checkpoint_wrapper(dot_prod_attention))

@skipIfRocm
def _test_sdpa_rewriter_14(self):
def dot_prod_attention(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
return (
torch.matmul(query, key.transpose(-2, -1))
.mul(1.0 / math.sqrt(key.shape[-1]))
.softmax(dim=-1)
.clone()
.matmul(value)
)

self._check_common(dot_prod_attention)
self._check_common(checkpoint_wrapper(dot_prod_attention))

@skipIfRocm
def _test_sdpa_rewriter_15(self):
def dot_prod_attention(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
"""Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)"""
q = query.transpose(1, 2)
k = key.transpose(1, 2)
v = value.transpose(1, 2)
return (
torch.matmul(q, k.transpose(-2, -1))
.div(math.sqrt(key.shape[-1]))
.softmax(dim=-1)
.clone()
.matmul(v)
)

self._check_common(dot_prod_attention)


if HAS_CUDA and PLATFORM_SUPPORTS_FUSED_SDPA:

Expand Down Expand Up @@ -493,6 +545,15 @@ class SDPAPatternRewriterCudaTests(TestSDPAPatternRewriterTemplate):
test_sdpa_rewriter_12_cuda = (
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_12
)
test_sdpa_rewriter_13_cuda = (
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13
)
test_sdpa_rewriter_14_cuda = (
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_14
)
test_sdpa_rewriter_15_cuda = (
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_15
)


if HAS_CPU:
Expand All @@ -517,6 +578,15 @@ class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate):
test_sdpa_rewriter_12_cpu = (
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_12
)
test_sdpa_rewriter_13_cpu = (
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13
)
test_sdpa_rewriter_14_cpu = (
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_14
)
test_sdpa_rewriter_15_cpu = (
TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_15
)


if __name__ == "__main__":
Expand Down
96 changes: 96 additions & 0 deletions torch/_inductor/fx_passes/fuse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,81 @@ def _sfdp_replacement_12(query, key, value, inv_scale_factor, dropout_p):
)


def _sfdp_pattern_13(query, key, value, inv_scale):
# dropout would create a clone() if eval() or p = 0
return (
torch.matmul(query, key.transpose(-2, -1))
.div(inv_scale)
.softmax(dim=-1)
.clone()
.matmul(value)
)


def _sfdp_replacement_13(query, key, value, inv_scale):
counters["inductor"]["fuse_attention"] += 1
return aten.scaled_dot_product_attention(
query.contiguous(),
key.contiguous(),
value.contiguous(),
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=1.0 / inv_scale,
)


def _sfdp_pattern_14(query, key, value, scale_factor):
# dropout would create a clone() if eval() or p = 0
return (
torch.matmul(query, key.transpose(-2, -1))
.mul(scale_factor)
.softmax(dim=-1)
.clone()
.matmul(value)
)


def _sfdp_replacement_14(query, key, value, scale_factor):
counters["inductor"]["fuse_attention"] += 1
return aten.scaled_dot_product_attention(
query.contiguous(),
key.contiguous(),
value.contiguous(),
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=scale_factor,
)


def _sfdp_pattern_15(query, key, value, inv_scale):
# dropout would create a clone() if eval() or p = 0
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
return (
torch.matmul(q, k.transpose(-2, -1))
.div(inv_scale)
.softmax(dim=-1)
.clone()
.matmul(v)
)


def _sfdp_replacement_15(query, key, value, inv_scale):
counters["inductor"]["fuse_attention"] += 1
return aten.scaled_dot_product_attention(
query.transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2),
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=1.0 / inv_scale,
)


def _sfdp_params_check(match):
assert all(k in match.kwargs for k in ("query", "key", "value"))
query = match.kwargs["query"].meta["val"]
Expand Down Expand Up @@ -450,6 +525,27 @@ def _sfdp_init():
d,
_sfdp_scale_factor_check(aten.div.Tensor),
),
(
_sfdp_pattern_13,
_sfdp_replacement_13,
[g(), g(), g(), c()],
{},
_sfdp_scale_factor_check(aten.div.Tensor),
),
(
_sfdp_pattern_14,
_sfdp_replacement_14,
[g(), g(), g(), c()],
{},
_sfdp_scale_factor_check(aten.mul.Tensor),
),
(
_sfdp_pattern_15,
_sfdp_replacement_15,
[g(), g(), g(), c()],
{},
_sfdp_scale_factor_check(aten.div.Tensor),
),
]:
args = [*args, *workaround.values()]
register_replacement(
Expand Down

0 comments on commit 3a79621

Please sign in to comment.