Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve bsr @ strided performance in baddmm for bfloat16/half with Triton kernels. #88078

Closed
wants to merge 116 commits into from

Conversation

nikitaved
Copy link
Collaborator

@nikitaved nikitaved commented Oct 31, 2022

As per title.

Additionally we also introduce support for:

  • Rectangular block sizes which are powers of 2 and at least 16 (triton's dot limitation).
  • Batch support with broadcasting for either of the arguments.

cc @ngimel @alexsamardzic @pearu @cpuhrsch @amjames @bhosmer @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire @VitalyFedyunin

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 31, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/88078

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 5835817:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: sparse release notes category label Oct 31, 2022
@nikitaved nikitaved added module: performance Issues related to performance, either of kernel code or framework glue module: sparse Related to torch.sparse module: half Related to float16 half-precision floats and removed release notes: sparse release notes category labels Oct 31, 2022
@pytorch-bot pytorch-bot bot added the release notes: sparse release notes category label Oct 31, 2022
Comment on lines 187 to 222
batch_idx, row_idx = nnz_per_row.nonzero(as_tuple=True)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nonzero could be removed when the number of nonzero rows is high, then skipping empty rows could be delegated to the kernel.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nonzero is also one of those ops that will cause a sync if I'm not mistaken.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly, so it makes sense to avoid using it in some circumstances which also addresses your comments below.

import triton.language as tl


def compressed_indices_to_plain_indices(cidx, pidx):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is this unused?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not here anymore, right, but it is a very nice util for other kernels we might wanna have for sparse in Triton. We can remove it now...

@nikitaved nikitaved force-pushed the nikitaved/triton_bsr_dense_mm branch from 8a0f7d0 to 5dee33f Compare November 4, 2022 14:53
@nikitaved nikitaved changed the title Improve bsr @ strided performance in addmm for bfloat16/half with Triton kernels. Improve bsr @ strided performance in baddmm for bfloat16/half with Triton kernels. Nov 4, 2022
test/test_sparse_csr.py Outdated Show resolved Hide resolved
torch/sparse/triton_ops/triton_bsr_dense_mm.py Outdated Show resolved Hide resolved
@nikitaved nikitaved reopened this Jan 26, 2023
@malfet
Copy link
Contributor

malfet commented Jan 26, 2023

@nikitaved it would be good to update the PR description to explain what was fixed since the last revert.
Also, I wonder why "revertedx2" label is not there..

@malfet malfet requested a review from cpuhrsch January 26, 2023 17:05
@nikitaved
Copy link
Collaborator Author

nikitaved commented Jan 26, 2023

@malfet , nothing was really fixed, we just disabled tests for CUDA 11.6, the kernel was just hanging there, but new tests showed up that scan native functions including dummy entries to be overwritten with Triton implementations, and this is where things started breaking...

test/test_sparse_csr.py Outdated Show resolved Hide resolved
@nikitaved
Copy link
Collaborator Author

nikitaved commented Jan 31, 2023

From the offline discussion with @cpuhrsch : we decided to remove all the Cpp hooks for now. test_decomp issues are real (through introducing a native function), but they manifest themselves after a very long time spinning the tests. We will try to investigate these issues in a follow-up PR.

@nikitaved
Copy link
Collaborator Author

All right, CUDA 11.6 is deprecated. Let's spin this one again.

@nikitaved
Copy link
Collaborator Author

Closing in favor of #94823 for more granular issue control.

@nikitaved nikitaved closed this Feb 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo module: half Related to float16 half-precision floats module: performance Issues related to performance, either of kernel code or framework glue module: sparse Related to torch.sparse open source release notes: sparse release notes category Reverted triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet