Skip to content

Commit

Permalink
Add common used score_mod functions for templated attention (#124670)
Browse files Browse the repository at this point in the history
Fixes #ISSUE_NUMBER

Pull Request resolved: #124670
Approved by: https://github.com/Chillee
  • Loading branch information
yanboliang authored and pytorchmergebot committed Apr 27, 2024
1 parent df08140 commit 7478b7f
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 67 deletions.
86 changes: 24 additions & 62 deletions test/inductor/test_templated_attention.py
Expand Up @@ -13,7 +13,15 @@
)
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.utils import run_and_get_code
from torch.nn.attention._templated_attention import _compose, _templated_attention
from torch.nn.attention._templated_attention import (
_causal,
_compose,
_generate_alibi_bias,
_identity,
_rel_bias,
_rel_causal,
_templated_attention,
)
from torch.testing import FileCheck
from torch.testing._internal import common_utils
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16
Expand Down Expand Up @@ -48,9 +56,13 @@ def create_attention(score_mod):
if common_utils.TEST_WITH_ROCM:
test_dtypes = [torch.float32]


def _identity_mod(score, b, h, m, n):
return score
test_score_mods = [
_identity,
_causal,
_rel_bias,
_rel_causal,
_generate_alibi_bias(8),
]


def _causal_mod(score, b, h, token_q, token_kv):
Expand Down Expand Up @@ -90,58 +102,8 @@ def run_test(self, score_mod: Callable, dtype: torch.dtype = torch.float16):

@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_identity(self, dtype: torch.dtype):
def score_mod(score, b, h, m, n):
return score

self.run_test(score_mod, dtype)

@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_causal_mask(self, dtype: torch.dtype):
def score_mod(score, b, h, token_q, token_kv):
return torch.where(token_q >= token_kv, score, float("-inf"))

self.run_test(score_mod, dtype)

@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_rel_bias(self, dtype: torch.dtype):
def score_mod(score, b, h, m, n):
return score + (m - n)

self.run_test(score_mod, dtype)

@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_alibi_bias(self, dtype: torch.dtype):
def score_mod(score, b, h, m, n):
return score + (m - n) * h

self.run_test(score_mod, dtype)

@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_rel_causal(self, dtype: torch.dtype):
def score_mod(score, b, h, m, n):
return torch.where(m <= n, score + (m - n), float("-inf"))

self.run_test(score_mod, dtype)

@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_skip_odd_keys(self, dtype: torch.dtype):
def score_mod(score, b, h, q, kv):
return torch.where(kv % 2 == 0, score, float("-inf"))

self.run_test(score_mod, dtype)

@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_alibi_causal(self, dtype: torch.dtype):
def score_mod(score, b, h, m, n):
return torch.where(m <= n, score + (m - n) * h, float("-inf"))

@common_utils.parametrize("score_mod", test_score_mods)
def test_builtin_score_mods(self, dtype: torch.dtype, score_mod: Callable):
self.run_test(score_mod, dtype)

@supported_platform
Expand Down Expand Up @@ -302,7 +264,7 @@ def test_backwards_fails(self):
requires_grad=True,
)
q, k, v = make_tensor(), make_tensor(), make_tensor()
out = _templated_attention(q, k, v, _identity_mod)
out = _templated_attention(q, k, v, _identity)
with self.assertRaisesRegex(
RuntimeError, "Autograd not implemented for templated_attention"
):
Expand All @@ -316,15 +278,15 @@ def test_mixed_dtypes_fails(self):
with self.assertRaisesRegex(
ValueError, "Expected query, key, and value to have the same dtype"
):
_templated_attention(query, key, value, _identity_mod)
_templated_attention(query, key, value, _identity)

@supported_platform
def test_different_sequence_length_fails(self):
query = torch.randn((1, 1, 2048, 64), dtype=torch.float32, device="cuda")
key = torch.randn((1, 1, 1024, 64), dtype=torch.float32, device="cuda")
value = torch.randn((1, 1, 1024, 64), dtype=torch.float32, device="cuda")
with self.assertRaisesRegex(ValueError, "NYI: The target sequence length"):
_templated_attention(query, key, value, _identity_mod)
_templated_attention(query, key, value, _identity)

@supported_platform
@patch.object(torch._inductor.config, "max_autotune", True)
Expand All @@ -351,7 +313,7 @@ def bias_mod(score, batch, head, token_q, token_kv):

@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
@common_utils.parametrize("score_mod", [_identity_mod, _causal_mod])
@common_utils.parametrize("score_mod", [_identity, _causal])
def test_logsumexp_correctness(self, dtype, score_mod):
@torch.compile
def sdpa_hop(q, k, v, score_mod):
Expand Down Expand Up @@ -414,7 +376,7 @@ def func(q, k, v, score_mod):
lse_2 = lse * 2
return lse_2

