-
Notifications
You must be signed in to change notification settings - Fork 22.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allowing batching for det/logdet/slogdet operations #22909
Conversation
cc: @ssnl |
Really looking forward to this, I need it as soon as possible. |
} else if ((det == 0).all().item<uint8_t>()) { // all matrices are singular | ||
return singular_case_backward(grad, self, det); | ||
} else { // some matrices are invertible, some matrices are singular | ||
return at::where(det != 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still not a fan of this. This means that just running det in a loop and cat afterwards can often be much faster.
You are synching by accessing det.all() anyways. So, why not just do a det.nonzero()
, selecting and calculating both parts, and then putting them together?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually didn’t think of nonzero, I’ll try it out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, thanks! Inplace index_put
gradients into correct locations may break double backward, you might also want to check if that still works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inplace index_put has a gradient defined, so I don't think it'll break double backward. I've made changes as per your suggestions.
|
||
if ((det != 0).all().item<uint8_t>()) { // all matrices are invertible | ||
return unsqueeze_multiple(grad * det, {-1, -2}, self.dim()) * self.inverse().transpose(-2, -1); | ||
} else if ((det == 0).all().item<uint8_t>()) { // all matrices are singular |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm pretty sure these two if conds can be merged together to avoid some duplicated computation.
This is great, but in the end this is a wrapper (since there are no actual batched kernels) so users don't need to write a for loop themselves. So we shouldn't make certain cases slower than a for loop. |
emmm, you mean we cannot actually get acceleration from paralell computing...? |
@Gsunshine The forward is batched with magma's batched LU call, but the backward is just a loop, unless @vishwakftw is launching svd on different streams, which I don't think is happening with the current code. |
@Gsunshine yes, @ssnl is right about the forward call. For backward, in this case of invertible matrices, the backward uses a batch MAGMA getri cal, but in the case of singular matrices, the backward computation is looped. |
@ssnl @vishwakftw I got it. thx! I prefer If you need any help in the enhancement, pls feel free to @Gsunshine . Hope to offer my help! |
@pytorchbot rebase this please |
…rch into batch-det-logdet-slogdet
…_diag_U_info - _lu_det_P_diag_U_info is renamed to _lu_det_P_diag_U
…h-det-logdet-slogdet
@pytorchbot rebase this please |
@ssnl I optimized the forward implementation now (turns out we actually don’t need the infos tensor from the helper function to be returned). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a bunch of perf comments, but they are all minor. Feel free to just merge this please
after you fix them!
Also, this looks great! Thank you so much for working on this! :)
These things are small so I pushed to your branch. Hopefully we can get a green CI. |
26d03b5
to
3ad956c
Compare
3ad956c
to
874c259
Compare
@ssnl thank you very much for the detailed review. Sorry I couldn’t get to them earlier than you did because it was late back here. |
@pytorchbot merge this please |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Changelog: - Add batching for det / logdet / slogdet operations - Update derivative computation to support batched inputs (and consequently batched outputs) - Update docs Pull Request resolved: pytorch#22909 Test Plan: - Add a `test_det_logdet_slogdet_batched` method in `test_torch.py` to test `torch.det`, `torch.logdet` and `torch.slogdet` on batched inputs. This relies on the correctness of `torch.det` on single matrices (tested by `test_det_logdet_slogdet`). A port of this test is added to `test_cuda.py` - Add autograd tests for batched inputs Differential Revision: D16580988 Pulled By: ezyang fbshipit-source-id: b76c87212fbe621f42a847e3b809b5e60cfcdb7a
Summary: Changelog: - Add batching for det / logdet / slogdet operations - Update derivative computation to support batched inputs (and consequently batched outputs) - Update docs Pull Request resolved: pytorch/pytorch#22909 Test Plan: - Add a `test_det_logdet_slogdet_batched` method in `test_torch.py` to test `torch.det`, `torch.logdet` and `torch.slogdet` on batched inputs. This relies on the correctness of `torch.det` on single matrices (tested by `test_det_logdet_slogdet`). A port of this test is added to `test_cuda.py` - Add autograd tests for batched inputs Differential Revision: D16580988 Pulled By: ezyang fbshipit-source-id: b76c87212fbe621f42a847e3b809b5e60cfcdb7a
Summary: Changelog: - Add batching for det / logdet / slogdet operations - Update derivative computation to support batched inputs (and consequently batched outputs) - Update docs Pull Request resolved: #22909 Test Plan: - Add a `test_det_logdet_slogdet_batched` method in `test_torch.py` to test `torch.det`, `torch.logdet` and `torch.slogdet` on batched inputs. This relies on the correctness of `torch.det` on single matrices (tested by `test_det_logdet_slogdet`). A port of this test is added to `test_cuda.py` - Add autograd tests for batched inputs Differential Revision: D16580988 Pulled By: ezyang fbshipit-source-id: b76c87212fbe621f42a847e3b809b5e60cfcdb7a
Changelog:
Test Plan:
test_det_logdet_slogdet_batched
method intest_torch.py
to testtorch.det
,torch.logdet
andtorch.slogdet
on batched inputs. This relies on the correctness oftorch.det
on single matrices (tested bytest_det_logdet_slogdet
). A port of this test is added totest_cuda.py