New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Port cholesky_inverse to ATen #50269
Conversation
Use Tensor of ints for 'infos' instead of std::vector
💊 CI failures summary and remediationsAs of commit f117e84 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
Benchmarks for
After:
|
So this PR pretty significantly regresses performance on both CUDA and CPU? That's worrying. What do you think is happening and can we address this? cc @ngimel |
TORCH_CHECK(result.device() == input.device(), | ||
"result device ", result.device(), " does not match input device ", input.device()); | ||
|
||
// Single matrix MAGMA routine requires 'infos' to reside in CPU memory, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait -- doesn't this PR implemented batched CUDA, though?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, this comment is a bit outdated. I need to change it. At that time I thought that batched MAGMA routine would require infos
to live on GPU.
Now CUDA path is implemented in terms of MAGMA's apply_cholesky_solve
. It is different from other batched functions in that it doesn't take an array of ints infos
argument, it doesn't raise any errors related to the algorithm. It returns only a single integer to tell whether all passed arguments were ok or not.
If cuSOLVER would be used instead in the future, then infos
would need to be created on GPU.
I will change it to:
- if input on CPU then we need
infos
of sizebatchsize(input)
to fill with error codes from LAPACK for each matrix in the batched tensor. - if input on GPU then we need only one integer living on CPU for storing error code from MAGMA
It also means that for some inputs CPU could raise an error while GPU would output something because the operations used are different.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It also means that for some inputs CPU could raise an error while GPU would output something because the operations used are different.
Are there some specific cases you're thinking of?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checked the docs. It will happen for inputs with zero on the diagonal. CPU version raises an error with the index of the zero diagonal element, while CUDA gives inf
. I need to add this to tests.
In [1]: import torch
In [2]: a = torch.randn(2, 2)
In [3]: a
Out[3]:
tensor([[-0.6556, -0.4479],
[-0.9347, 0.8169]])
In [4]: a[1, 1] = 0
In [5]: a
Out[5]:
tensor([[-0.6556, -0.4479],
[-0.9347, 0.0000]])
In [6]: torch.cholesky_inverse(a)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-6-680561e9e67c> in <module>
----> 1 torch.cholesky_inverse(a)
RuntimeError: Lapack Error potri : A(2,2) is 0, A cannot be factorized at ../aten/src/TH/generic/THTensorLapack.cpp:245
In [7]: torch.cholesky_solve(torch.eye(2), a)
Out[7]:
tensor([[inf, -inf],
[-inf, inf]])
In [8]: torch.cholesky_solve(torch.eye(2, device='cuda'), a.cuda())
Out[8]:
tensor([[inf, -inf],
[-inf, inf]], device='cuda:0')
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Testing (and, in the near future, documenting) this behavior sounds great.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work, @IvanYashchuk. There are a couple follow-ups (adjustment based on future OpInfo fixes, docs) as @anjali411 points out, but I think they're separable. I've added this to the list of operators to review in our 1.8 scrub.
When you're happy with this PR ping me and let's merge it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@IvanYashchuk please rebase the PR and let me know once it's ready for merge. |
@anjali411 I resolved the conflict. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@anjali411 merged this pull request in 6e4746c. |
After this issue was merged, the builds on
|
@dncliss thanks for reporting the issue, and yeah the fix should be similar. @IvanYashchuk could you create a fix for this? |
@anjali411 Thanks for confirming. I didn't open a new issue (I figure maybe @IvanYashchuk can just push the 1-line fix) but if you want it reported that way I could do so. For reference, the regular ppc64le builds seen here: https://powerci.osuosl.org/job/pytorch-master-nightly-py3-linux-ppc64le-gpu/ older than build 1047 had the error fixed by PR 51217, then 1047 itself was a successful build, and 1048 onward introduced the failure after the merging of this cholesky_inverse feature PR. The error, as before, is an "undefined reference" during linking. |
@anjali411, @dncliss I submitted the fix. Sorry for the trouble! |
Summary: It was overlooked that vsx dispatch is also needed for cholesky_inverse cpu dispatch. See #50269 (comment) Pull Request resolved: #51562 Reviewed By: H-Huang Differential Revision: D26199581 Pulled By: anjali411 fbshipit-source-id: 5d02c6da52ce1d2e9e26001f5d4648a71dd0e829
…puts. (#69069) Summary: While implementing #68720, We found out empirically that `torch.cholesky_inverse` support batched inputs, but it is not explained in doc: [link](#68720 (review)) `torch.cholesky_inverse` is implemented in #50269 and the doc was updated at #31275 but not merged. neerajprad Pull Request resolved: #69069 Reviewed By: mrshenli Differential Revision: D32979362 Pulled By: neerajprad fbshipit-source-id: 0967c969434ce6e0ab15889c240149c23c0bce44
…puts. (#69069) Summary: While implementing #68720, We found out empirically that `torch.cholesky_inverse` support batched inputs, but it is not explained in doc: [link](#68720 (review)) `torch.cholesky_inverse` is implemented in #50269 and the doc was updated at #31275 but not merged. neerajprad Reviewed By: mrshenli Differential Revision: D32979362 Pulled By: neerajprad fbshipit-source-id: 0967c969434ce6e0ab15889c240149c23c0bce44 [ghstack-poisoned]
…puts. (#69069) Summary: While implementing #68720, We found out empirically that `torch.cholesky_inverse` support batched inputs, but it is not explained in doc: [link](#68720 (review)) `torch.cholesky_inverse` is implemented in #50269 and the doc was updated at #31275 but not merged. neerajprad Reviewed By: mrshenli Differential Revision: D32979362 Pulled By: neerajprad fbshipit-source-id: 0967c969434ce6e0ab15889c240149c23c0bce44
…puts. (#69069) Summary: While implementing #68720, We found out empirically that `torch.cholesky_inverse` support batched inputs, but it is not explained in doc: [link](#68720 (review)) `torch.cholesky_inverse` is implemented in #50269 and the doc was updated at #31275 but not merged. neerajprad Pull Request resolved: #69069 Reviewed By: mrshenli Differential Revision: D32979362 Pulled By: neerajprad fbshipit-source-id: 0967c969434ce6e0ab15889c240149c23c0bce44
…puts. (#69069) Summary: While implementing #68720, We found out empirically that `torch.cholesky_inverse` support batched inputs, but it is not explained in doc: [link](#68720 (review)) `torch.cholesky_inverse` is implemented in #50269 and the doc was updated at #31275 but not merged. neerajprad Pull Request resolved: #69069 Reviewed By: mrshenli Differential Revision: D32979362 Pulled By: neerajprad fbshipit-source-id: 0967c969434ce6e0ab15889c240149c23c0bce44
Now we can remove
_th_potri
!Compared to the original TH-based
cholesky_inverse
, complex (#33152) and batched inputs (#7500) are now supported both on CPU and CUDA.Closes #24685.
Closes #24543.
Ref. #49421, #42666