_, code = run_and_get_code(func, q, k, v, _identity_mod)
_, code = run_and_get_code(func, q, k, v, _identity)
# Ensure that two kernels are generated
FileCheck().check_count(".run(", 2, True).run(code[0])

Expand All @@ -435,7 +397,7 @@ def func(q, k, v, score_mod):
lse_2 = lse * 2
return out, lse_2

_, code = run_and_get_code(func, q, k, v, _identity_mod)
_, code = run_and_get_code(func, q, k, v, _identity)
# Ensure that two kernels are generated
FileCheck().check_count(".run(", 2, True).run(code[0])

Expand Down
25 changes: 20 additions & 5 deletions torch/_inductor/kernel/templated_attention.py
Expand Up @@ -3,7 +3,7 @@
from typing import Any, List

import torch
from .. import config
from .. import config, utils
from ..lowering import empty_strided, lowerings, register_lowering
from ..select_algorithm import autotune_select_algorithm, TritonTemplate

Expand Down Expand Up @@ -173,6 +173,24 @@ def sdpa_grid(batch_size, num_heads, num_queries, d_model, meta):
)


def _get_default_config(query):
default_config = None
is_big_shared_mem = utils.get_gpu_shared_memory() > 128 * 1024

if is_big_shared_mem:
if query.get_dtype() == torch.float32:
default_config = (64, 64, 4, 3)
else:
default_config = (128, 64, 4, 3)
else:
if query.get_dtype() == torch.float32:
default_config = (32, 32, 4, 3)
else:
default_config = (64, 32, 4, 3)

return default_config


# TODO: We probably also need a layout constraint?
@register_lowering(torch.ops.higher_order.templated_attention, type_promotion_kind=None)
def templated_attention(*args, **kwargs):
Expand Down Expand Up @@ -274,10 +292,7 @@ def create_placeholder(name: str, dtype: torch.dtype) -> InputBuffer:
)
choices: List[Any] = []
configs: List[Any] = []
if query.get_dtype() == torch.float32:
configs.append((64, 64, 4, 3))
else:
configs.append((128, 64, 4, 3))
configs.append(_get_default_config(query))
if config.max_autotune:
configs += [
(128, 64, 4, 3),
Expand Down
6 changes: 6 additions & 0 deletions torch/_inductor/utils.py
Expand Up @@ -1197,6 +1197,12 @@ def get_gpu_dram_gbps():
return get_dram_gbps()


def get_gpu_shared_memory():
from triton.runtime import driver

return driver.active.utils.get_device_properties(0).get("max_shared_mem", 0)


def is_welford_reduction(reduction_type):
return reduction_type.startswith("welford")

Expand Down
57 changes: 57 additions & 0 deletions torch/nn/attention/_templated_attention.py
Expand Up @@ -90,3 +90,60 @@ def score_mod(

# Drop the logsumexp value since this is only needed for backwards
return out


"""Some common used score_mod functions for templated attention in PyTorch."""


def _identity(
score: torch.Tensor,
batch: torch.Tensor,
head: torch.Tensor,
token_q: torch.Tensor,
token_kv: torch.Tensor,
) -> torch.Tensor:
return score


def _causal(
score: torch.Tensor,
batch: torch.Tensor,
head: torch.Tensor,
token_q: torch.Tensor,
token_kv: torch.Tensor,
) -> torch.Tensor:
return torch.where(token_q >= token_kv, score, float("-inf"))


def _rel_bias(
score: torch.Tensor,
batch: torch.Tensor,
head: torch.Tensor,
token_q: torch.Tensor,
token_kv: torch.Tensor,
) -> torch.Tensor:
return score + (token_q - token_kv)


def _rel_causal(
score: torch.Tensor,
batch: torch.Tensor,
head: torch.Tensor,
token_q: torch.Tensor,
token_kv: torch.Tensor,
) -> torch.Tensor:
return torch.where(token_q <= token_kv, score + (token_q - token_kv), float("-inf"))


def _generate_alibi_bias(num_heads: int):
def _alibi_bias(
score: torch.Tensor,
batch: torch.Tensor,
head: torch.Tensor,
token_q: torch.Tensor,
token_kv: torch.Tensor,
) -> torch.Tensor:
scale = torch.exp2(-((head + 1) * 8.0 / num_heads))
return score + (token_kv - token_q) * scale

return _alibi_bias

0 comments on commit 7478b7f

Please sign in to comment.