Skip to content

Conversation

nikitaved
Copy link
Collaborator

@nikitaved nikitaved commented May 11, 2023

This PR implements a sampled_addmm kernel that works with a BSR mask.

Stack from ghstack (oldest at bottom):

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented May 11, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/101163

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8a25f86:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: sparse release notes category label May 11, 2023
@nikitaved nikitaved marked this pull request as draft May 11, 2023 07:42
nikitaved added a commit that referenced this pull request May 11, 2023
ghstack-source-id: ffa39d1
Pull Request resolved: #101163
nikitaved added a commit that referenced this pull request May 12, 2023
ghstack-source-id: 3cdcc6f
Pull Request resolved: #101163
nikitaved added a commit that referenced this pull request May 15, 2023
ghstack-source-id: de96103
Pull Request resolved: #101163
This PR implements a `sampled_addmm` kernel that works with a BSR mask.




[ghstack-poisoned]
This PR implements a `sampled_addmm` kernel that works with a BSR mask.




[ghstack-poisoned]
batches = [(), (2,), (2, 2)]
size = [128, 256, 0]

def sampled_addmm_ref(input, mat1, mat2, alpha, beta):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could another way to implement this be torch.addmm(input.to_dense(), mat1, mat2, alpha, beta).sparse_mask(input.to_sparse()).to_sparse_bsr(input.values().shape[-2:])?

Sure it's less efficient, but a simpler reference implementation could give us more confidence?

Copy link
Collaborator Author

@nikitaved nikitaved May 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We test against CSR from sparse.sampled_addmm for more confidence. The proposed solution also only works for non-batched inputs, so, for simplicity, we still need to loop over batches... Alternatively, we can remove the reference and just test against CSR with half promoted to float. Would that be more preffered?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, testing against CSR with float would work too. That's even simpler.

Copy link
Contributor

@cpuhrsch cpuhrsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit worried about the complexity of the reference addmm implementation just so we can have a simple comparison point.

res_csr = torch.sparse.sampled_addmm(csr, mat1csr, mat2csr, alpha=alpha, beta=beta)
self.assertEqual(res_tri.to_dense(), res_csr.to_dense())

# Check grid consistency
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this mean?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will update the comment to clarify on that.

nikitaved added 2 commits May 24, 2023 08:29
This PR implements a `sampled_addmm` kernel that works with a BSR mask.




[ghstack-poisoned]
This PR implements a `sampled_addmm` kernel that works with a BSR mask.




[ghstack-poisoned]
This PR implements a `sampled_addmm` kernel that works with a BSR mask.




[ghstack-poisoned]
@cpuhrsch
Copy link
Contributor

Looks like there's one lint error left otherwise this is good to go. Thanks for writing this!

nikitaved added 2 commits May 25, 2023 08:31
This PR implements a `sampled_addmm` kernel that works with a BSR mask.




[ghstack-poisoned]
This PR implements a `sampled_addmm` kernel that works with a BSR mask.




[ghstack-poisoned]
@nikitaved
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: sparse release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants