New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make btriunpack work for high dimensional batches and faster than before #15286
Conversation
Changelog: - Optimize btriunpack by using torch.where instead of indexing, inplace operations and avoiding costly permutations Test plan: - Added tests for btriunpack in test_torch.py (and a port to test_cuda.py)
Failures are unrelated. |
@zou3519 is it possible to get someone to review this? Not high-pri, but some feedback would be helpful. Thanks. |
|
||
@skipIfNoLapack | ||
def test_btriunpack(self): | ||
self._test_btriunpack(self, lambda t: t) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, the pleasure of making your code general, but then not making use of the generality :>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’m sorry, did I do something wrong here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, just observing a quirk of the existing tests.
t = P[i, :, j].clone() | ||
P[i, :, j] = P[i, :, k] | ||
P[i, :, k] = t | ||
P = torch.eye(sz, device=LU_data.device, dtype=LU_data.dtype).expand_as(LU_data).clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we should probably add repeat_as
at some point ;)
If you want to be super cool, copy paste the old implementation into the test suite and do some "reference implementation versus new optimized implementation" tests. |
@ezyang I don't think that is necessary, because the reconstruction takes care of the correctness of the implementation. However, if you insist, I don't mind adding the old-impl vs. new-impl tests as well. |
Nope, sounds good to me |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
…ore (pytorch#15286) Summary: Changelog: - Optimize btriunpack by using `torch.where` instead of indexing, inplace operations instead of out place operations and avoiding costly permutations by computing the final permutation over a list. Pull Request resolved: pytorch#15286 Differential Revision: D13562038 Pulled By: soumith fbshipit-source-id: e2c94cfab5322bf1d24bf56d7b056619f553acc6
…ore (#15286) Summary: Changelog: - Optimize btriunpack by using `torch.where` instead of indexing, inplace operations instead of out place operations and avoiding costly permutations by computing the final permutation over a list. Pull Request resolved: #15286 Differential Revision: D13562038 Pulled By: soumith fbshipit-source-id: e2c94cfab5322bf1d24bf56d7b056619f553acc6
…ore (#15286) Summary: Changelog: - Optimize btriunpack by using `torch.where` instead of indexing, inplace operations instead of out place operations and avoiding costly permutations by computing the final permutation over a list. Pull Request resolved: #15286 Differential Revision: D13562038 Pulled By: soumith fbshipit-source-id: e2c94cfab5322bf1d24bf56d7b056619f553acc6
Changelog:
torch.where
instead of indexing, inplace operations instead of out place operations and avoiding costly permutations by computing the final permutation over a list.Test plan:
This should help unblock testing in #14964 . I created a separate PR so that reviewing can be done efficiently.