Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bsr_dense_mm(): better test coverage #100543

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 47 additions & 2 deletions test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3365,8 +3365,9 @@ def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):

bsr = bsr.to_sparse_bsr(block_size)

if bsr.dim() == 2:
# Test against linear to check dispatch.
if bsr.dim() == 2 and dtype != torch.float:
# Test against linear to check dispatch
# which takes place for torch.half and torch.bfloat16.
res_tri = torch.nn.functional.linear(dense, bsr)
res_dense = torch.nn.functional.linear(dense, bsr.to_dense())

Expand Down Expand Up @@ -3397,6 +3398,50 @@ def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
)
self.assertEqual(res_tri, res_dense)

@onlyCUDA
@skipIfRocm
@dtypes(torch.half)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
def test_triton_bsr_dense_bmm_error_messages(self, device, dtype):
from torch.sparse._triton_ops import bsr_dense_mm

rhs = torch.rand(32, 32, dtype=dtype, device=device)
lhs = rhs.to_sparse_bsr(16)
with self.assertRaisesRegex(ValueError, "only BSR sparse format is supported"):
bsr_dense_mm(lhs.to_sparse_bsc(16), rhs)
with self.assertRaisesRegex(ValueError, "on the same GPU device"):
bsr_dense_mm(lhs, rhs.cpu())
if torch.cuda.device_count() > 1:
with self.assertRaisesRegex(ValueError, "on the same GPU device"):
bsr_dense_mm(lhs.to("cuda:0"), rhs.to("cuda:1"))
Comment on lines +3414 to +3416
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should keep all mutli-GPU tests separate (and decorated with @requiresMultiGPU or something like that)

with self.assertRaisesRegex(ValueError, "all inputs are expected to be of the same dtype"):
bsr_dense_mm(lhs, rhs.to(torch.float))
with self.assertRaisesRegex(ValueError, r"and one of \(half, bfloat16, float32\)"):
bsr_dense_mm(lhs.to(torch.double), rhs.to(torch.double))
with self.assertRaisesRegex(ValueError, "all inputs are expected to be at least 2D"):
bsr_dense_mm(lhs, torch.rand(1, dtype=dtype, device=device))
with self.assertRaisesRegex(ValueError, "sizes are not compatible for matrix multiplication"):
bsr_dense_mm(lhs, torch.rand(1, 1, dtype=dtype, device=device))
with self.assertRaisesRegex(ValueError,
r"dense.size\(-1\) == 15 should be divisible by blocksize\[0\] == 16"):
bsr_dense_mm(lhs, torch.rand(32, 15, dtype=dtype, device=device))
# Blocksizes check
for blocksize in (15, 30):
n = blocksize * 2
rhs = torch.rand(n, n, dtype=dtype, device=device)
lhs = rhs.to_sparse_bsr(blocksize)
with self.assertRaisesRegex(ValueError, "should be at least 16 and a power of 2"):
bsr_dense_mm(lhs, rhs)
# out check
rhs = torch.rand(2, 32, 32, dtype=dtype, device=device)
lhs = rhs.to_sparse_bsr(16)
with self.assertRaisesRegex(ValueError, r"`out` argument has wrong shape"):
out = torch.rand(2, 30, 30, dtype=dtype, device=device)
bsr_dense_mm(lhs, rhs, out=out)
with self.assertRaisesRegex(ValueError, r"only row-major/col-major `out`"):
out = torch.rand(32, 32, 2, dtype=dtype, device=device).transpose(0, -1)
bsr_dense_mm(lhs, rhs, out=out)


# e.g., TestSparseCSRCPU and TestSparseCSRCUDA
instantiate_device_type_tests(TestSparseCSR, globals())
Expand Down
28 changes: 24 additions & 4 deletions torch/sparse/_triton_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,6 @@ def bsr_dense_mm(
max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
out: Optional[torch.Tensor] = None,
):
m, kl = bsr.shape[-2:]
kr, n = dense.shape[-2:]

def check(cond, msg):
if not cond:
raise ValueError(msg)
Expand Down Expand Up @@ -463,19 +460,42 @@ def check(cond, msg):
f"but got bsr.dim() == {bsr.dim()} and dense.dim() == {dense.dim()}.",
)

m, kl = bsr.shape[-2:]
kr, n = dense.shape[-2:]

check(
kl == kr,
"bsr_dense_mm(): argument sizes are not compatible for matrix multiplication, "
f"got bsr.shape[-1] == {kl} which is not equal to dense.shape[-2] == {kr}.",
)

row_block = bsr.values().shape[-2]
row_block, col_block = bsr.values().shape[-2:]
check(
not n % row_block,
nikitaved marked this conversation as resolved.
Show resolved Hide resolved
f"bsr_dense_mm(): dense.size(-1) == {n} should be divisible by "
f"blocksize[0] == {row_block}.",
)

def is_power_of_two(v):
return not (v & (v - 1))
Copy link
Contributor

@malfet malfet May 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit (I'm not sure how not is defined for integers in Python, perhaps you can send a link

Suggested change
return not (v & (v - 1))
return v & (v - 1) == 0

Copy link
Collaborator Author

@nikitaved nikitaved May 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From https://docs.python.org/3/reference/expressions.html#not:

"In the context of Boolean operations, and also when expressions are used by control flow statements, the following values are interpreted as false: False, None, numeric zero of all types, and empty strings and containers (including strings, tuples, lists, dictionaries, sets and frozensets). All other values are interpreted as true. User-defined objects can customize their truth value by providing a bool() method."

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point here, is that, in my mind, mixing logical and and boolean not hurts readability, on the other hand, it's up to an author/maintainer of the codebase. I.e. as long as @cpuhrsch is fine reading that, I'm ok as well.


def is_compatible_blocksize(b):
assert len(b) == 2
res = True
for blocksize in b:
# Triton loads only blocks which are at least 16 and powers of 2.
res = (blocksize >= 16 and is_power_of_two(blocksize)) and res
return res

check(
is_compatible_blocksize((row_block, col_block)),
f"bsr_dense_mm(): sparse inputs' blocksize ({row_block}, {col_block}) "
"should be at least 16 and a power of 2 in each dimension.",
)
else:
m, kl = bsr.shape[-2:]
kr, n = dense.shape[-2:]

original_batch_dims_broadcasted = broadcast_batch_dims(bsr, dense)

if out is not None and not skip_checks:
Expand Down