Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDPA: frontend for BSR masks #104042

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 43 additions & 0 deletions test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3560,6 +3560,49 @@ def test_triton_bsr_dense_bmm_error_messages(self, device, dtype):
out = torch.rand(32, 32, 2, dtype=dtype, device=device).transpose(0, -1)
bsr_dense_mm(lhs, rhs, out=out)

@parametrize("block_size", [16, 32, 64])
@onlyCUDA
@skipIfRocm
@dtypes(torch.half, torch.bfloat16, torch.float)
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
@precisionOverride({torch.float16: 1e-3})
def test_triton_scaled_dot_product_attention(self, device, dtype, block_size):
from functools import partial
from torch.sparse._triton_ops import _scaled_dot_product_attention

# Note that each value in a non-zero block is in range block_size * [low^2, high^2).
tensor = partial(make_tensor, device=device, dtype=dtype, low=0.3, high=1.2)

def broadcast_input(*ts):
batch_dims = torch.broadcast_shapes(*(t.shape[:-2] for t in ts))
yield from (torch.broadcast_to(t, batch_dims + t.shape[-2:]) for t in ts)

# NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
batches = [(), (2,), (2, 2)]
size = [128, 256, 0]

for bam, bq, bk, bv, m, n, k in itertools.product(batches, batches, batches, batches, size, size, size):
query = tensor(bq + (m, k))
key = tensor(bk + (n, k))
value = tensor(bv + (n, k))

# We make attn_mask block lower/upper triangular so that BSR and Strided
# function variants are directly comparable.
# NOTE: only boolean mask is directly compatible with the Strided version
# without any pre-/post-processing.
attn_mask = torch.ones(bam + (m, n), device=device, dtype=torch.bool)
attn_mask = self._to_block_triangular_inplace(attn_mask, block_size, block_size)
attn_mask_bsr = attn_mask.to_sparse_bsr(block_size)
amjames marked this conversation as resolved.
Show resolved Hide resolved

expected = torch.nn.functional.scaled_dot_product_attention(
drisspg marked this conversation as resolved.
Show resolved Hide resolved
*broadcast_input(query, key, value, attn_mask)
)
res = _scaled_dot_product_attention(query, key, value, attn_mask.to_sparse_bsr(block_size))

self.assertEqual(res, expected)


@parametrize("block_size", [16, 32, 64])
@onlyCUDA
@skipIfRocm
Expand Down
61 changes: 53 additions & 8 deletions torch/sparse/_triton_ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import torch
from torch._inductor.cuda_properties import get_device_capability

Expand Down Expand Up @@ -49,12 +50,12 @@
)


def check_dtype(f_name, t, dtype):
def check_dtype(f_name, t, dtype, *additional_dtypes):
check(
t.dtype == dtype
and t.dtype in (torch.half, torch.bfloat16, torch.float),
and t.dtype in ((torch.half, torch.bfloat16, torch.float) + tuple(*additional_dtypes)),
f"{f_name}(): all inputs are expected to be of the same dtype "
"and one of (half, bfloat16, float32), "
f"and one of (half, bfloat16, float32) or {additional_dtypes}, "
f"but got dtype == {t.dtype}.",
)

Expand Down Expand Up @@ -563,18 +564,26 @@
):
f_name = "sampled_addmm"

check_bsr_layout(f_name, input)
input_broadcasted = broadcast_batch_dims_bsr(f_name, input, mat1, mat2)

if not skip_checks:
check_bsr_layout(f_name, input)
check_device(f_name, mat1, input.device)
check_device(f_name, mat2, input.device)
check_dtype(f_name, mat1, input.dtype)
check_dtype(f_name, mat2, input.dtype)
if beta != 0.0 and input.dtype is torch.bool:
check(
False,
f"{f_name}(): having beta == {beta} not equal to 0.0 with boolean mask is not allowed."
)
if input.dtype is not torch.bool:
check_dtype(f_name, mat1, input.dtype)
check_dtype(f_name, mat2, input.dtype)
else:
check_dtype(f_name, mat1, mat2.dtype)
check_mm_compatible_shapes(f_name, mat1, mat2)
if out is not None:
check_bsr_layout(f_name, out)
check_device(f_name, out, input.device)
check_device(f_name, out, mat1.device)
check_dtype(f_name, out, input.dtype)
check(
out.shape == input_broadcasted.shape
Expand All @@ -585,7 +594,7 @@
)

if out is None:
out = input_broadcasted.clone()
out = input_broadcasted.to(mat1.dtype, copy=True)
else:
out.copy_(input_broadcasted)

Expand Down Expand Up @@ -835,7 +844,43 @@
size=input.shape,
layout=input.layout
)
def _scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor],
dropout_p: float = 0.0,
is_causal: bool = False
nikitaved marked this conversation as resolved.
Show resolved Hide resolved
):
f_name = "_scaled_dot_product_attention"
check(
not is_causal,
f"{f_name}(): is_causal == True is not supported."
amjames marked this conversation as resolved.
Show resolved Hide resolved
)
check(
attn_mask is not None and attn_mask.layout == torch.sparse_bsr,
nikitaved marked this conversation as resolved.
Show resolved Hide resolved
f"{f_name}(): attn_mask == None is not supported and "

Check failure on line 862 in torch/sparse/_triton_ops.py

View workflow job for this annotation

GitHub Actions / lintrunner / linux-job

MYPY [union-attr]

Item "None" of "Optional[Tensor]" has no attribute "layout"
f"attn_mask.layout must be {torch.sparse_bsr}, but got "
f"attn_mask.layout == {attn_mask.layout}"
)

check_device(f_name, key, query.device)
check_device(f_name, value, query.device)
check_device(f_name, attn_mask, query.device)

check_dtype(f_name, key, query.dtype)
check_dtype(f_name, value, query.dtype)
if attn_mask.dtype is not torch.bool:

Check failure on line 873 in torch/sparse/_triton_ops.py

View workflow job for this annotation

GitHub Actions / lintrunner / linux-job

MYPY [union-attr]

Item "None" of "Optional[Tensor]" has no attribute "dtype"
check_dtype(f_name, attn_mask, query.dtype)

sdpa = sampled_addmm(attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False)

Check failure on line 876 in torch/sparse/_triton_ops.py

View workflow job for this annotation

GitHub Actions / lintrunner / linux-job

MYPY [arg-type]

Argument 1 to "sampled_addmm" has incompatible type "Optional[Tensor]"; expected "Tensor"
sdpa.values().div_(math.sqrt(query.size(-1)))
sdpa = bsr_softmax(sdpa)
torch.nn.functional.dropout(sdpa.values(), p=dropout_p, inplace=True)
sdpa = bsr_dense_mm(sdpa, value)
return sdpa
else:
bsr_softmax = None # type: ignore[assignment]
bsr_dense_mm = None # type: ignore[assignment]
sampled_addmm = None # type: ignore[assignment]
_scaled_dot_product_attention = None # type: ignore[assignment]