Skip to content

Commit

Permalink
Triton kernel for bsr @ dense
Browse files Browse the repository at this point in the history
ghstack-source-id: 6c26e3f6fc357d1cc659d3f023c6bc121c824111
Pull Request resolved: #94823
  • Loading branch information
nikitaved committed Feb 14, 2023
1 parent 94f0808 commit a909e9a
Show file tree
Hide file tree
Showing 3 changed files with 665 additions and 0 deletions.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ ignore_errors = True
# Third party dependencies that don't have types.
#

[mypy-triton.*]
ignore_missing_imports = True

[mypy-tensorflow.*]
ignore_missing_imports = True

Expand Down
54 changes: 54 additions & 0 deletions test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,6 +1462,60 @@ def run_test_block_addmm_addmv(self,
self.assertEqual(actual, out)
self.assertEqual(actual, expected)

@parametrize("block_size", [16, 32, 64])
@parametrize("index_dtype", [torch.int32, torch.int64])
@skipCUDAIfRocm
@onlyCUDA
@dtypes(torch.half, torch.bfloat16)
@dtypesIfCUDA(*[torch.half] if SM53OrLater else [],
*[torch.bfloat16] if SM80OrLater else [])
def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
from functools import partial

from torch._inductor.utils import has_triton
from torch.sparse._triton_ops import bsr_dense_mm

if not has_triton():
self.skipTest("Triton is not available.")

# Note that each value in a non-zero block is in range block_size * [low^2, high^2).
tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)

# NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
batches = [(), (2,)]
size = [128, 256, 0]

# Whether to make inputs orthogonal so that the product is zero
make_orthogonal = [True, False]

for bd, bs, m, n, k, is_ortho in itertools.product(batches, batches, size, size, size, make_orthogonal):
bsr = tensor(bs + (m, k))
# NOTE: do not get confused, it will be transposed
dense = tensor(bd + (n, k))

if is_ortho:
bsr = torch.cat((bsr, torch.zeros_like(bsr)), dim=-1)
dense = torch.cat((torch.zeros_like(dense), dense), dim=-1)

bsr = bsr.to_sparse_bsr(block_size)

res_tri = bsr_dense_mm(bsr, dense.transpose(-2, -1))
res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
self.assertEqual(res_tri, res_dense)

# check whether bsr_dense_mm handles different grid sizes
# None means max possible grid size which is CUDA-dependent.
grid_size = (None, 2, 4)
grid_gen = itertools.product(grid_size, repeat=3)
for is_sparse_rowspace, grid in itertools.product((True, False), grid_gen):
res_tri = bsr_dense_mm(
bsr,
dense.transpose(-2, -1),
max_grid=grid,
is_sparse_rowspace_mode=is_sparse_rowspace
)
self.assertEqual(res_tri, res_dense)

# TODO: block_size 1 is broken
@parametrize("block_size", [2, 3])
@parametrize("index_dtype", [torch.int32, torch.int64])
Expand Down

0 comments on commit a909e9a

Please sign in to comment.