New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added linalg.pinv #48399
Added linalg.pinv #48399
Changes from 11 commits
fc302ba
3eb8552
7dcb934
14a0c71
dec1212
3170b60
492b697
cae7cb2
94325ac
44d86c5
727baff
0bb266e
52d3694
7537fb6
b7a3eb3
4d08f1e
e758afd
0800641
4de9786
4ac1eaf
ed067a8
b8714d9
6eb2f86
f5fd672
80be298
fa57cf3
76ec22b
c0045fc
42882c3
7c1672a
609e931
73f9dce
0507670
dbd65d4
48ac5dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -95,22 +95,70 @@ std::tuple<Tensor, Tensor> slogdet(const Tensor& self) { | |
return std::make_tuple(det_sign, abslogdet_val); | ||
} | ||
|
||
Tensor pinverse(const Tensor& self, double rcond) { | ||
TORCH_CHECK((at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type())) && self.dim() >= 2, | ||
"pinverse(", self.scalar_type(), "{", self.sizes(), "}): expected a tensor with 2 or more dimensions " | ||
Tensor linalg_pinv(const Tensor& input, const Tensor& rcond, bool hermitian) { | ||
TORCH_CHECK((at::isFloatingType(input.scalar_type()) || at::isComplexType(input.scalar_type())) && input.dim() >= 2, | ||
"linalg_pinv(", input.scalar_type(), "{", input.sizes(), "}): expected a tensor with 2 or more dimensions " | ||
"of floating types"); | ||
if (self.numel() == 0) { | ||
if (input.numel() == 0) { | ||
// Match NumPy | ||
auto self_sizes = self.sizes().vec(); | ||
std::swap(self_sizes[self.dim() - 1], self_sizes[self.dim() - 2]); | ||
return at::empty(self_sizes, self.options()); | ||
auto input_sizes = input.sizes().vec(); | ||
mruberry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
std::swap(input_sizes[input.dim() - 1], input_sizes[input.dim() - 2]); | ||
return at::empty(input_sizes, input.options()); | ||
} | ||
|
||
Tensor rcond_ = rcond; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you elaborate on what's happening here, and add a comment elaborating There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was not needed here and I removed it. b7a3eb3 |
||
if (rcond.dim() > 0) { | ||
rcond_ = rcond.unsqueeze(-1); | ||
} | ||
Tensor U, S, V; | ||
std::tie(U, S, V) = self.svd(); | ||
Tensor max_val = at::narrow(S, /*dim=*/-1, /*start=*/0, /*length=*/1); | ||
Tensor S_pseudoinv = at::where(S > rcond * max_val, S.reciprocal(), at::zeros({}, S.options())).to(self.dtype()); | ||
// computes V.conj() @ diag(S_pseudoinv) @ U.T.conj() | ||
return at::matmul(V.conj() * S_pseudoinv.unsqueeze(-2), U.transpose(-2, -1).conj()); | ||
|
||
// If not Hermitian use singular value decomposition, else use eigenvalue decomposition | ||
if (!hermitian) { | ||
// until https://github.com/pytorch/pytorch/issues/45821 is resolved | ||
mruberry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// svd() returns conjugated V for complex-valued input | ||
Tensor U, S, V_conj; | ||
// TODO: replace input.svd with linalg_svd | ||
std::tie(U, S, V_conj) = input.svd(); | ||
Tensor max_val = at::narrow(S, /*dim=*/-1, /*start=*/0, /*length=*/1); // singular values are sorted in descending order | ||
Tensor S_pseudoinv = at::where(S > rcond_ * max_val, S.reciprocal(), at::zeros({}, S.options())).to(input.dtype()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Parens around There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, |
||
// computes V @ diag(S_pseudoinv) @ U.T.conj() | ||
// TODO: replace V_conj.conj() -> V once https://github.com/pytorch/pytorch/issues/45821 is resolved | ||
return at::matmul(V_conj.conj() * S_pseudoinv.unsqueeze(-2), U.conj().transpose(-2, -1)); | ||
} else { | ||
Tensor S, U; | ||
std::tie(S, U) = at::linalg_eigh(input); | ||
// For Hermitian matrices, singular values equal to abs(eigenvalues) | ||
Tensor S_abs = S.abs(); | ||
// eigenvalues are sorted in ascending order starting with negative values, we need a maximum value of abs(eigenvalues) | ||
Tensor max_val = S_abs.amax(/*dim=*/-1, /*keepdim=*/true); | ||
Tensor S_pseudoinv = at::where(S_abs > rcond_ * max_val, S.reciprocal(), at::zeros({}, S.options())).to(input.dtype()); | ||
// computes U @ diag(S_pseudoinv) @ U.conj().T | ||
return at::matmul(U * S_pseudoinv.unsqueeze(-2), U.conj().transpose(-2, -1)); | ||
} | ||
} | ||
|
||
Tensor linalg_pinv(const Tensor& input, double rcond, bool hermitian) { | ||
Tensor rcond_tensor = at::full({}, rcond, input.options().dtype(ScalarType::Double)); | ||
mruberry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return at::linalg_pinv(input, rcond_tensor, hermitian); | ||
} | ||
|
||
// TODO: implement _out variant avoiding copy and using already allocated storage directly | ||
Tensor& linalg_pinv_out(Tensor& result, const Tensor& input, const Tensor& rcond, bool hermitian) { | ||
TORCH_CHECK(result.scalar_type() == input.scalar_type(), | ||
"result dtype ", result.scalar_type(), " does not match the expected dtype ", input.scalar_type()); | ||
|
||
Tensor result_tmp = at::linalg_pinv(input, rcond, hermitian); | ||
at::native::resize_output(result, result_tmp.sizes()); | ||
result.copy_(result_tmp); | ||
return result; | ||
} | ||
|
||
Tensor& linalg_pinv_out(Tensor& result, const Tensor& input, double rcond, bool hermitian) { | ||
Tensor rcond_tensor = at::full({}, rcond, input.options().dtype(ScalarType::Double)); | ||
mruberry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return at::linalg_pinv_out(result, input, rcond_tensor, hermitian); | ||
} | ||
|
||
Tensor pinverse(const Tensor& self, double rcond) { | ||
return at::linalg_pinv(self, rcond, /*hermitian=*/false); | ||
} | ||
|
||
Tensor& linalg_matrix_rank_out(Tensor& result, const Tensor& self, optional<double> tol, bool hermitian) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2064,6 +2064,105 @@ def run_test_singular_input(batch_dim, n): | |
for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]: | ||
run_test_singular_input(*params) | ||
|
||
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, torch.float64: 1e-7, torch.complex128: 1e-7}) | ||
@skipCUDAIfNoMagma | ||
@skipCPUIfNoLapack | ||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) | ||
def test_pinv(self, device, dtype): | ||
from torch.testing._internal.common_utils import random_hermitian_pd_matrix | ||
|
||
def run_test_main(A, hermitian): | ||
# Testing against definition for pseudo-inverses | ||
A_pinv = torch.linalg.pinv(A, hermitian=hermitian) | ||
if A.numel() > 0: | ||
self.assertEqual(A, A @ A_pinv @ A, atol=self.precision, rtol=self.precision) | ||
self.assertEqual(A_pinv, A_pinv @ A @ A_pinv, atol=self.precision, rtol=self.precision) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These lines probably want rtol = 0 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mathematically yes, but it doesn't work with rtol=0. It fails for fp32 and complex64 and a larger rtol is needed for this test to pass. We do two matrix-matrix multiplications ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, no worries. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note to self: we should support setting rtol and atol with the PrecisionOverride decorator. |
||
self.assertEqual(A @ A_pinv, (A @ A_pinv).conj().transpose(-2, -1)) | ||
self.assertEqual(A_pinv @ A, (A_pinv @ A).conj().transpose(-2, -1)) | ||
else: | ||
self.assertEqual(A.shape, A_pinv.shape[:-2] + (A_pinv.shape[-1], A_pinv.shape[-2])) | ||
|
||
# Check out= variant | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Too bad the OpInfo's won't test out= because we don't have multi-tensor out= testing (yet). |
||
out = torch.empty_like(A_pinv) | ||
ans = torch.linalg.pinv(A, hermitian=hermitian, out=out) | ||
self.assertEqual(ans, out) | ||
self.assertEqual(ans, A_pinv) | ||
|
||
def run_test_numpy(A, hermitian): | ||
# Check against NumPy output | ||
# Test float rcond, and specific value for each matrix | ||
rconds = [float(torch.rand(1)), torch.rand(A.shape[:-2], dtype=torch.double, device=device)] | ||
# Test broadcasting of rcond | ||
if A.ndim > 2: | ||
rconds.append(torch.rand(A.shape[-3], device=device)) | ||
for rcond in rconds: | ||
actual = torch.linalg.pinv(A, rcond=rcond, hermitian=hermitian) | ||
numpy_rcond = rcond if isinstance(rcond, float) else rcond.cpu().numpy() | ||
expected = np.linalg.pinv(A.cpu().numpy(), rcond=numpy_rcond, hermitian=hermitian) | ||
self.assertEqual(actual, expected) | ||
|
||
for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Awesome enumeration over sizes. |
||
(3, 2), (5, 3, 2), (2, 5, 3, 2), # fat matrices | ||
(2, 3), (5, 2, 3), (2, 5, 2, 3), # thin matrices | ||
(0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]: # zero numel matrices | ||
A = torch.randn(*sizes, dtype=dtype, device=device) | ||
hermitian = False | ||
run_test_main(A, hermitian) | ||
run_test_numpy(A, hermitian) | ||
|
||
# Check hermitian = True | ||
for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices | ||
(0, 0), (3, 0, 0), ]: # zero numel square matrices | ||
A = random_hermitian_pd_matrix(sizes[-1], *sizes[:-2], dtype=dtype, device=device) | ||
hermitian = True | ||
run_test_main(A, hermitian) | ||
run_test_numpy(A, hermitian) | ||
|
||
@skipCUDAIfNoMagma | ||
@skipCPUIfNoLapack | ||
@dtypes(torch.float64) | ||
def test_pinv_autograd(self, device, dtype): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this adds an OpInfo then I don't think a custom autograd test is needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it's not needed. I'll remove it. |
||
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value | ||
|
||
n = 5 | ||
for batches in ([], [2], [2, 3]): | ||
# using .to(device) instead of device=device because @xwang233 claims it's faster | ||
a = random_fullrank_matrix_distinct_singular_value(n, *batches, dtype=dtype).to(device) | ||
a.requires_grad_() | ||
|
||
def func(a, hermitian): | ||
if hermitian: | ||
a = a + a.conj().transpose(-2, -1) | ||
return torch.linalg.pinv(a, hermitian=hermitian) | ||
|
||
for hermitian in [False, True]: | ||
gradcheck(func, [a, hermitian]) | ||
gradgradcheck(func, [a, hermitian]) | ||
|
||
# TODO: RuntimeError: svd does not support automatic differentiation for outputs with complex dtype. | ||
# See https://github.com/pytorch/pytorch/pull/47761 | ||
@unittest.expectedFailure | ||
@skipCUDAIfNoMagma | ||
@skipCPUIfNoLapack | ||
@dtypes(torch.complex128) | ||
def test_pinv_autograd_complex_xfailed(self, device, dtype): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This failure test will be needed. There are other failure cases that need to be tested for, too. Like rcond having the wrong dtype or the sizes of the inputs being incorrect. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added |
||
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value | ||
|
||
n = 5 | ||
batches = (2, 3) | ||
# using .to(device) instead of device=device because @xwang233 claims it's faster | ||
a = random_fullrank_matrix_distinct_singular_value(n, *batches, dtype=dtype).to(device) | ||
a.requires_grad_() | ||
|
||
def func(a, hermitian): | ||
if hermitian: | ||
a = a + a.conj().transpose(-2, -1) | ||
return torch.linalg.pinv(a, hermitian=hermitian) | ||
|
||
for hermitian in [False, True]: | ||
gradcheck(func, [a, hermitian]) | ||
gradgradcheck(func, [a, hermitian]) | ||
|
||
def solve_test_helper(self, A_dims, b_dims, device, dtype): | ||
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -483,6 +483,79 @@ | |
>>> tensor([ 11.6734+0.j, 105.1037+0.j, 10.1978+0.j]) | ||
""") | ||
|
||
pinv = _add_docstr(_linalg.linalg_pinv, r""" | ||
linalg.pinv(input, rcond=1e-15, hermitian=False) -> Tensor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note for @heitorschueroff, we should be sure to docs deprecate the old pinverse before 1.8. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The NumPy version of this operator can throw a Runtime error: The other linear algebra functions that we've recently added also explain when they may throw runtime errors. Is there something we can document here, too? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a note that if svd or eigh do not converge then the runtime error will be thrown. It's difficult to test these situations though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What makes that hard to test? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can't find how to generate matrices that would make svd or eigh fail. They're quite robust. |
||
|
||
Computes the pseudo-inverse (also known as the Moore-Penrose inverse) of a matrix :attr:`input`, | ||
or of each matrix in a batched :attr:`input`. | ||
The pseudo-inverse is computed using singular value decomposition (see :func:`torch.linalg.svd`) by default. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately torch.linalg.svd doesn't exist yet. This can reference torch.svd for now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure I can change that.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is nice to dream... but we should probably keep all the references in the docs working ;) |
||
If :attr:`hermitian` is ``True``, then :attr:`input` is assumed to be Hermitian (symmetric if real-valued), | ||
and the computation of the pseudo-inverse is done by obtaining the eigenvalues and eigenvectors | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Overall this paragraph is very good. In the future we may want to elaborate on the precise computation using the outputs of the singular value decomposition or the eigenvalues and eigenvectors. |
||
(see :func:`torch.linalg.eigh`). | ||
The singular values (or the absolute eigenvalues when :attr:`hermitian` is ``True``) that are below | ||
the specified :attr:`rcond` threshold are treated to be zero and discarded in the computation. | ||
IvanYashchuk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Supports input of ``float``, ``double``, ``cfloat`` and ``cdouble`` datatypes. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should verify that float16 and bfloat16 are not supported and throw the proper error message when given. See comments above. |
||
|
||
.. note:: When given inputs on a CUDA device, this function synchronizes that device with the CPU. | ||
|
||
Args: | ||
input (Tensor): the input matrix of size :math:`(m, n)` or the batch of matrices of size :math:`(*, m, n)` | ||
where `*` is one or more batch dimensions. | ||
rcond (float, Tensor, optional): the tolerance value to determine the cutoff for small singular values. Default: 1e-15 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One challenge with clearly documenting rcond is explaining what different tensor shapes do. Do you think we can elaborate on this? In particular, the shape of rcond must be broadcastable to the singular value tensor returned from torch.svd, right? Is it worth including an rcond example where rcond isn't a scalar? Possibly by inspecting the singular value decomposition and then determining a cutoff? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, rcond must be broadcastable to the singular values. I think it's a good enough condition to mention and doesn't need further explanation? |
||
hermitian(bool, optional): indicates whether :attr:`input` is Hermitian. Default: ``False`` | ||
|
||
Examples:: | ||
|
||
>>> input = torch.randn(3, 5) | ||
>>> input | ||
tensor([[ 0.5495, 0.0979, -1.4092, -0.1128, 0.4132], | ||
[-1.1143, -0.3662, 0.3042, 1.6374, -0.9294], | ||
[-0.3269, -0.5745, -0.0382, -0.5922, -0.6759]]) | ||
>>> torch.linalg.pinv(input) | ||
tensor([[ 0.0600, -0.1933, -0.2090], | ||
[-0.0903, -0.0817, -0.4752], | ||
[-0.7124, -0.1631, -0.2272], | ||
[ 0.1356, 0.3933, -0.5023], | ||
[-0.0308, -0.1725, -0.5216]]) | ||
|
||
Batched linalg.pinv example | ||
>>> a = torch.randn(2, 6, 3) | ||
>>> b = torch.linalg.pinv(a) | ||
>>> torch.matmul(b, a) | ||
tensor([[[ 1.0000e+00, 1.6391e-07, -1.1548e-07], | ||
[ 8.3121e-08, 1.0000e+00, -2.7567e-07], | ||
[ 3.5390e-08, 1.4901e-08, 1.0000e+00]], | ||
|
||
[[ 1.0000e+00, -8.9407e-08, 2.9802e-08], | ||
[-2.2352e-07, 1.0000e+00, 1.1921e-07], | ||
[ 0.0000e+00, 8.9407e-08, 1.0000e+00]]]) | ||
|
||
Hermitian input example | ||
>>> a = torch.randn(3, 3, dtype=torch.complex64) | ||
>>> a = a + a.t().conj() # creates a Hermitian matrix | ||
>>> b = torch.linalg.pinv(a, hermitian=True) | ||
>>> torch.matmul(b, a) | ||
tensor([[ 1.0000e+00+0.0000e+00j, -1.1921e-07-2.3842e-07j, | ||
5.9605e-08-2.3842e-07j], | ||
[ 5.9605e-08+2.3842e-07j, 1.0000e+00+2.3842e-07j, | ||
-4.7684e-07+1.1921e-07j], | ||
[-1.1921e-07+0.0000e+00j, -2.3842e-07-2.9802e-07j, | ||
1.0000e+00-1.7897e-07j]]) | ||
|
||
Non-default rcond example | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This example is tricky. Because we don't see the intermediate steps it's hard to understand the effect that rcond is having. The next example has the same issue. What are your thoughts? I think we can leave out these examples for now since this PR seems to be almost ready to merge, otherwise, and maybe return to add them later? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The documentation says that the pseudo-inverse is calculated using SVD and it says that rcond determines which singular values should be set to zero. The first example demonstrates that using the default and some other value for rcond gives different results as expected. Just using the facts from the docs I think it's understandable that different rconds in general give different results. I think we should keep this example because I think examples should contain the code snippets demonstrating each variable. That's studying math to really understand the effect of rcond (topic of Tikhonov regularization) and definitely not needed here. We're on the level of "something in the input changes, something in the output should change as well". The purpose of the second rcond example ( How about keeping these examples as is (that's working and valid code) and return to it later to improve them? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mruberry, so should we remove those rcond examples now and think about it later or keep it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No it's fine, if you like them let's keep them. |
||
>>> rcond = 0.5 | ||
>>> a = torch.randn(3, 3) | ||
>>> torch.linalg.pinv(a) | ||
tensor([[ 0.2971, -0.4280, -2.0111], | ||
[-0.0090, 0.6426, -0.1116], | ||
[-0.7832, -0.2465, 1.0994]]) | ||
>>> torch.linalg.pinv(a, rcond) | ||
tensor([[-0.2672, -0.2351, -0.0539], | ||
[-0.0211, 0.6467, -0.0698], | ||
[-0.4400, -0.3638, -0.0910]]) | ||
""") | ||
|
||
tensorinv = _add_docstr(_linalg.linalg_tensorinv, r""" | ||
linalg.tensorinv(input, ind=2, *, out=None) -> Tensor | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"of floating types" is misleading. The requirement is that the tensors have a floating point dtype or a complex dtype.
Also, is this the right place to valid that input and rcond have the same dtype?
Also also, what happens when the dtype is bfloat16 or half?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh wait, I'm mistaken. It should not be that the dtype of input and rcond are the same.
Shouldn't it be that rcond must always be the "value type" of input? That is, if input is double rcond should be double, but if input is complex double then rcond should be double, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the torch check to accept only float, double, cfloat or cdouble types.
As for the rcond, I don't know actually how we should restrict its types. When only a float is given from Python interface it always gets converted to double in C++ and then we always create a scalar tensor of type double.
If directly a tensor is passed, then all types should be valid that allow multiplication with
max_val
tensor (that can be only of type float or double) and that allow comparison with a tensor of type float or double. So I guess any floating and integer types (so everything that is not complex) should be valid here for any allowed input?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added restriction for rcond type not to accept complex types 4de9786
And added tests to show that it works for all other types 4ac1eaf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your analysis makes sense. Restricting it to not be complex sounds good.