-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Add forward AD for linalg.det and simplify its backward #79487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR is in preparation for implementing `logdet` and `slogdet` as structured kernels + implementing them with more efficient derivatives We implement forward AD for det. We also simplify the implementation of the backward, and leave a note on how to implement it properly for singular matrices. We leave thad for future work. Note (by looking at the OpInfo) that the current implementation passes the same tests as the one before. We skip the forward-over-backward in the singular case, as that one was not working in the gradgrad case either. [ghstack-poisoned]
🔗 Helpful links
❌ 2 New FailuresAs of commit 833bd72 (more details on the Dr. CI page): Expand to see more
🕵️ 2 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages
|
This PR is in preparation for implementing `logdet` and `slogdet` as structured kernels + implementing them with more efficient derivatives We implement forward AD for det. We also simplify the implementation of the backward, and leave a note on how to implement it properly for singular matrices. We leave thad for future work. Note (by looking at the OpInfo) that the current implementation passes the same tests as the one before. We skip the forward-over-backward in the singular case, as that one was not working in the gradgrad case either. ghstack-source-id: 9f4033c Pull Request resolved: #79487
This PR is in preparation for implementing `logdet` and `slogdet` as structured kernels + implementing them with more efficient derivatives We implement forward AD for det. We also simplify the implementation of the backward, and leave a note on how to implement it properly for singular matrices. We leave thad for future work. Note (by looking at the OpInfo) that the current implementation passes the same tests as the one before. We skip the forward-over-backward in the singular case, as that one was not working in the gradgrad case either. [ghstack-poisoned]
This PR is in preparation for implementing `logdet` and `slogdet` as structured kernels + implementing them with more efficient derivatives We implement forward AD for det. We also simplify the implementation of the backward, and leave a note on how to implement it properly for singular matrices. We leave thad for future work. Note (by looking at the OpInfo) that the current implementation passes the same tests as the one before. We skip the forward-over-backward in the singular case, as that one was not working in the gradgrad case either. [ghstack-poisoned]
This PR is in preparation for implementing `logdet` and `slogdet` as structured kernels + implementing them with more efficient derivatives We implement forward AD for det. We also simplify the implementation of the backward, and leave a note on how to implement it properly for singular matrices. We leave thad for future work. Note (by looking at the OpInfo) that the current implementation passes the same tests as the one before. We skip the forward-over-backward in the singular case, as that one was not working in the gradgrad case either. [ghstack-poisoned]
// The proper way of doing this is doing `auto mask = det == 0.;` and then | ||
// if any determinant is zero, use an SVD decomposition to compute the | ||
// derivative in those inputs (not all inputs). The derivative may be then |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But this is what is done in the current code. What is it being removed now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not quite. The point here is that:
- If there is just one singular value that is zero, then the derivative is not zero
- If there is more than one singular value that is zero, the derivative is zero
- The SVD returns the singular values in decreasing order, so if the det is zero, we can just remove the smallest singular value and compute the derivative as (det(U) * det(V).conj() * s_1 * ... * s_{n-1})u^Tv
The last point is the relevant one to perform this optimisation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, makes sense. Why not implementing it in this PR though?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realised it's not that easy. It's easy to compute the derivative, but the second derivative will not be corrct.
I am not implementing it in this PR as I need to move on to do other things. I've left a detailed comment on what's the issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fwiw, the second derivative of a vector with more than one zero is not correct in torch.prod
afaik, so it's fine if it doesn't work for torch.det
either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, forgot to push the changes. Now they're in.
This PR is in preparation for implementing `logdet` and `slogdet` as structured kernels + implementing them with more efficient derivatives We implement forward AD for det. We also simplify the implementation of the backward, and leave a note on how to implement it properly for singular matrices. We leave thad for future work. Note (by looking at the OpInfo) that the current implementation passes the same tests as the one before. We skip the forward-over-backward in the singular case, as that one was not working in the gradgrad case either. [ghstack-poisoned]
This PR is in preparation for implementing `logdet` and `slogdet` as structured kernels + implementing them with more efficient derivatives We implement forward AD for det. We also simplify the implementation of the backward, and leave a note on how to implement it properly for singular matrices. We leave thad for future work. Note (by looking at the OpInfo) that the current implementation passes the same tests as the one before. We skip the forward-over-backward in the singular case, as that one was not working in the gradgrad case either. [ghstack-poisoned]
This PR is in preparation for implementing `logdet` and `slogdet` as structured kernels + implementing them with more efficient derivatives We implement forward AD for det. We also simplify the implementation of the backward, and leave a note on how to implement it properly for singular matrices. We leave thad for future work. Note (by looking at the OpInfo) that the current implementation passes the same tests as the one before. We skip the forward-over-backward in the singular case, as that one was not working in the gradgrad case either. [ghstack-poisoned]
This PR is in preparation for implementing `logdet` and `slogdet` as structured kernels + implementing them with more efficient derivatives We implement forward AD for det. We also simplify the implementation of the backward, and leave a note on how to implement it properly for singular matrices. We leave thad for future work. Note (by looking at the OpInfo) that the current implementation passes the same tests as the one before. We skip the forward-over-backward in the singular case, as that one was not working in the gradgrad case either. [ghstack-poisoned]
This PR is in preparation for implementing `logdet` and `slogdet` as structured kernels + implementing them with more efficient derivatives We implement forward AD for det. We also simplify the implementation of the backward, and leave a note on how to implement it properly for singular matrices. We leave thad for future work. Note (by looking at the OpInfo) that the current implementation passes the same tests as the one before. We skip the forward-over-backward in the singular case, as that one was not working in the gradgrad case either. [ghstack-poisoned]
// compute higher derivatives of `x.prod()` and `x` has more than one zero. | ||
return at::linalg_solve(A.mH(), d); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know which samples gradgradcheck
fails on?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I filled up a follow-up PR #80217 to unlock this one (as this one is already working and these modifications may introduce some weird errors in some platforms)
This PR is in preparation for implementing `logdet` and `slogdet` as structured kernels + implementing them with more efficient derivatives We implement forward AD for det. We also simplify the implementation of the backward, and leave a note on how to implement it properly for singular matrices. We leave thad for future work. Note (by looking at the OpInfo) that the current implementation passes the same tests as the one before. We skip the forward-over-backward in the singular case, as that one was not working in the gradgrad case either. [ghstack-poisoned]
// The issue with this code is that the derivative given by autograd of | ||
// prod_safe_zeros_backward is not the second derivative of the product. | ||
// It is not clear to me how to implement the second derivative of the | ||
// product efficently. Note that this is also currently a problem when we |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the issue with the prod
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a non-issue, sorry. This comment is removed in #80217
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, Mario. LGTM! Let's see the output, and that of the next one in the stack.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM
@pytorchbot merge -g |
@pytorchbot successfully started a merge job. Check the current status here |
Merge failed due to 2 additional jobs have failed, first few of them are: periodic ,linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck / test (default, 1, 2, linux.4xlarge.nvidia.gpu) |
The error, while related to the backward of conv, seems completely unrelated and it looks like a flaky test cc @janeyx99 @jbschlosser (note that the previous CI run was green and there were no changes in this PR between that run and this one). I'm going YOLO and trying a merge anyway. Wish me the best. |
@pytorchbot merge |
@pytorchbot successfully started a merge job. Check the current status here |
) Summary: This PR is in preparation for implementing `logdet` and `slogdet` as structured kernels + implementing them with more efficient derivatives We implement forward AD for det. We also simplify the implementation of the backward, and leave a note on how to implement it properly for singular matrices. We leave thad for future work. Note (by looking at the OpInfo) that the current implementation passes the same tests as the one before. We skip the forward-over-backward in the singular case, as that one was not working in the gradgrad case either. Pull Request resolved: #79487 Approved by: https://github.com/nikitaved, https://github.com/albanD Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/42a2359612743f791fd2c0684b75f2ff450bc0b3 Reviewed By: seemethere Differential Revision: D37423853 Pulled By: seemethere fbshipit-source-id: 627829937609ae977c7c539a5aa6cf009512e162
Stack from ghstack:
This PR is in preparation for implementing
logdet
andslogdet
asstructured kernels + implementing them with more efficient derivatives
We implement forward AD for det. We also simplify the implementation of
the backward, and leave a note on how to implement it properly for
singular matrices. We leave thad for future work.
Note (by looking at the OpInfo) that the current implementation passes
the same tests as the one before. We skip the forward-over-backward in
the singular case, as that one was not working in the gradgrad case
either.