Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fused attention patterns #97741

Closed
wants to merge 13 commits into from
171 changes: 171 additions & 0 deletions test/inductor/test_fused_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Owner(s): ["module: inductor"]
import itertools
import math

import torch
import torch._inductor.config
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.utils import counters
from torch._inductor import config
from torch._inductor.utils import run_and_get_code
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FUSED_SDPA
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CUDA


class TestSDPAPatternRewriter(TestCase):
@config.patch(fallback_random=True, lowmem_dropout=False)
def _check_common(self, dot_prod_attention, args1=None, contains=True):
tensor_shape = (4, 2, 16, 32)
if args1 is None:
args1 = [
torch.randn(tensor_shape, device="cuda"),
torch.randn(tensor_shape, device="cuda"),
torch.randn(tensor_shape, device="cuda"),
]
args2 = [*map(torch.clone, args1)]

for training in [False, True]:
for x in itertools.chain(args1[:3], args2[:3]):
x.requires_grad = training

torch.manual_seed(1234)
result1 = dot_prod_attention(*args1)

counters.clear()
torch.manual_seed(1234)
result2, (source_code,) = run_and_get_code(
torch.compile(dot_prod_attention, fullgraph=True), *args2
)
self.assertGreaterEqual(counters["inductor"]["fuse_attention"], 1)
if contains:
# many of the patterns get re-expanded in dispatcher
self.assertIn(
"aten._scaled_dot_product_efficient_attention", source_code
)
self.assertEqual(result1, result2)

if training:
result1.sum().backward()
result2.sum().backward()

self.assertEqual(args1[0].grad, args2[0].grad)
self.assertEqual(args1[1].grad, args2[1].grad)
self.assertEqual(args1[2].grad, args2[2].grad)

def test_sdpa_rewriter_1(self):
def dot_prod_attention(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
"""Input tensors assumed to have shape (batch_size, n_head, seq_len, embed_dim)"""
return (
torch.matmul(query, key.transpose(-2, -1))
.div(math.sqrt(key.shape[-1]))
.softmax(dim=-1)
.matmul(value)
)

self._check_common(dot_prod_attention)

def test_pattern_fails_with_reuse(self):
"""
This test checks that the replacement is not done
when an intermediate result is being used / returned downstream
"""

@torch.compile(fullgraph=True)
def dot_prod_attention(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
attn_weights = (
torch.matmul(query, key.transpose(-2, -1))
.div(math.sqrt(key.shape[-1]))
.softmax(dim=-1)
)
return attn_weights.matmul(value), attn_weights

tensor_shape = (2, 4, 8, 16)
args = [
torch.randn(tensor_shape, device="cuda"),
torch.randn(tensor_shape, device="cuda"),
torch.randn(tensor_shape, device="cuda"),
]
_, (source_code,) = run_and_get_code(dot_prod_attention, *args)
self.assertNotIn("aten._scaled_dot_product_efficient_attention", source_code)

def test_sdpa_rewriter_2(self):
def dot_prod_attention(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
return (
torch.matmul(query, key.transpose(-2, -1))
.mul(1.0 / math.sqrt(key.shape[-1]))
.softmax(dim=-1)
.matmul(value)
)

self._check_common(dot_prod_attention)

def test_sdpa_rewriter_3(self):
def dot_prod_attention(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
return torch.nn.functional.dropout(
torch.matmul(query, key.transpose(-2, -1)).div(3.0).softmax(dim=-1),
p=0.4,
training=True,
inplace=False,
).matmul(value)

self._check_common(dot_prod_attention, contains=False)

def test_sdpa_rewriter_4(self):
def dot_prod_attention(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
return torch.nn.functional.dropout(
torch.matmul(query, key.transpose(-2, -1)).mul(0.4).softmax(dim=-1),
p=0.2,
training=True,
inplace=False,
).matmul(value)

self._check_common(dot_prod_attention, contains=False)

def test_sdpa_rewriter_5(self):
def sfdp_pattern_5(query, key, value):
attn_mask = torch.ones(
query.size(-2), key.size(-2), dtype=torch.bool, device=query.device
).tril(diagonal=0)
attn_mask = attn_mask.masked_fill(
torch.logical_not(attn_mask), -float("inf")
)
attn_weight = torch.softmax(
(query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask,
dim=-1,
)
return attn_weight @ value

self._check_common(sfdp_pattern_5, contains=False)

def test_sdpa_rewriter_6(self):
def sfdp_pattern_6(query, key, value):
attn_mask = torch.ones(
query.size(-2), key.size(-2), dtype=torch.bool, device=query.device
).tril(diagonal=0)
attn_mask = attn_mask.masked_fill(
torch.logical_not(attn_mask), -float("inf")
)
attn_weight = torch.softmax(
(query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask,
dim=-1,
)
attn_weight = torch.dropout(attn_weight, 0.5, True)
return attn_weight @ value

self._check_common(sfdp_pattern_6, contains=False)


if __name__ == "__main__":
if IS_LINUX and HAS_CUDA and PLATFORM_SUPPORTS_FUSED_SDPA:
run_tests()
15 changes: 12 additions & 3 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from . import config, metrics, overrides, pattern_matcher
from .debug import DebugContext
from .decomposition import select_decomp_table
from .fx_passes.joint_graph import joint_graph_passes
from .graph import GraphLowering
from .mkldnn import convert_outplace_to_inplace
from .utils import (
Expand Down Expand Up @@ -665,6 +666,10 @@ def compile_fx(

@dynamo_utils.dynamo_timed
def fw_compiler_base(model: torch.fx.GraphModule, example_inputs, is_inference):
if is_inference:
# partition_fn won't be called
joint_graph_passes(model)

fixed = len(example_inputs) - num_example_inputs
# Why convert outplace op to inplace? Inductor can support inplace operations well and for custom
# inplace ops which are lowered as ExternKernel, it is beneficial to performance when the inplace
Expand All @@ -683,6 +688,12 @@ def fw_compiler_base(model: torch.fx.GraphModule, example_inputs, is_inference):
fw_compiler = functools.partial(fw_compiler_base, is_inference=False)
inference_compiler = functools.partial(fw_compiler_base, is_inference=True)

def partition_fn(graph, joint_inputs, **kwargs):
joint_graph_passes(graph)
return min_cut_rematerialization_partition(
graph, joint_inputs, **kwargs, compiler="inductor"
)

# Save and restore dynamic shapes setting for backwards, as it is
# sometimes done as a context manager which won't be set when we
# hit backwards compile
Expand Down Expand Up @@ -713,9 +724,7 @@ def bw_compiler(model: torch.fx.GraphModule, example_inputs):
bw_compiler=bw_compiler,
inference_compiler=inference_compiler,
decompositions=decompositions,
partition_fn=functools.partial(
min_cut_rematerialization_partition, compiler="inductor"
),
partition_fn=partition_fn,
keep_inference_input_mutations=True,
)(model_, example_inputs_)

Expand Down
6 changes: 6 additions & 0 deletions torch/_inductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ def register_decomposition(ops):
return decomp.register_decomposition(ops, decompositions)


@register_decomposition(aten._unsafe_view.default)
def _unsafe_view(self, size):
# this makes pattern matching easier
return self.view(size)


@register_decomposition([aten.clamp])
@pw_cast_for_opmath
def clamp(x, min=None, max=None):
Expand Down
Empty file.