-
Notifications
You must be signed in to change notification settings - Fork 25.6k
sampled_addmm: BSR support #101163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
sampled_addmm: BSR support #101163
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/101163
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 8a25f86: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
This PR implements a `sampled_addmm` kernel that works with a BSR mask. [ghstack-poisoned]
This PR implements a `sampled_addmm` kernel that works with a BSR mask. [ghstack-poisoned]
test/test_sparse_csr.py
Outdated
batches = [(), (2,), (2, 2)] | ||
size = [128, 256, 0] | ||
|
||
def sampled_addmm_ref(input, mat1, mat2, alpha, beta): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could another way to implement this be torch.addmm(input.to_dense(), mat1, mat2, alpha, beta).sparse_mask(input.to_sparse()).to_sparse_bsr(input.values().shape[-2:])
?
Sure it's less efficient, but a simpler reference implementation could give us more confidence?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We test against CSR from sparse.sampled_addmm
for more confidence. The proposed solution also only works for non-batched inputs, so, for simplicity, we still need to loop over batches... Alternatively, we can remove the reference and just test against CSR with half promoted to float. Would that be more preffered?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, testing against CSR with float would work too. That's even simpler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit worried about the complexity of the reference addmm implementation just so we can have a simple comparison point.
test/test_sparse_csr.py
Outdated
res_csr = torch.sparse.sampled_addmm(csr, mat1csr, mat2csr, alpha=alpha, beta=beta) | ||
self.assertEqual(res_tri.to_dense(), res_csr.to_dense()) | ||
|
||
# Check grid consistency |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will update the comment to clarify on that.
This PR implements a `sampled_addmm` kernel that works with a BSR mask. [ghstack-poisoned]
This PR implements a `sampled_addmm` kernel that works with a BSR mask. [ghstack-poisoned]
This PR implements a `sampled_addmm` kernel that works with a BSR mask. [ghstack-poisoned]
Looks like there's one lint error left otherwise this is good to go. Thanks for writing this! |
This PR implements a `sampled_addmm` kernel that works with a BSR mask. [ghstack-poisoned]
This PR implements a `sampled_addmm` kernel that works with a BSR mask. [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: #94825 Approved by: https://github.com/albanD, https://github.com/cpuhrsch
This PR implements a
sampled_addmm
kernel that works with a BSR mask.Stack from ghstack (oldest at bottom):