Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowing batching for det/logdet/slogdet operations #22909

Closed
wants to merge 24 commits into from

Conversation

vishwakftw
Copy link
Contributor

@vishwakftw vishwakftw commented Jul 16, 2019

Changelog:

  • Add batching for det / logdet / slogdet operations
  • Update derivative computation to support batched inputs (and consequently batched outputs)
  • Update docs

Test Plan:

  • Add a test_det_logdet_slogdet_batched method in test_torch.py to test torch.det, torch.logdet and torch.slogdet on batched inputs. This relies on the correctness of torch.det on single matrices (tested by test_det_logdet_slogdet). A port of this test is added to test_cuda.py
  • Add autograd tests for batched inputs

@pytorchbot pytorchbot added module: internals Related to internal abstractions in c10 and ATen module: operators labels Jul 16, 2019
@vishwakftw
Copy link
Contributor Author

cc: @ssnl

@calincru
Copy link

Really looking forward to this, I need it as soon as possible.

@pytorchbot pytorchbot added module: cuda Related to torch.cuda, and CUDA support in general module: docs Related to our documentation, both in docs/ and docblocks module: tests Issues related to tests (not the torch.testing module) labels Jul 19, 2019
@vishwakftw vishwakftw requested a review from ssnl July 19, 2019 15:19
@vishwakftw vishwakftw changed the title [WIP] Allowing batching for det/logdet/slogdet operations Allowing batching for det/logdet/slogdet operations Jul 19, 2019
} else if ((det == 0).all().item<uint8_t>()) { // all matrices are singular
return singular_case_backward(grad, self, det);
} else { // some matrices are invertible, some matrices are singular
return at::where(det != 0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still not a fan of this. This means that just running det in a loop and cat afterwards can often be much faster.

You are synching by accessing det.all() anyways. So, why not just do a det.nonzero(), selecting and calculating both parts, and then putting them together?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually didn’t think of nonzero, I’ll try it out.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, thanks! Inplace index_put gradients into correct locations may break double backward, you might also want to check if that still works.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inplace index_put has a gradient defined, so I don't think it'll break double backward. I've made changes as per your suggestions.


if ((det != 0).all().item<uint8_t>()) { // all matrices are invertible
return unsqueeze_multiple(grad * det, {-1, -2}, self.dim()) * self.inverse().transpose(-2, -1);
} else if ((det == 0).all().item<uint8_t>()) { // all matrices are singular
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure these two if conds can be merged together to avoid some duplicated computation.

@ssnl
Copy link
Collaborator

ssnl commented Jul 19, 2019

This is great, but in the end this is a wrapper (since there are no actual batched kernels) so users don't need to write a for loop themselves. So we shouldn't make certain cases slower than a for loop.

@cpuhrsch cpuhrsch added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 19, 2019
@Gsunshine
Copy link

This is great, but in the end this is a wrapper (since there are no actual batched kernels) so users don't need to write a for loop themselves. So we shouldn't make certain cases slower than a for loop.

emmm, you mean we cannot actually get acceleration from paralell computing...?

@ssnl
Copy link
Collaborator

ssnl commented Jul 20, 2019

@Gsunshine The forward is batched with magma's batched LU call, but the backward is just a loop, unless @vishwakftw is launching svd on different streams, which I don't think is happening with the current code.

@vishwakftw
Copy link
Contributor Author

@Gsunshine yes, @ssnl is right about the forward call. For backward, in this case of invertible matrices, the backward uses a batch MAGMA getri cal, but in the case of singular matrices, the backward computation is looped.

@Gsunshine
Copy link

Gsunshine commented Jul 20, 2019

@ssnl @vishwakftw I got it. thx!

I prefer logdet for invertible matrices with both batched forward and backward call, while keeping a general function slogdet with svd backward. Maybe an argument sym_pos_def for SPD matrices could be added to det & logdet to enable even faster batched cholesky implementation.

If you need any help in the enhancement, pls feel free to @Gsunshine . Hope to offer my help!

@vishwakftw
Copy link
Contributor Author

@pytorchbot rebase this please

@vishwakftw
Copy link
Contributor Author

@pytorchbot rebase this please

@vishwakftw
Copy link
Contributor Author

@ssnl I optimized the forward implementation now (turns out we actually don’t need the infos tensor from the helper function to be returned).

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a bunch of perf comments, but they are all minor. Feel free to just merge this please after you fix them!

Also, this looks great! Thank you so much for working on this! :)

aten/src/ATen/native/LinearAlgebra.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/LinearAlgebra.cpp Show resolved Hide resolved
aten/src/ATen/native/LinearAlgebra.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/LinearAlgebra.cpp Show resolved Hide resolved
tools/autograd/templates/Functions.cpp Outdated Show resolved Hide resolved
tools/autograd/templates/Functions.cpp Outdated Show resolved Hide resolved
tools/autograd/templates/Functions.cpp Outdated Show resolved Hide resolved
tools/autograd/templates/Functions.cpp Show resolved Hide resolved
tools/autograd/templates/Functions.cpp Outdated Show resolved Hide resolved
tools/autograd/templates/Functions.cpp Outdated Show resolved Hide resolved
@ssnl
Copy link
Collaborator

ssnl commented Jul 31, 2019

These things are small so I pushed to your branch. Hopefully we can get a green CI.

@ssnl ssnl force-pushed the batch-det-logdet-slogdet branch 3 times, most recently from 26d03b5 to 3ad956c Compare July 31, 2019 02:10
@ssnl ssnl force-pushed the batch-det-logdet-slogdet branch from 3ad956c to 874c259 Compare July 31, 2019 02:29
@vishwakftw
Copy link
Contributor Author

@ssnl thank you very much for the detailed review. Sorry I couldn’t get to them earlier than you did because it was late back here.

@vishwakftw
Copy link
Contributor Author

@pytorchbot merge this please

@pytorchbot pytorchbot added the merge-this-please Was marked for merge with @pytorchbot merge this please label Jul 31, 2019
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

vishwakftw added a commit to vishwakftw/pytorch that referenced this pull request Jul 31, 2019
Summary:
Changelog:
- Add batching for det / logdet / slogdet operations
- Update derivative computation to support batched inputs (and consequently batched outputs)
- Update docs
Pull Request resolved: pytorch#22909

Test Plan:
- Add a `test_det_logdet_slogdet_batched` method in `test_torch.py` to test `torch.det`, `torch.logdet` and `torch.slogdet` on batched inputs. This relies on the correctness of `torch.det` on single matrices (tested by `test_det_logdet_slogdet`). A port of this test is added to `test_cuda.py`
- Add autograd tests for batched inputs

Differential Revision: D16580988

Pulled By: ezyang

fbshipit-source-id: b76c87212fbe621f42a847e3b809b5e60cfcdb7a
@vishwakftw vishwakftw deleted the batch-det-logdet-slogdet branch July 31, 2019 17:27
zdevito pushed a commit to zdevito/ATen that referenced this pull request Jul 31, 2019
Summary:
Changelog:
- Add batching for det / logdet / slogdet operations
- Update derivative computation to support batched inputs (and consequently batched outputs)
- Update docs
Pull Request resolved: pytorch/pytorch#22909

Test Plan:
- Add a `test_det_logdet_slogdet_batched` method in `test_torch.py` to test `torch.det`, `torch.logdet` and `torch.slogdet` on batched inputs. This relies on the correctness of `torch.det` on single matrices (tested by `test_det_logdet_slogdet`). A port of this test is added to `test_cuda.py`
- Add autograd tests for batched inputs

Differential Revision: D16580988

Pulled By: ezyang

fbshipit-source-id: b76c87212fbe621f42a847e3b809b5e60cfcdb7a
@facebook-github-bot
Copy link
Contributor

@ezyang merged this pull request in 5d130e4.

soumith pushed a commit that referenced this pull request Aug 2, 2019
Summary:
Changelog:
- Add batching for det / logdet / slogdet operations
- Update derivative computation to support batched inputs (and consequently batched outputs)
- Update docs
Pull Request resolved: #22909

Test Plan:
- Add a `test_det_logdet_slogdet_batched` method in `test_torch.py` to test `torch.det`, `torch.logdet` and `torch.slogdet` on batched inputs. This relies on the correctness of `torch.det` on single matrices (tested by `test_det_logdet_slogdet`). A port of this test is added to `test_cuda.py`
- Add autograd tests for batched inputs

Differential Revision: D16580988

Pulled By: ezyang

fbshipit-source-id: b76c87212fbe621f42a847e3b809b5e60cfcdb7a
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
merge-this-please Was marked for merge with @pytorchbot merge this please Merged module: cuda Related to torch.cuda, and CUDA support in general module: docs Related to our documentation, both in docs/ and docblocks module: internals Related to internal abstractions in c10 and ATen module: tests Issues related to tests (not the torch.testing module) oncall: jit Add this issue/PR to JIT oncall triage queue open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants