Skip to content

Commit

Permalink
bsr_dense_mm(): better test coverage
Browse files Browse the repository at this point in the history
ghstack-source-id: 71b3ae9c26eb9e974453458af2ac0489cda374c6
Pull Request resolved: #100543
  • Loading branch information
nikitaved committed May 4, 2023
1 parent 2ebb48f commit e0118e3
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 7 deletions.
43 changes: 40 additions & 3 deletions test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3331,7 +3331,7 @@ class TestSparseCompressedTritonKernels(TestCase):
@onlyCUDA
@skipIfRocm
@dtypes(torch.half, torch.bfloat16)
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [])
@dtypesIfCUDA(torch.float, torch.half, *[torch.bfloat16] if SM80OrLater else [])
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
from functools import partial
Expand Down Expand Up @@ -3365,8 +3365,9 @@ def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):

bsr = bsr.to_sparse_bsr(block_size)

if bsr.dim() == 2:
# Test against linear to check dispatch.
if bsr.dim() == 2 and dtype != torch.float:
# Test against linear to check dispatch
# which takes place for torch.half and torch.bfloat16.
res_tri = torch.nn.functional.linear(dense, bsr)
res_dense = torch.nn.functional.linear(dense, bsr.to_dense())

Expand Down Expand Up @@ -3397,6 +3398,42 @@ def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
)
self.assertEqual(res_tri, res_dense)

@onlyCUDA
@skipIfRocm
@dtypes(torch.half)
@dtypesIfCUDA(torch.half)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
def test_triton_bsr_dense_error_messages(self, device, dtype):
from torch.sparse._triton_ops import bsr_dense_mm

rhs = torch.rand(32, 32, dtype=dtype, device=device)
lhs = rhs.to_sparse_bsr(16)
with self.assertRaisesRegex(ValueError, "only BSR sparse format is supported"):
bsr_dense_mm(lhs.to_sparse_bsc(16), rhs)
with self.assertRaisesRegex(ValueError, "on the same GPU device"):
bsr_dense_mm(lhs, rhs.cpu())
if torch.cuda.device_count() > 1:
with self.assertRaisesRegex(ValueError, "on the same GPU device"):
bsr_dense_mm(lhs.to("cuda:0"), rhs.to("cuda:1"))
with self.assertRaisesRegex(ValueError, "all inputs are expected to be of the same dtype"):
bsr_dense_mm(lhs, rhs.to(torch.float))
with self.assertRaisesRegex(ValueError, "and one of \(half, bfloat16, float32\)"):
bsr_dense_mm(lhs.to(torch.double), rhs.to(torch.double))
with self.assertRaisesRegex(ValueError, "all inputs are expected to be at least 2D"):
bsr_dense_mm(lhs, torch.rand(1, dtype=dtype, device=device))
with self.assertRaisesRegex(ValueError, "sizes are not compatible for matrix multiplication"):
bsr_dense_mm(lhs, torch.rand(1, 1, dtype=dtype, device=device))
with self.assertRaisesRegex(ValueError,
"dense.size\(-1\) == 15 should be divisible by blocksize\[0\] == 16"):
bsr_dense_mm(lhs, torch.rand(32, 15, dtype=dtype, device=device))
# Blocksizes check
for blocksize in (15, 30):
n = blocksize * 2
rhs = torch.rand(n, n, dtype=dtype, device=device)
lhs = rhs.to_sparse_bsr(blocksize)
with self.assertRaisesRegex(ValueError, "should be at least 16 and a power of 2"):
bsr_dense_mm(lhs, rhs)


# e.g., TestSparseCSRCPU and TestSparseCSRCUDA
instantiate_device_type_tests(TestSparseCSR, globals())
Expand Down
28 changes: 24 additions & 4 deletions torch/sparse/_triton_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,6 @@ def bsr_dense_mm(
max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
out: Optional[torch.Tensor] = None,
):
m, kl = bsr.shape[-2:]
kr, n = dense.shape[-2:]

def check(cond, msg):
if not cond:
raise ValueError(msg)
Expand Down Expand Up @@ -419,19 +416,42 @@ def check(cond, msg):
f"but got bsr.dim() == {bsr.dim()} and dense.dim() == {dense.dim()}.",
)

m, kl = bsr.shape[-2:]
kr, n = dense.shape[-2:]

check(
kl == kr,
"bsr_dense_mm(): argument sizes are not compatible for matrix multiplication, "
f"got bsr.shape[-1] == {kl} which is not equal to dense.shape[-2] == {kr}.",
)

row_block = bsr.values().shape[-2]
row_block, col_block = bsr.values().shape[-2:]
check(
not n % row_block,
f"bsr_dense_mm(): dense.size(-1) == {n} should be divisible by "
f"blocksize[0] == {row_block}.",
)

def is_power_of_two(v):
return not (v & (v - 1))

def is_compatible_blocksize(b):
assert len(b) == 2
res = True
for blocksize in b:
# Triton loads only blocks which are at least 16 and powers of 2.
res = (blocksize >= 16 and is_power_of_two(blocksize)) and res
return res

check(
is_compatible_blocksize((row_block, col_block)),
f"bsr_dense_mm(): sparse inputs' blocksize ({row_block}, {col_block}) "
"should be at least 16 and a power of 2 in each dimension.",
)
else:
m, kl = bsr.shape[-2:]
kr, n = dense.shape[-2:]

# Required to undo the fake batch dimension insertion.
original_batch_dims_broadcasted = torch.broadcast_shapes(
bsr.shape[:-2], dense.shape[:-2]
Expand Down

0 comments on commit e0118e3

Please sign in to comment.