New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added computing matrix condition numbers (linalg.cond) #45832
Changes from 50 commits
7ce8a20
dcff9ef
d179001
c88ccc3
1e49003
db2de7c
843c07f
fd82325
ed1b7de
7be03c7
f36e732
de025dc
43f52bc
39cc93f
6bded45
7a4dac6
5705429
0d9d49f
ae8878f
8dc29e0
02f43eb
36385d9
abf8a29
0b214da
49a1577
8b0cfcd
655a03d
27142ba
195463f
ea95b08
fbe025a
0056ca9
47da625
6eb461a
725cc17
e8b05da
26651b1
bacdc65
c7b52de
873f93e
5813333
c082881
c1e4bd9
f5296cb
b452379
b15fedd
73046d1
5ed8a77
e407098
9008c10
f615332
994cca6
898f855
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,8 @@ | |
#include <limits> | ||
#include <ATen/NamedTensorUtils.h> | ||
|
||
#include <c10/util/variant.h> | ||
|
||
namespace at { | ||
namespace native { | ||
|
||
|
@@ -1704,6 +1706,160 @@ Tensor& linalg_norm_out(Tensor& result, const Tensor& self, std::string ord, opt | |
return linalg_norm_out_impl(result, self, c10::nullopt, ord, opt_dim, keepdim, opt_dtype); | ||
} | ||
|
||
Tensor _linalg_cond_exception_helper(const Tensor& self) { | ||
// For batched input if at least one matrix in the batch is not invertible, | ||
// we can't get the result for all other (possibly) invertible matrices in the batch without an explicit for loop. | ||
// This should change when at::inverse works with silent errors | ||
if (self.dim() > 2) { | ||
TORCH_CHECK(false, | ||
"One or more matrices in the batch was not invertible! " | ||
"linalg_cond does not support yet this case."); | ||
} | ||
auto result_shape = IntArrayRef(self.sizes().cbegin(), self.sizes().cend()-2); | ||
Tensor result = at::full(result_shape, INFINITY, self.options()); | ||
return result; | ||
} | ||
|
||
// This function helps to dispatch norm computations depending on 'ord' of variant type | ||
Tensor _linalg_cond_helper(const Tensor& self, c10::variant<Scalar, std::string> ord_variant) { | ||
// Ignore errors if not invertible, result is INFINITY in this case | ||
// Currently checking for error in at::inverse causes cross-device data movement | ||
// For batched input if at least one matrix in the batch is not invertible, | ||
// then the result for all other (possibly) invertible matrices will be infinity as well | ||
// since there is currently no way to use at::inverse with silent errors | ||
Tensor self_inverse; | ||
try { | ||
self_inverse = at::inverse(self); | ||
} catch (const std::exception& e) { | ||
if (strstr(e.what(), "singular")) { | ||
return _linalg_cond_exception_helper(self); | ||
} else { | ||
TORCH_CHECK(false, "linalg_cond got an unexpected error:\n", e.what()); | ||
} | ||
} | ||
std::array<int64_t, 2> dim_arr = {-2, -1}; | ||
optional<IntArrayRef> dim = IntArrayRef(dim_arr); | ||
|
||
return c10::visit([&](auto&& ord) { | ||
Tensor norm_self = at::linalg_norm(self, ord, dim); | ||
Tensor norm_inverse = at::linalg_norm(self_inverse, ord, dim); | ||
Tensor result = norm_self * norm_inverse; | ||
return result; | ||
}, ord_variant); | ||
} | ||
|
||
// Return zero for each matrix in the batch | ||
Tensor _linalg_cond_empty_matrix(const Tensor& self, c10::ScalarType dtype) { | ||
auto result_shape = IntArrayRef(self.sizes().cbegin(), self.sizes().cend()-2); | ||
return at::zeros(result_shape, self.options().dtype(dtype)); | ||
} | ||
|
||
void _linalg_cond_check_ord(c10::variant<Scalar, std::string> ord_variant) { | ||
if (ord_variant.index() == 0) { | ||
Scalar* ord = c10::get_if<Scalar>(&ord_variant); | ||
double abs_ord = std::abs(ord->toDouble()); | ||
TORCH_CHECK(abs_ord == 2.0 || abs_ord == 1.0 || abs_ord == INFINITY, | ||
"linalg_cond got an invalid norm type: ", ord->toDouble()); | ||
} else if (ord_variant.index() == 1) { | ||
std::string* ord = c10::get_if<std::string>(&ord_variant); | ||
TORCH_CHECK(*ord == "fro" || *ord == "nuc", | ||
"linalg_cond got an invalid norm type: ", *ord); | ||
} else { | ||
TORCH_CHECK(false, | ||
"linalg_cond: something went wrong while checking the norm type"); | ||
} | ||
} | ||
|
||
// Numerical or None norms | ||
Tensor linalg_cond(const Tensor& self, optional<Scalar> opt_ord) { | ||
TORCH_CHECK(self.dim() >= 2, "linalg_cond only supports matrices or batches of matrices, but got a tensor with ", | ||
self.dim(), " dimensions."); | ||
|
||
// The default case is using 2-norm | ||
Scalar ord = opt_ord.has_value() ? opt_ord.value() : 2; | ||
|
||
c10::variant<Scalar, std::string> ord_variant = ord; | ||
_linalg_cond_check_ord(ord_variant); | ||
|
||
// NumPy doesn't define the condition number for 0x0 matrices, we return 0.0 for such input | ||
if (self.numel() == 0) { | ||
auto real_dtype = toValueType(typeMetaToScalarType(self.dtype())); | ||
auto expected_dtype = std::abs(ord.toDouble()) == 2.0 ? real_dtype : self.scalar_type(); | ||
return _linalg_cond_empty_matrix(self, expected_dtype); | ||
} | ||
|
||
// If ord == None or ord == ±2 | ||
if (std::abs(ord.toDouble()) == 2.0) { | ||
auto singular_values = std::get<1>(at::svd(self)); | ||
// singular values are sorted in descending order | ||
auto s_max = at::narrow(singular_values, /*dim=*/-1, /*start=*/0, /*length=*/1); | ||
auto s_min = at::narrow(singular_values, /*dim=*/-1, /*start=*/-1, /*length=*/1); | ||
Tensor result; | ||
if (ord.toDouble() == -2.0) { | ||
result = s_min / s_max; | ||
} else { | ||
result = s_max / s_min; | ||
} | ||
return result; | ||
} | ||
|
||
// ord == ±1 ord == ±inf | ||
// since at::inverse is used in the implementation, self has to be a tensor consisting of square matrices | ||
// the same check as squareCheckInputs(self) but with a slightly more informative error message | ||
TORCH_CHECK(self.size(-1) == self.size(-2), | ||
IvanYashchuk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"linalg_cond with ±1 or ±inf norm types only supports square matrices or batches of square matrices " | ||
"but got ", self.size(-1), " by ", self.size(-2), " matrices"); | ||
|
||
return _linalg_cond_helper(self, ord_variant); | ||
} | ||
|
||
Tensor& linalg_cond_out(Tensor& result, const Tensor& self, optional<Scalar> opt_ord) { | ||
// If ord == None or ord == ±2 then SVD is used to compute the condition number | ||
// the result is always real-valued, for other cases it is complex-valued for the complex-valued input. | ||
ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype())); | ||
IvanYashchuk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Scalar ord = opt_ord.has_value() ? opt_ord.value() : 2; | ||
auto expected_dtype = std::abs(ord.toDouble()) == 2.0 ? real_dtype : self.scalar_type(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Technically instead of casting ord to double unconditionally we should first check that it's not complex and not an integer that will overflow, but that seems onerous. Not for this PR, of course, but I wonder if the Scalar type could/does do that for us automatically? It might catch some errors with the handling of complex scalars. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
TORCH_CHECK(result.scalar_type() == expected_dtype, | ||
"result dtype ", result.scalar_type(), " does not match the expected dtype ", expected_dtype); | ||
|
||
Tensor result_tmp = at::linalg_cond(self, opt_ord); | ||
at::native::resize_output(result, result_tmp.sizes()); | ||
result.copy_(result_tmp); | ||
mruberry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return result; | ||
} | ||
|
||
// Frobenius or nuclear norms | ||
Tensor linalg_cond(const Tensor& self, std::string ord) { | ||
// the same checks as squareCheckInputs(self) but with a slightly more informative error message | ||
TORCH_CHECK(self.dim() >= 2, "linalg_cond only supports matrices or batches of matrices, but got a tensor with ", | ||
self.dim(), " dimensions."); | ||
TORCH_CHECK(self.size(-1) == self.size(-2), | ||
"linalg_cond with frobenius or nuclear norm types only supports square matrices or batches of square matrices " | ||
"but got ", self.size(-1), " by ", self.size(-2), " matrices"); | ||
|
||
c10::variant<Scalar, std::string> ord_variant = ord; | ||
_linalg_cond_check_ord(ord_variant); | ||
|
||
// NumPy doesn't define the condition number for 0x0 matrices, we return 0.0 for such input | ||
if (self.numel() == 0) { | ||
return _linalg_cond_empty_matrix(self, self.scalar_type()); | ||
} | ||
|
||
return _linalg_cond_helper(self, ord_variant); | ||
} | ||
|
||
// TODO: implement _out variant avoiding copy and using already allocated storage directly | ||
Tensor& linalg_cond_out(Tensor& result, const Tensor& self, std::string ord) { | ||
IvanYashchuk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
TORCH_CHECK(result.scalar_type() == self.scalar_type(), | ||
"result dtype ", result.scalar_type(), " does not match the expected dtype ", self.scalar_type()); | ||
|
||
Tensor result_tmp = at::linalg_cond(self, ord); | ||
at::native::resize_output(result, result_tmp.sizes()); | ||
result.copy_(result_tmp); | ||
return result; | ||
} | ||
|
||
Tensor linalg_tensorinv(const Tensor& self, int64_t ind) { | ||
/* | ||
The idea is to reduce the problem to 2D square matrix inversion. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove this second line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Follow-up question, if the intent is to turn inf for every value, should this be a warning and not a check?
Maybe something like...
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've decided here #45832 (comment) that since returning inf for every matrix is not correct we should raise an error. The batched case will be properly implemented in the future once it's possible to call
at::inverse
with silent errors (this should be possible after #48261 is ready, similarly to how it's now possible to callat::solve_out_info
without error checks here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK; I'm conviced.