New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Port CPU torch.orgqr to ATen #50502
Port CPU torch.orgqr to ATen #50502
Conversation
native_functions.yml
💊 CI failures summary and remediationsAs of commit 5f67bc9 (more details on the Dr. CI page):
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
auto infos = at::empty({std::max<int64_t>(1, batchCount(input))}, input.options().dtype(kInt).device(kCPU)); | ||
|
||
// if result is not empty and not in batched column major format we have to allocate a temporary tensor | ||
if (result.numel() != 0 && !result.transpose(-2, -1).is_contiguous()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really nice.
|
||
Tensor orgqr(const Tensor& input, const Tensor& tau) { | ||
Tensor result = at::empty({0}, input.options()); | ||
result = at::orgqr_outf(input, tau, result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
orgqr_outf
must be a typo here, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, that's the actual function now.
at::orgqr_outf(input, tau, result)
is equivalent to at::orgqr_out(result, input, tau)
.
@@ -4049,21 +4049,98 @@ def test_renorm_ps(self, device): | |||
|
|||
@onlyCPU | |||
@skipCPUIfNoLapack | |||
def test_orgqr_errors(self, device): | |||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) | |||
def test_orgqr(self, device, dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pretty cool that this was never tested previously.
actual = torch.orgqr(reflectors, tau) | ||
self.assertEqual(expected, actual) | ||
|
||
out = torch.empty_like(A) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add an OpInfo for this function either in the CUDA port PR or after it. Then we won't need this out test, I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @IvanYashchuk!
This port looks pretty good. @ngimel and I made a few comments.
However, we were looking at this function and wondering if it should just be removed. We could find no uses of it within Facebook. Do you think it's interesting vs. torch.linalg.qr? If we want to support this functionality, maybe we'd offer it as a mode on torch.linalg.qr instead of this cryptically named function.
Looking forward to hearing your thoughts.
EDIT: We preemptively edited the tracking issue #49421 to reflect deprecating and removing, not porting, functions like orgqr. If we decide we don't want to keep it then we can change the issue back.
Some more context on deprecating this function, it is used in one function:
that appears to be copied around some repos on Github. |
There is a recent request for this function to support CUDA and differentiation #50104. I definitely agree that we shouldn't use lapack's name. I was thinking to introduce a more descriptive name together with the PR for the backward rule, something like |
Aha! Thank you for reminding me. We can work on a name improvement like you've suggested in a future PR (using an alias). |
511f4dd
to
20b0654
Compare
CI failures are not related to this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @IvanYashchuk!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @IvanYashchuk!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
This is hitting a series of errors internally. I suggest we rebase and resubmit it using ci-all, especially since this base looks like it has several failing builds. The failures are the CPU variants of test_orgqr: caffe2/test:linalg - test_orgqr_cpu_float32 (test_linalg.TestLinalgCPU) Example error output:
|
int lwork = -1; | ||
scalar_t wkopt; | ||
lapackOrgqr<scalar_t>(m, n_columns, k, self_data, lda, tau_data, &wkopt, lwork, &infos_data[0]); | ||
lwork = static_cast<int>(real_impl<scalar_t, value_t>(wkopt)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like this error Argument 8 has illegal value
. The 8-th argument is lwork
, it is an integer for which we should obtain the value from LAPACK.
The size of the work array (lwork≥n).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The versioned used internally seems to be much older than the version in CI.
Could it be the change in how the function is called, using m instead of lda? I'll try to debug internally, too, to get a better sense for what's going on. @ngimel points out that torch.linalg.qr must be relying on this same function, so it's surprising we haven't seen this issue previously.
at::native::resize_as_(result, input.transpose(-2, -1), MemoryFormat::Contiguous); | ||
result.transpose_(-2, -1); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adding early return here
//early return for empty matrices
if (result.numel() == 0) {
infos.fill_(0);
return result;
}
fixes internal error. I now wonder how OSS tests pass, because for empty matrix lwork is returned as 0, and that's an illegal value (it should be at least 1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reference LAPACK for empty matrices returns lwork as 1
https://github.com/Reference-LAPACK/lapack/blob/master/SRC/dorgqr.f#L193-L196
Maybe the older version of MKL didn't do that and now it does the same as in reference implementation, that's why OSS could be passing.
We haven't seen issues with torch.linalg.qr because the early return is used there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Codecov Report
@@ Coverage Diff @@
## master #50502 +/- ##
==========================================
+ Coverage 80.99% 81.01% +0.01%
==========================================
Files 1916 1917 +1
Lines 209552 209556 +4
==========================================
+ Hits 169736 169762 +26
+ Misses 39816 39794 -22 |
Now we can remove
_th_orgqr
!Compared to the original TH-based
orgqr
, complex (#33152) and batched inputs are now supported.CUDA support will be added in a follow-up PR.
Closes #24747
Ref. #49421, #42666