Skip to content

Commit

Permalink
Fix how torch.linalg.cond handles complex to real downgrade
Browse files Browse the repository at this point in the history
  • Loading branch information
kurtamohler committed Dec 8, 2020
1 parent 0af066f commit df1244f
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 22 deletions.
19 changes: 10 additions & 9 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1779,7 +1779,8 @@ Tensor _linalg_cond_exception_helper(const Tensor& self) {
"linalg_cond does not support yet this case.");
}
auto result_shape = IntArrayRef(self.sizes().cbegin(), self.sizes().cend()-2);
Tensor result = at::full(result_shape, INFINITY, self.options());
TensorOptions options = self.options().dtype(toValueType(self.scalar_type()));
Tensor result = at::full(result_shape, INFINITY, options);
return result;
}

Expand Down Expand Up @@ -1814,7 +1815,8 @@ Tensor _linalg_cond_helper(const Tensor& self, c10::variant<Scalar, std::string>
// Return zero for each matrix in the batch
Tensor _linalg_cond_empty_matrix(const Tensor& self, c10::ScalarType dtype) {
auto result_shape = IntArrayRef(self.sizes().cbegin(), self.sizes().cend()-2);
return at::zeros(result_shape, self.options().dtype(dtype));
TensorOptions options = self.options().dtype(toValueType(self.scalar_type()));
return at::zeros(result_shape, options);
}

void _linalg_cond_check_ord(c10::variant<Scalar, std::string> ord_variant) {
Expand Down Expand Up @@ -1847,8 +1849,7 @@ Tensor linalg_cond(const Tensor& self, optional<Scalar> opt_ord) {
// NumPy doesn't define the condition number for 0x0 matrices, we return 0.0 for such input
if (self.numel() == 0) {
auto real_dtype = toValueType(typeMetaToScalarType(self.dtype()));
auto expected_dtype = std::abs(ord.toDouble()) == 2.0 ? real_dtype : self.scalar_type();
return _linalg_cond_empty_matrix(self, expected_dtype);
return _linalg_cond_empty_matrix(self, real_dtype);
}

// If ord == None or ord == ±2
Expand Down Expand Up @@ -1881,10 +1882,9 @@ Tensor& linalg_cond_out(Tensor& result, const Tensor& self, optional<Scalar> opt
// the result is always real-valued, for other cases it is complex-valued for the complex-valued input.
ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype()));
Scalar ord = opt_ord.has_value() ? opt_ord.value() : 2;
auto expected_dtype = std::abs(ord.toDouble()) == 2.0 ? real_dtype : self.scalar_type();

TORCH_CHECK(result.scalar_type() == expected_dtype,
"result dtype ", result.scalar_type(), " does not match the expected dtype ", expected_dtype);
TORCH_CHECK(result.scalar_type() == real_dtype,
"result dtype ", result.scalar_type(), " does not match the expected dtype ", real_dtype);

Tensor result_tmp = at::linalg_cond(self, opt_ord);
at::native::resize_output(result, result_tmp.sizes());
Expand Down Expand Up @@ -1914,8 +1914,9 @@ Tensor linalg_cond(const Tensor& self, std::string ord) {

// TODO: implement _out variant avoiding copy and using already allocated storage directly
Tensor& linalg_cond_out(Tensor& result, const Tensor& self, std::string ord) {
TORCH_CHECK(result.scalar_type() == self.scalar_type(),
"result dtype ", result.scalar_type(), " does not match the expected dtype ", self.scalar_type());
ScalarType real_type = toValueType(self.scalar_type());
TORCH_CHECK(result.scalar_type() == real_type,
"result dtype ", result.scalar_type(), " does not match the expected dtype ", real_type);

Tensor result_tmp = at::linalg_cond(self, ord);
at::native::resize_output(result, result_tmp.sizes());
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ static Tensor& norm_out(Tensor &result, const Tensor &self, optional<Scalar> opt
ScalarType in_dtype = opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type();
TORCH_CHECK(
at::isFloatingType(in_dtype) || at::isComplexType(in_dtype),
"Can only calculate the mean of floating types. Got ",
"Can only calculate the norm of floating point and complex dtypes. Got ",
toString(in_dtype),
" instead.");

Expand Down
9 changes: 2 additions & 7 deletions test/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,11 +1047,6 @@ def run_test_case(input, p):
for input_size in input_sizes:
input = torch.randn(*input_size, dtype=dtype, device=device)
for p in norm_types:
# frobenius norm not supported for complex tensors
if dtype.is_complex and p == 'fro':
with self.assertRaisesRegex(RuntimeError, "frobenius norm not supported for complex tensors"):
torch.linalg.cond(input, p)
continue
run_test_case(input, p)

# test empty batch sizes
Expand Down Expand Up @@ -1079,7 +1074,7 @@ def run_test_case(input, p):
for input_size in input_sizes:
input = torch.randn(*input_size, dtype=dtype, device=device)
for p in ['fro', 2]:
expected_dtype = a.real.dtype if dtype.is_complex and p == 2 else dtype
expected_dtype = a.real.dtype if dtype.is_complex else dtype
expected = torch.zeros(input_size[:-2], dtype=expected_dtype, device=device)
actual = torch.linalg.cond(input, p)
self.assertEqual(actual, expected)
Expand Down Expand Up @@ -1107,7 +1102,7 @@ def test_cond_errors_and_warnings(self, device, dtype):
# if non-empty out tensor with wrong shape is passed a warning is given
a = torch.ones((2, 2), dtype=dtype, device=device)
for p in ['fro', 2]:
real_dtype = a.real.dtype if dtype.is_complex and p == 2 else dtype
real_dtype = a.real.dtype if dtype.is_complex else dtype
out = torch.empty(a.shape, dtype=real_dtype, device=device)
with warnings.catch_warnings(record=True) as w:
# Trigger warning
Expand Down
4 changes: 2 additions & 2 deletions torch/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,8 +1263,8 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa
Args:
input (Tensor): The input tensor. Its data type must be either a floating
point or complex type. For complex inputs, the norm is calculated on the
absolute values of each element. If the input is complex and neither
point or complex type. For complex inputs, the norm is calculated using the
absolute value of each element. If the input is complex and neither
:attr:`dtype` nor :attr:`out` is specified, the result's data type will
be the corresponding floating point type (e.g. float if :attr:`input` is
complexfloat).
Expand Down
8 changes: 5 additions & 3 deletions torch/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,8 @@
is None. If both :attr:`dim` and :attr:`ord` are None, the 2-norm of the input flattened to 1-D
will be returned. Its data type must be either a floating point or complex type. For complex
inputs, the norm is calculated on of the absolute values of each element. If the input is
complex and neither :attr:`dtype` nor :attr:`out` is specified, the result's data type will be
the corresponding downgraded real number type.
complex and neither :attr:`dtype` nor :attr:`out` is specified, the result's data type will
be the corresponding floating point type (e.g. float if :attr:`input` is complexfloat).
ord (int, float, inf, -inf, 'fro', 'nuc', optional): The order of norm.
inf refers to :attr:`float('inf')`, numpy's :attr:`inf` object, or any equivalent object.
Expand Down Expand Up @@ -412,7 +412,9 @@
times the matrix norm of the inverse of :attr:`input`. And for norms ``p = {None, 2, -2}`` this is defined as
the ratio between the largest and smallest singular values.
This function supports ``float``, ``double``, and only on CPU, ``cfloat`` and ``cdouble`` dtypes for :attr:`input`.
This function supports ``float``, ``double``, ``cfloat`` and ``cdouble`` dtypes for :attr:`input`.
If the input is complex and neither :attr:`dtype` nor :attr:`out` is specified, the result's data type will
be the corresponding floating point type (e.g. float if :attr:`input` is complexfloat).
.. note:: For ``p = {None, 2, -2}`` the condition number is computed as the ratio between the largest
and smallest singular values computed using :func:`torch.linalg.svd`.
Expand Down

0 comments on commit df1244f

Please sign in to comment.