Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Triton kernel for bsr @ dense #94823

Closed
wants to merge 11 commits into from
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ ignore_errors = True
# Third party dependencies that don't have types.
#

[mypy-triton.*]
ignore_missing_imports = True

[mypy-tensorflow.*]
ignore_missing_imports = True

Expand Down
54 changes: 54 additions & 0 deletions test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,6 +1462,60 @@ def run_test_block_addmm_addmv(self,
self.assertEqual(actual, out)
self.assertEqual(actual, expected)

@parametrize("block_size", [16, 32, 64])
@parametrize("index_dtype", [torch.int32, torch.int64])
@skipCUDAIfRocm
nikitaved marked this conversation as resolved.
Show resolved Hide resolved
@onlyCUDA
@dtypes(torch.half, torch.bfloat16)
@dtypesIfCUDA(*[torch.half] if SM53OrLater else [],
nikitaved marked this conversation as resolved.
Show resolved Hide resolved
*[torch.bfloat16] if SM80OrLater else [])
def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
nikitaved marked this conversation as resolved.
Show resolved Hide resolved
from functools import partial

from torch._inductor.utils import has_triton
from torch.sparse._triton_ops import bsr_dense_mm

if not has_triton():
self.skipTest("Triton is not available.")
nikitaved marked this conversation as resolved.
Show resolved Hide resolved

# Note that each value in a non-zero block is in range block_size * [low^2, high^2).
tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)

# NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
batches = [(), (2,)]
size = [128, 256, 0]

# Whether to make inputs orthogonal so that the product is zero
make_orthogonal = [True, False]

for bd, bs, m, n, k, is_ortho in itertools.product(batches, batches, size, size, size, make_orthogonal):
bsr = tensor(bs + (m, k))
# NOTE: do not get confused, it will be transposed
dense = tensor(bd + (n, k))

if is_ortho:
bsr = torch.cat((bsr, torch.zeros_like(bsr)), dim=-1)
dense = torch.cat((torch.zeros_like(dense), dense), dim=-1)

bsr = bsr.to_sparse_bsr(block_size)

res_tri = bsr_dense_mm(bsr, dense.transpose(-2, -1))
res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
self.assertEqual(res_tri, res_dense)

# check whether bsr_dense_mm handles different grid sizes
# None means max possible grid size which is CUDA-dependent.
grid_size = (None, 2, 4)
grid_gen = itertools.product(grid_size, repeat=3)
for is_sparse_rowspace, grid in itertools.product((True, False), grid_gen):
res_tri = bsr_dense_mm(
bsr,
dense.transpose(-2, -1),
max_grid=grid,
is_sparse_rowspace_mode=is_sparse_rowspace
)
self.assertEqual(res_tri, res_dense)

# TODO: block_size 1 is broken
@parametrize("block_size", [2, 3])
@parametrize("index_dtype", [torch.int32, torch.int64])
Expand Down