-
Notifications
You must be signed in to change notification settings - Fork 36
[Benchmark] gather_gemv kernel and test #635
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
stack-info: PR: #635, branch: Sibylau/stack/3
f0765bb
to
d64b898
Compare
![]()
![]() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @Sibylau ! left some nit comments, and also might need to rebase to fix conflicts with main branch
examples/gather_gemv.py
Outdated
|
||
def baseline_gather_gemv(w: Tensor, idx: Tensor, x: Tensor) -> Tensor: | ||
"""PyTorch baseline implementation.""" | ||
# A hard-wired fix for tritonbench baseline: w[idx].to(x.dtype) @ x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe can remove this comment
examples/gather_gemv.py
Outdated
for idx_val in idx.tolist(): | ||
outputs.append(w[idx_val].to(x.dtype) @ x) | ||
return torch.stack(outputs, dim=0) | ||
# return torch.stack([w[idx[0]].to(x.dtype) @ x, w[idx[1]].to(x.dtype) @ x]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe remove?
test/test_examples.py
Outdated
args, | ||
expected(*args), | ||
fn_name="gather_gemv", | ||
block_sizes=[64, 64], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not exactly sure why AMD job fails.. but I suspect changing the block sizes to some smaller value might help
@yf225 The CI test on rocm fails due to code mismatch:
Do you know why the generated code for AMD is different? can i put a @skipIfRocm for this kernel test? |
test/test_examples.py
Outdated
) | ||
) | ||
|
||
@skipIfRocm("failure on rocm") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of using skipIfRocm
which skips the whole test including the output equality check, maybe we can add a skip_rocm: bool
arg to def assertExpectedJournal
that skips the journal check if the device is rocm. And to check the device is rocm, we can add something like this to _testing.py
:
def is_rocm() -> bool:
"""Return True if running on ROCm (AMD GPU)."""
return (
triton.runtime.driver.active.get_current_target().backend == "hip"
and DEVICE.type == "cuda"
)
(Please feel free to do this in a follow-up PR. Thanks!)
stack-info: PR: #635, branch: Sibylau/stack/3
78a4c22
to
c8421c3
Compare
Stacked PRs:
[Benchmark] gather_gemv kernel and test