Navigation Menu

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make btriunpack work for high dimensional batches and faster than before #15286

Closed

Conversation

vishwakftw
Copy link
Contributor

Changelog:

  • Optimize btriunpack by using torch.where instead of indexing, inplace operations instead of out place operations and avoiding costly permutations by computing the final permutation over a list.

Test plan:

  • Added tests for btriunpack in test_torch.py (and a port to test_cuda.py)

This should help unblock testing in #14964 . I created a separate PR so that reviewing can be done efficiently.

Changelog:
- Optimize btriunpack by using torch.where instead of indexing, inplace operations and avoiding costly permutations

Test plan:
- Added tests for btriunpack in test_torch.py (and a port to test_cuda.py)
@vishwakftw
Copy link
Contributor Author

Failures are unrelated.

@vishwakftw
Copy link
Contributor Author

@zou3519 is it possible to get someone to review this? Not high-pri, but some feedback would be helpful. Thanks.


@skipIfNoLapack
def test_btriunpack(self):
self._test_btriunpack(self, lambda t: t)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, the pleasure of making your code general, but then not making use of the generality :>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m sorry, did I do something wrong here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, just observing a quirk of the existing tests.

t = P[i, :, j].clone()
P[i, :, j] = P[i, :, k]
P[i, :, k] = t
P = torch.eye(sz, device=LU_data.device, dtype=LU_data.dtype).expand_as(LU_data).clone()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we should probably add repeat_as at some point ;)

torch/functional.py Outdated Show resolved Hide resolved
@ezyang
Copy link
Contributor

ezyang commented Dec 20, 2018

If you want to be super cool, copy paste the old implementation into the test suite and do some "reference implementation versus new optimized implementation" tests.

@vishwakftw
Copy link
Contributor Author

@ezyang I don't think that is necessary, because the reconstruction takes care of the correctness of the implementation.

However, if you insist, I don't mind adding the old-impl vs. new-impl tests as well.

@ezyang
Copy link
Contributor

ezyang commented Dec 21, 2018

I don't think that is necessary, because the reconstruction takes care of the correctness of the implementation.

However, if you insist, I don't mind adding the old-impl vs. new-impl tests as well.

Nope, sounds good to me

@vishwakftw
Copy link
Contributor Author

@ezyang @zou3519 is this good to go?

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@vishwakftw vishwakftw deleted the btriunpack-fast-many-dims branch December 31, 2018 02:46
vishwakftw added a commit to vishwakftw/pytorch that referenced this pull request Jan 18, 2019
…ore (pytorch#15286)

Summary:
Changelog:
- Optimize btriunpack by using `torch.where` instead of indexing, inplace operations instead of out place operations and avoiding costly permutations by computing the final permutation over a list.
Pull Request resolved: pytorch#15286

Differential Revision: D13562038

Pulled By: soumith

fbshipit-source-id: e2c94cfab5322bf1d24bf56d7b056619f553acc6
soumith pushed a commit that referenced this pull request Jan 18, 2019
…ore (#15286)

Summary:
Changelog:
- Optimize btriunpack by using `torch.where` instead of indexing, inplace operations instead of out place operations and avoiding costly permutations by computing the final permutation over a list.
Pull Request resolved: #15286

Differential Revision: D13562038

Pulled By: soumith

fbshipit-source-id: e2c94cfab5322bf1d24bf56d7b056619f553acc6
soumith pushed a commit that referenced this pull request Jan 29, 2019
…ore (#15286)

Summary:
Changelog:
- Optimize btriunpack by using `torch.where` instead of indexing, inplace operations instead of out place operations and avoiding costly permutations by computing the final permutation over a list.
Pull Request resolved: #15286

Differential Revision: D13562038

Pulled By: soumith

fbshipit-source-id: e2c94cfab5322bf1d24bf56d7b056619f553acc6
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants