-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bsr_dense_mm(): better test coverage #100543
Changes from all commits
690cf24
38733d8
bfbbb6b
2e2aabc
823e56c
99260c8
369cea2
e854171
5abc09b
726e109
bc8424f
817abe2
df00f7d
35c63d7
9d984af
c781dfa
85eb047
1aa9032
ddb3f0a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -431,9 +431,6 @@ def bsr_dense_mm( | |||||
max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None, | ||||||
out: Optional[torch.Tensor] = None, | ||||||
): | ||||||
m, kl = bsr.shape[-2:] | ||||||
kr, n = dense.shape[-2:] | ||||||
|
||||||
def check(cond, msg): | ||||||
if not cond: | ||||||
raise ValueError(msg) | ||||||
|
@@ -463,19 +460,42 @@ def check(cond, msg): | |||||
f"but got bsr.dim() == {bsr.dim()} and dense.dim() == {dense.dim()}.", | ||||||
) | ||||||
|
||||||
m, kl = bsr.shape[-2:] | ||||||
kr, n = dense.shape[-2:] | ||||||
|
||||||
check( | ||||||
kl == kr, | ||||||
"bsr_dense_mm(): argument sizes are not compatible for matrix multiplication, " | ||||||
f"got bsr.shape[-1] == {kl} which is not equal to dense.shape[-2] == {kr}.", | ||||||
) | ||||||
|
||||||
row_block = bsr.values().shape[-2] | ||||||
row_block, col_block = bsr.values().shape[-2:] | ||||||
check( | ||||||
not n % row_block, | ||||||
nikitaved marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
f"bsr_dense_mm(): dense.size(-1) == {n} should be divisible by " | ||||||
f"blocksize[0] == {row_block}.", | ||||||
) | ||||||
|
||||||
def is_power_of_two(v): | ||||||
return not (v & (v - 1)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit (I'm not sure how
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From https://docs.python.org/3/reference/expressions.html#not: "In the context of Boolean operations, and also when expressions are used by control flow statements, the following values are interpreted as false: False, None, numeric zero of all types, and empty strings and containers (including strings, tuples, lists, dictionaries, sets and frozensets). All other values are interpreted as true. User-defined objects can customize their truth value by providing a bool() method." There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My point here, is that, in my mind, mixing logical |
||||||
|
||||||
def is_compatible_blocksize(b): | ||||||
assert len(b) == 2 | ||||||
res = True | ||||||
for blocksize in b: | ||||||
# Triton loads only blocks which are at least 16 and powers of 2. | ||||||
res = (blocksize >= 16 and is_power_of_two(blocksize)) and res | ||||||
return res | ||||||
|
||||||
check( | ||||||
is_compatible_blocksize((row_block, col_block)), | ||||||
f"bsr_dense_mm(): sparse inputs' blocksize ({row_block}, {col_block}) " | ||||||
"should be at least 16 and a power of 2 in each dimension.", | ||||||
) | ||||||
else: | ||||||
m, kl = bsr.shape[-2:] | ||||||
kr, n = dense.shape[-2:] | ||||||
|
||||||
original_batch_dims_broadcasted = broadcast_batch_dims(bsr, dense) | ||||||
|
||||||
if out is not None and not skip_checks: | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should keep all mutli-GPU tests separate (and decorated with
@requiresMultiGPU
or something like that)