-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve bsr @ strided
performance in baddmm
for bfloat16/half
with Triton kernels.
#88078
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/88078
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 5835817: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
batch_idx, row_idx = nnz_per_row.nonzero(as_tuple=True) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nonzero
could be removed when the number of nonzero rows is high, then skipping empty rows could be delegated to the kernel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nonzero is also one of those ops that will cause a sync if I'm not mistaken.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly, so it makes sense to avoid using it in some circumstances which also addresses your comments below.
a7c503b
to
8a0f7d0
Compare
import triton.language as tl | ||
|
||
|
||
def compressed_indices_to_plain_indices(cidx, pidx): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: is this unused?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not here anymore, right, but it is a very nice util for other kernels we might wanna have for sparse in Triton. We can remove it now...
8a0f7d0
to
5dee33f
Compare
bsr @ strided
performance in addmm
for bfloat16/half
with Triton kernels.bsr @ strided
performance in baddmm
for bfloat16/half
with Triton kernels.
…taved/triton_bsr_dense_mm
…taved/triton_bsr_dense_mm
…taved/triton_bsr_dense_mm
…taved/triton_bsr_dense_mm
@nikitaved it would be good to update the PR description to explain what was fixed since the last revert. |
@malfet , nothing was really fixed, we just disabled tests for CUDA 11.6, the kernel was just hanging there, but new tests showed up that scan native functions including dummy entries to be overwritten with Triton implementations, and this is where things started breaking... |
…taved/triton_bsr_dense_mm
…taved/triton_bsr_dense_mm
…taved/triton_bsr_dense_mm
…taved/triton_bsr_dense_mm
…taved/triton_bsr_dense_mm
…taved/triton_bsr_dense_mm
…taved/triton_bsr_dense_mm
…taved/triton_bsr_dense_mm
…taved/triton_bsr_dense_mm
…taved/triton_bsr_dense_mm
From the offline discussion with @cpuhrsch : we decided to remove all the Cpp hooks for now. |
…taved/triton_bsr_dense_mm
…taved/triton_bsr_dense_mm
…taved/triton_bsr_dense_mm
…taved/triton_bsr_dense_mm
All right, CUDA 11.6 is deprecated. Let's spin this one again. |
Closing in favor of #94823 for more granular issue control. |
As per title.
Additionally we also introduce support for:
dot
limitation).cc @ngimel @alexsamardzic @pearu @cpuhrsch @amjames @bhosmer @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire @VitalyFedyunin