Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Triton kernel for bsr @ dense #94823

Closed
wants to merge 11 commits into from
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ ignore_errors = True
# Third party dependencies that don't have types.
#

[mypy-triton.*]
ignore_missing_imports = True

[mypy-tensorflow.*]
ignore_missing_imports = True

Expand Down
56 changes: 55 additions & 1 deletion test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, TEST_CUSPARSE_GENERIC
from torch.testing._internal.common_utils import \
(TEST_WITH_ROCM, TEST_SCIPY, TEST_NUMPY, TEST_MKL, IS_WINDOWS, TestCase, run_tests, load_tests, coalescedonoff, parametrize,
subtest, skipIfTorchDynamo)
subtest, skipIfTorchDynamo, skipIfRocm, IS_FBCODE, IS_REMOTE_GPU)
from torch.testing._internal.common_device_type import \
(ops, instantiate_device_type_tests, dtypes, OpDTypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoSparseGeneric,
precisionOverride, skipMeta, skipCUDAIf, skipCUDAIfRocm, skipCPUIfNoMklSparse, skipCUDAIfRocmVersionLessThan)
Expand Down Expand Up @@ -1462,6 +1462,60 @@ def run_test_block_addmm_addmv(self,
self.assertEqual(actual, out)
self.assertEqual(actual, expected)

@parametrize("block_size", [16, 32, 64])
@parametrize("index_dtype", [torch.int32, torch.int64])
@onlyCUDA
@skipIfRocm
@dtypes(torch.half, torch.bfloat16)
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [])
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
nikitaved marked this conversation as resolved.
Show resolved Hide resolved
from functools import partial

from torch._inductor.utils import has_triton
from torch.sparse._triton_ops import bsr_dense_mm

if not has_triton():
self.skipTest("Triton is not available.")
nikitaved marked this conversation as resolved.
Show resolved Hide resolved

# Note that each value in a non-zero block is in range block_size * [low^2, high^2).
tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)

# NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
batches = [(), (2,)]
size = [128, 256, 0]

# Whether to make inputs orthogonal so that the product is zero
make_orthogonal = [True, False]

for bd, bs, m, n, k, is_ortho in itertools.product(batches, batches, size, size, size, make_orthogonal):
bsr = tensor(bs + (m, k))
# NOTE: do not get confused, it will be transposed
dense = tensor(bd + (n, k))

if is_ortho:
bsr = torch.cat((bsr, torch.zeros_like(bsr)), dim=-1)
dense = torch.cat((torch.zeros_like(dense), dense), dim=-1)

bsr = bsr.to_sparse_bsr(block_size)

res_tri = bsr_dense_mm(bsr, dense.transpose(-2, -1))
res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
self.assertEqual(res_tri, res_dense)

# check whether bsr_dense_mm handles different grid sizes
# None means max possible grid size which is CUDA-dependent.
grid_size = (None, 2, 4)
grid_gen = itertools.product(grid_size, repeat=3)
for is_sparse_rowspace, grid in itertools.product((True, False), grid_gen):
res_tri = bsr_dense_mm(
bsr,
dense.transpose(-2, -1),
max_grid=grid,
is_sparse_rowspace_mode=is_sparse_rowspace
)
self.assertEqual(res_tri, res_dense)

# TODO: block_size 1 is broken
@parametrize("block_size", [2, 3])
@parametrize("index_dtype", [torch.int32, torch.int64])
Expand Down