From 7ce8a20c7ebf661f28c6b8c7317a36d0c14b00ae Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 5 Oct 2020 05:17:43 -0500 Subject: [PATCH 01/44] Added linalg_cond --- aten/src/ATen/native/LinearAlgebra.cpp | 21 +++++++++++++ aten/src/ATen/native/native_functions.yaml | 8 +++++ test/test_linalg.py | 30 +++++++++++++++++++ torch/linalg/__init__.py | 35 ++++++++++++++++++++++ 4 files changed, 94 insertions(+) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 9c3742c129de..1d0a9a02acbe 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1599,6 +1599,27 @@ Tensor& linalg_norm_out(Tensor& result, const Tensor& self, std::string ord, opt return linalg_norm_out_impl(result, self, c10::nullopt, ord, opt_dim, keepdim, opt_dtype); } +// Numerical or None norms +Tensor linalg_cond(const Tensor& self, optional opt_ord) { + optional ord = opt_ord.has_value() ? opt_ord : 2; + IntArrayRef dim{-2, -1}; + Tensor self_inverse = at::inverse(self); + Tensor norm_self = at::linalg_norm(self, ord, dim); + Tensor norm_inverse = at::linalg_norm(self_inverse, ord, dim); + Tensor result = norm_self * norm_inverse; + return result; +} + +// Frobenius norm +Tensor linalg_cond(const Tensor& self, std::string ord) { + IntArrayRef dim{-2, -1}; + Tensor self_inverse = at::inverse(self); + Tensor norm_self = at::linalg_norm(self, ord, dim); + Tensor norm_inverse = at::linalg_norm(self_inverse, ord, dim); + Tensor result = norm_self * norm_inverse; + return result; +} + static inline Tensor _chain_matmul_general(TensorList matrices, std::vector>& order, int64_t i, int64_t j) { if (i == j) return matrices[i]; diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index de5e98037277..f1325347e87b 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -8263,6 +8263,14 @@ python_module: linalg variants: function +- func: linalg_cond(Tensor self, Scalar? ord=None) -> Tensor + python_module: linalg + variants: function + +- func: linalg_cond.ord_str(Tensor self, str ord) -> Tensor + python_module: linalg + variants: function + ## Functions that are only for testing # It is undocumented and should not be used outside of tests. - func: _test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor diff --git a/test/test_linalg.py b/test/test_linalg.py index 97c7b926faf4..2d4524d70d10 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -313,6 +313,36 @@ def run_test_case(input, p, dim, keepdim): for ord in ord_settings: run_test_case(input, ord, dim, keepdim) + @precisionOverride({torch.float32: 1e-4}) + @skipCPUIfNoLapack + @skipCUDAIfNoMagma + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + @dtypes(torch.float32, torch.float64) + def test_cond_matrix(self, device, dtype): + def run_test_case(input, ord): + result = torch.linalg.cond(input, ord) + input_numpy = input.cpu().numpy() + result_numpy = np.linalg.cond(input_numpy, ord) + + self.assertEqual(result, result_numpy, rtol=1e-3, atol=self.precision) + + ord_matrix = [1, -1, 2, -2, inf, -inf, 'fro', None] + S = 10 + test_cases = [ + # input size, p settings, dim + ((S, S), ord_matrix), + ((S, S), ord_matrix), + ((S, S), ord_matrix), + ((S, S, S, S), ord_matrix), + ((S, S, S, S), ord_matrix), + ((S, S, S, S), ord_matrix), + ((S, S, S, S), ord_matrix), + ] + for input_size, ord_settings in test_cases: + input = torch.randn(*input_size, dtype=dtype, device=device) + for ord in ord_settings: + run_test_case(input, ord) + # Test autograd and jit functionality for linalg functions. # TODO: Once support for linalg functions is added to method_tests in common_methods_invocations.py, # the `test_cases` entries below should be moved there. These entries are in a similar format, diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 5e2b59c45c80..ac666c97e199 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -140,3 +140,38 @@ >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) (tensor(3.7417), tensor(11.2250)) """) + +cond = _add_docstr(_linalg.linalg_cond, r""" +linalg.norm(input, p=None) -> Tensor + +Returns the condition number of a matrix. + +.. note:: The condition number of :attr:`input` is defined as the norm of + :attr:`input` times the norm of the inverse of :attr:`input`. + +Args: + input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more + batch dimensions. + + p (int, float, inf, -inf, 'fro', optional): The order of norm. + inf refers to :attr:`float('inf')`, numpy's :attr:`inf` object, or any equivalent object. + The following norms can be used: + + ===== ============================ + ord norm for matrices + ===== ============================ + None 2-norm, computed directly using the SVD + 'fro' Frobenius norm + inf max(sum(abs(x), dim=1)) + -inf min(sum(abs(x), dim=1)) + 1 max(sum(abs(x), dim=0)) + -1 min(sum(abs(x), dim=0)) + 2 2-norm (largest sing. value) + -2 smallest singular value + ===== ============================ + + Default: ``None`` + +Returns: + (Tensor): the condition number of the matrix. +""") From dcff9ef4e2a8e0115e22cf10bf457f50d5aec25d Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 5 Oct 2020 15:22:42 -0500 Subject: [PATCH 02/44] Updated the implementation. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Now for p=None and p=±2 condition number is computed with svd. Added errorchecks. --- aten/src/ATen/native/LinearAlgebra.cpp | 48 ++++++++++++++++++++++---- test/test_linalg.py | 24 ++++++------- torch/linalg/__init__.py | 1 + 3 files changed, 54 insertions(+), 19 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 1d0a9a02acbe..be0e1e9b3241 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1601,22 +1601,56 @@ Tensor& linalg_norm_out(Tensor& result, const Tensor& self, std::string ord, opt // Numerical or None norms Tensor linalg_cond(const Tensor& self, optional opt_ord) { - optional ord = opt_ord.has_value() ? opt_ord : 2; - IntArrayRef dim{-2, -1}; - Tensor self_inverse = at::inverse(self); - Tensor norm_self = at::linalg_norm(self, ord, dim); - Tensor norm_inverse = at::linalg_norm(self_inverse, ord, dim); - Tensor result = norm_self * norm_inverse; + TORCH_CHECK(self.numel() > 0, "linalg_cond is not defined for empty tensors."); + TORCH_CHECK(self.dim() >= 2, "Tensor of matrices must have at least 2 dimensions."); + + Tensor self_inverse, result; + + // The default case is using 2-norm + Scalar ord = opt_ord.has_value() ? opt_ord.value() : 2; + + // If ord == None or ord == ±2 + if (std::abs(ord.toDouble()) == 2.0) { + auto singular_values = std::get<1>(at::svd(self)); + auto s_max = std::get<0>(singular_values.max(/*dim=*/-1)); + auto s_min = std::get<0>(singular_values.min(/*dim=*/-1)); + if (ord.toDouble() == -2.0) { + result = s_min / s_max; + } + else { + result = s_max / s_min; + } + } + // ord == ±1 ord == ±inf + else { + squareCheckInputs(self); + // Ignore errors if not invertible, self_inverse should contain NaNs in this case + try { + self_inverse = at::inverse(self); + } catch (...) {} + IntArrayRef dim{-2, -1}; + Tensor norm_self = at::linalg_norm(self, ord, dim); + Tensor norm_inverse = at::linalg_norm(self_inverse, ord, dim); + result = norm_self * norm_inverse; + } + result = at::nan_to_num(result, INFINITY); return result; } // Frobenius norm Tensor linalg_cond(const Tensor& self, std::string ord) { + TORCH_CHECK(self.numel() > 0, "linalg_cond is not defined for empty tensors."); + squareCheckInputs(self); + // Ignore errors if not invertible, self_inverse should contain NaNs in this case + Tensor self_inverse; + try { + self_inverse = at::inverse(self); + } catch (...) {} IntArrayRef dim{-2, -1}; - Tensor self_inverse = at::inverse(self); Tensor norm_self = at::linalg_norm(self, ord, dim); Tensor norm_inverse = at::linalg_norm(self_inverse, ord, dim); Tensor result = norm_self * norm_inverse; + result = at::nan_to_num(result, INFINITY); return result; } diff --git a/test/test_linalg.py b/test/test_linalg.py index 2d4524d70d10..88bc9e2120f7 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -313,30 +313,30 @@ def run_test_case(input, p, dim, keepdim): for ord in ord_settings: run_test_case(input, ord, dim, keepdim) - @precisionOverride({torch.float32: 1e-4}) + @precisionOverride({torch.float32: 1e-3}) @skipCPUIfNoLapack @skipCUDAIfNoMagma @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @dtypes(torch.float32, torch.float64) - def test_cond_matrix(self, device, dtype): + def test_cond(self, device, dtype): def run_test_case(input, ord): result = torch.linalg.cond(input, ord) input_numpy = input.cpu().numpy() result_numpy = np.linalg.cond(input_numpy, ord) - self.assertEqual(result, result_numpy, rtol=1e-3, atol=self.precision) + self.assertEqual(result, result_numpy, rtol=1e-2, atol=self.precision) - ord_matrix = [1, -1, 2, -2, inf, -inf, 'fro', None] + norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None] S = 10 test_cases = [ - # input size, p settings, dim - ((S, S), ord_matrix), - ((S, S), ord_matrix), - ((S, S), ord_matrix), - ((S, S, S, S), ord_matrix), - ((S, S, S, S), ord_matrix), - ((S, S, S, S), ord_matrix), - ((S, S, S, S), ord_matrix), + # input size, norm types settings + ((S, S), norm_types), + ((S, S), norm_types), + ((S, S), norm_types), + ((S, S, S, S), norm_types), + ((S, S, S, S), norm_types), + ((S, S, S, S), norm_types), + ((S, S, S, S), norm_types), ] for input_size, ord_settings in test_cases: input = torch.randn(*input_size, dtype=dtype, device=device) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index ac666c97e199..22041633dd92 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -162,6 +162,7 @@ ===== ============================ None 2-norm, computed directly using the SVD 'fro' Frobenius norm + 'nuc' nuclear norm inf max(sum(abs(x), dim=1)) -inf min(sum(abs(x), dim=1)) 1 max(sum(abs(x), dim=0)) From c88ccc37e77cf8699930d2168d922131ecddd79d Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 3 Nov 2020 05:41:18 -0600 Subject: [PATCH 03/44] Updated non-invertible case implementation --- aten/src/ATen/native/LinearAlgebra.cpp | 12 ++++-- test/test_linalg.py | 57 ++++++++++++++++++-------- 2 files changed, 49 insertions(+), 20 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index fc73dddaa0e5..840b08760852 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1630,7 +1630,10 @@ Tensor linalg_cond(const Tensor& self, optional opt_ord) { // Ignore errors if not invertible, self_inverse should contain NaNs in this case try { self_inverse = at::inverse(self); - } catch (...) {} + } catch (...) { + self_inverse = at::empty_like(self); + at::fill_(self_inverse, NAN); + } IntArrayRef dim{-2, -1}; Tensor norm_self = at::linalg_norm(self, ord, dim); Tensor norm_inverse = at::linalg_norm(self_inverse, ord, dim); @@ -1640,7 +1643,7 @@ Tensor linalg_cond(const Tensor& self, optional opt_ord) { return result; } -// Frobenius norm +// Frobenius or nuclear norms Tensor linalg_cond(const Tensor& self, std::string ord) { TORCH_CHECK(self.numel() > 0, "linalg_cond is not defined for empty tensors."); squareCheckInputs(self); @@ -1648,7 +1651,10 @@ Tensor linalg_cond(const Tensor& self, std::string ord) { Tensor self_inverse; try { self_inverse = at::inverse(self); - } catch (...) {} + } catch (...) { + self_inverse = at::empty_like(self); + at::fill_(self_inverse, NAN); + } IntArrayRef dim{-2, -1}; Tensor norm_self = at::linalg_norm(self, ord, dim); Tensor norm_inverse = at::linalg_norm(self_inverse, ord, dim); diff --git a/test/test_linalg.py b/test/test_linalg.py index c6b9aff65387..c2cb29a13ba8 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -309,36 +309,59 @@ def run_test_case(input, p, dim, keepdim): for ord in ord_settings: run_test_case(input, ord, dim, keepdim) - @precisionOverride({torch.float32: 1e-3}) @skipCPUIfNoLapack @skipCUDAIfNoMagma - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @dtypes(torch.float32, torch.float64) + @precisionOverride({torch.float32: 1e-3}) def test_cond(self, device, dtype): def run_test_case(input, ord): result = torch.linalg.cond(input, ord) - input_numpy = input.cpu().numpy() - result_numpy = np.linalg.cond(input_numpy, ord) + result_numpy = np.linalg.cond(input.cpu().numpy(), ord) self.assertEqual(result, result_numpy, rtol=1e-2, atol=self.precision) norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None] - S = 10 - test_cases = [ - # input size, norm types settings - ((S, S), norm_types), - ((S, S), norm_types), - ((S, S), norm_types), - ((S, S, S, S), norm_types), - ((S, S, S, S), norm_types), - ((S, S, S, S), norm_types), - ((S, S, S, S), norm_types), - ] - for input_size, ord_settings in test_cases: + input_sizes = [(3, 3), (2, 3, 3, 3)] + for input_size in input_sizes: input = torch.randn(*input_size, dtype=dtype, device=device) - for ord in ord_settings: + for ord in norm_types: run_test_case(input, ord) + # test for singular input + a = torch.eye(3, dtype=dtype, device=device) + a[-1, -1] = 0 # make 'a' singular + # NumPy returns inf for ord=±2 + # while current PyTorch implementation returns std::numeric_limits::max() + norm_types = [1, -1, inf, -inf, 'fro', 'nuc'] + for ord in norm_types: + run_test_case(a, ord) + + @skipCPUIfNoLapack + @skipCUDAIfNoMagma + @dtypes(torch.float32, torch.float64) + @precisionOverride({torch.float32: 1e-3}) + def test_cond_errors(self, device, dtype): + norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None] + + # cond expects the input to be non-empty + a = torch.zeros((0, 0), dtype=dtype, device=device) + for ord in norm_types: + with self.assertRaisesRegex(RuntimeError, r'linalg_cond is not defined for empty tensors'): + torch.linalg.cond(a, ord) + + # cond expects the input to be at least 2-dimensional + a = torch.ones(3, dtype=dtype, device=device) + for ord in norm_types: + with self.assertRaisesRegex(RuntimeError, r'Tensor of matrices must have at least 2 dimensions'): + torch.linalg.cond(a, ord) + + # for some norm types cond expects the input to be square + a = torch.ones(3, 2, dtype=dtype, device=device) + norm_types = [1, -1, inf, -inf, 'fro', 'nuc'] + for ord in norm_types: + with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'): + torch.linalg.cond(a, ord) + # Test autograd and jit functionality for linalg functions. # TODO: Once support for linalg functions is added to method_tests in common_methods_invocations.py, # the `test_cases` entries below should be moved there. These entries are in a similar format, From 1e49003f5ad1a81c23a2b7eaa8dd459fb06bd44f Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 3 Nov 2020 05:42:27 -0600 Subject: [PATCH 04/44] Updated docs --- torch/linalg/__init__.py | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 2d5ac3136063..e5998b413544 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -144,9 +144,12 @@ linalg.norm(input, p=None) -> Tensor Returns the condition number of a matrix. +The condition number of :attr:`input` is defined as the norm of +:attr:`input` times the norm of the inverse of :attr:`input`. -.. note:: The condition number of :attr:`input` is defined as the norm of - :attr:`input` times the norm of the inverse of :attr:`input`. +This function supports only real-valued input. + +.. note:: For non-invertible :attr:`input` and `ord` equal to ±2 a large number is returned instead of inf. Args: input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more @@ -172,8 +175,28 @@ Default: ``None`` -Returns: - (Tensor): the condition number of the matrix. +Examples:: + + >>> from torch import linalg as LA + >>> a = torch.tensor([[1., 0, -1], [0, 1, 0], [1, 0, 1]]) + >>> LA.cond(a) + tensor(1.4142) + >>> LA.cond(a, 'fro') + tensor(3.1623) + >>> LA.cond(a, 'nuc') + tensor(9.2426) + >>> LA.cond(a, np.inf) + tensor(2.) + >>> LA.cond(a, -np.inf) + tensor(1.) + >>> LA.cond(a, 1) + tensor(2.) + >>> LA.cond(a, -1) + tensor(1.) + >>> LA.cond(a, 2) + tensor(1.4142) + >>> LA.cond(a, -2) + tensor(0.7071) """) tensorsolve = _add_docstr(_linalg.linalg_tensorsolve, r""" From db2de7ce5e022c7e1d953a3374d332da63c6c00b Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 3 Nov 2020 05:46:00 -0600 Subject: [PATCH 05/44] Added entry to common_methods_invocations.py --- aten/src/ATen/native/native_functions.yaml | 4 ++++ test/test_jit.py | 2 +- .../testing/_internal/common_methods_invocations.py | 12 ++++++++++++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index ded95e9bf4d5..af0f31bc5526 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -8912,10 +8912,14 @@ - func: linalg_cond(Tensor self, Scalar? ord=None) -> Tensor python_module: linalg variants: function + dispatch: + Math: linalg_cond - func: linalg_cond.ord_str(Tensor self, str ord) -> Tensor python_module: linalg variants: function + dispatch: + Math: linalg_cond - func: linalg_tensorsolve(Tensor self, Tensor other, int[]? dims=None) -> Tensor python_module: linalg diff --git a/test/test_jit.py b/test/test_jit.py index 378c88eaa1cf..c6623a9494ce 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -15765,7 +15765,7 @@ def fn(*inputs, **kwargs): check_types=check_types) # alias annotation testing - if not is_magic_method and test_name not in EXCLUDE_SCRIPT: + if not is_magic_method and test_name not in EXCLUDE_SCRIPT and not exclude_tensor_method(name, test_name): check_alias_annotation(name, (self_variable,) + args_variable, kwargs_variable) check(name) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 24ad7531700f..f72fae6a5e43 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1462,6 +1462,16 @@ def method_tests(): ('__getitem__', torch.randn(S, S, S), (dont_convert([[0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])]),), 'adv_index_var'), ('to_sparse', (S, S), (), '', (), (), [], lambda x: x.to_dense()), + ('linalg.cond', (S, S), (), 'default', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('linalg.cond', (S, S, S), (), 'default_batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('linalg.cond', (S, S), (inf,), 'matrix_inf', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('linalg.cond', (S, S), (2,), 'matrix_2', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('linalg.cond', (S, S), (1,), 'matrix_1', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('linalg.cond', (S, S), (-inf,), 'matrix_neg_inf', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('linalg.cond', (S, S), (-2,), 'matrix_neg_2', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('linalg.cond', (S, S), (-1,), 'matrix_neg_1', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('linalg.cond', (S, S), ('fro',), 'fro', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('linalg.cond', (S, S), ('nuc',), 'nuc', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), ] def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwargs=None, dtype=torch.double, device=None): @@ -1756,4 +1766,6 @@ def exclude_tensor_method(name, test_name): return True if 'fft.' in name: return True + if 'linalg.' in name: + return True return False From 843c07fa70544e5261cf2a8623df54e8f435a942 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 3 Nov 2020 05:48:08 -0600 Subject: [PATCH 06/44] Added overrides.py entry --- torch/overrides.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/overrides.py b/torch/overrides.py index 9a91866e3f5b..8be33977a1ff 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -284,6 +284,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.combinations: lambda input, r=2, with_replacement=False: -1, torch.complex: lambda real, imag: -1, torch.polar: lambda abs, ang: -1, + torch.linalg.cond: lambda input, ord=None: -1, torch.conj: lambda input, out=None: -1, torch.constant_pad_nd: lambda input, pad, value=0: -1, torch.conv1d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1, From fd82325519c233927cc16bf3fc9939f7cb041ab1 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 3 Nov 2020 05:53:10 -0600 Subject: [PATCH 07/44] flake8 --- test/test_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index c2cb29a13ba8..f0e9a6bfc2f1 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -329,7 +329,7 @@ def run_test_case(input, ord): # test for singular input a = torch.eye(3, dtype=dtype, device=device) - a[-1, -1] = 0 # make 'a' singular + a[-1, -1] = 0 # make 'a' singular # NumPy returns inf for ord=±2 # while current PyTorch implementation returns std::numeric_limits::max() norm_types = [1, -1, inf, -inf, 'fro', 'nuc'] From ed1b7def88c0c5a066b734d511f2f48023212ec2 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 3 Nov 2020 06:16:06 -0600 Subject: [PATCH 08/44] Implement a trick for replacing FLT_MAX or DBL_MAX to INFINITY --- aten/src/ATen/native/LinearAlgebra.cpp | 9 +++++++++ test/test_linalg.py | 3 --- torch/linalg/__init__.py | 2 -- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 840b08760852..41b4bfde56c8 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1639,6 +1639,10 @@ Tensor linalg_cond(const Tensor& self, optional opt_ord) { Tensor norm_inverse = at::linalg_norm(self_inverse, ord, dim); result = norm_self * norm_inverse; } + // a trick to convert FLT_MAX or DBL_MAX to INFINITY + // result might also contains infs already then inf + inf - inf = nan + // this will get replaced by inf + result = result + result - result; result = at::nan_to_num(result, INFINITY); return result; } @@ -1660,6 +1664,11 @@ Tensor linalg_cond(const Tensor& self, std::string ord) { Tensor norm_inverse = at::linalg_norm(self_inverse, ord, dim); Tensor result = norm_self * norm_inverse; result = at::nan_to_num(result, INFINITY); + // a trick to convert FLT_MAX or DBL_MAX to INFINITY + // result might also contains infs already then inf + inf - inf = nan + // this will get replaced by inf + result = result + result - result; + result = at::nan_to_num(result, INFINITY); return result; } diff --git a/test/test_linalg.py b/test/test_linalg.py index f0e9a6bfc2f1..44bf76f30642 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -330,9 +330,6 @@ def run_test_case(input, ord): # test for singular input a = torch.eye(3, dtype=dtype, device=device) a[-1, -1] = 0 # make 'a' singular - # NumPy returns inf for ord=±2 - # while current PyTorch implementation returns std::numeric_limits::max() - norm_types = [1, -1, inf, -inf, 'fro', 'nuc'] for ord in norm_types: run_test_case(a, ord) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index e5998b413544..8bd1c35595ed 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -149,8 +149,6 @@ This function supports only real-valued input. -.. note:: For non-invertible :attr:`input` and `ord` equal to ±2 a large number is returned instead of inf. - Args: input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more batch dimensions. From 7be03c7830d2f32a3f9ba01cc0551fa04b62557f Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 3 Nov 2020 07:23:58 -0600 Subject: [PATCH 09/44] Added out= variant --- aten/src/ATen/native/LinearAlgebra.cpp | 22 ++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 12 ++++++++++ test/test_linalg.py | 26 ++++++++++++++++++++-- torch/linalg/__init__.py | 5 ++++- 4 files changed, 62 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 41b4bfde56c8..f143bc18c8cd 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1647,6 +1647,17 @@ Tensor linalg_cond(const Tensor& self, optional opt_ord) { return result; } +Tensor& linalg_cond_out(Tensor& result, const Tensor& self, optional opt_ord) { + ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype())); + TORCH_CHECK(result.scalar_type() == real_dtype, + "result dtype ", result.scalar_type(), " does not match self.real dtype ", real_dtype); + + Tensor result_tmp = at::linalg_cond(self, opt_ord); + at::native::resize_output(result, result_tmp.sizes()); + result.copy_(result_tmp); + return result; +} + // Frobenius or nuclear norms Tensor linalg_cond(const Tensor& self, std::string ord) { TORCH_CHECK(self.numel() > 0, "linalg_cond is not defined for empty tensors."); @@ -1672,6 +1683,17 @@ Tensor linalg_cond(const Tensor& self, std::string ord) { return result; } +Tensor& linalg_cond_out(Tensor& result, const Tensor& self, std::string ord) { + ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype())); + TORCH_CHECK(result.scalar_type() == real_dtype, + "result dtype ", result.scalar_type(), " does not match self.real dtype ", real_dtype); + + Tensor result_tmp = at::linalg_cond(self, ord); + at::native::resize_output(result, result_tmp.sizes()); + result.copy_(result_tmp); + return result; +} + Tensor linalg_tensorsolve(const Tensor& self, const Tensor& other, optional dims) { /* The idea is to reduce the problem to 2D matrix solve. diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index af0f31bc5526..7dbba915522c 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -8915,12 +8915,24 @@ dispatch: Math: linalg_cond +- func: linalg_cond.out(Tensor self, Scalar? ord=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + variants: function + dispatch: + Math: linalg_cond_out + - func: linalg_cond.ord_str(Tensor self, str ord) -> Tensor python_module: linalg variants: function dispatch: Math: linalg_cond +- func: linalg_cond.ord_str_out(Tensor self, str ord, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + variants: function + dispatch: + Math: linalg_cond_out + - func: linalg_tensorsolve(Tensor self, Tensor other, int[]? dims=None) -> Tensor python_module: linalg variants: function diff --git a/test/test_linalg.py b/test/test_linalg.py index 44bf76f30642..e9745bb11e39 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -317,9 +317,14 @@ def test_cond(self, device, dtype): def run_test_case(input, ord): result = torch.linalg.cond(input, ord) result_numpy = np.linalg.cond(input.cpu().numpy(), ord) - self.assertEqual(result, result_numpy, rtol=1e-2, atol=self.precision) + # test out= variant + out = torch.empty_like(result) + ans = torch.linalg.cond(input, ord, out=out) + self.assertEqual(ans, out) + self.assertEqual(ans, result) + norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None] input_sizes = [(3, 3), (2, 3, 3, 3)] for input_size in input_sizes: @@ -337,7 +342,7 @@ def run_test_case(input, ord): @skipCUDAIfNoMagma @dtypes(torch.float32, torch.float64) @precisionOverride({torch.float32: 1e-3}) - def test_cond_errors(self, device, dtype): + def test_cond_errors_and_warnings(self, device, dtype): norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None] # cond expects the input to be non-empty @@ -359,6 +364,23 @@ def test_cond_errors(self, device, dtype): with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'): torch.linalg.cond(a, ord) + # if non-empty out tensor with wrong shape is passed a warning is given + a = torch.ones((2, 2), dtype=dtype, device=device) + for ord in ['fro', 2]: + out = torch.empty_like(a) + with warnings.catch_warnings(record=True) as w: + # Trigger warning + torch.linalg.cond(a, ord, out=out) + # Check warning occurs + self.assertEqual(len(w), 1) + self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) + + # dtypes should match + out = torch.empty_like(a).to(torch.int) + for ord in ['fro', 2]: + with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match"): + torch.linalg.cond(a, ord, out=out) + # Test autograd and jit functionality for linalg functions. # TODO: Once support for linalg functions is added to method_tests in common_methods_invocations.py, # the `test_cases` entries below should be moved there. These entries are in a similar format, diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 8bd1c35595ed..de77aff73431 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -141,7 +141,7 @@ """) cond = _add_docstr(_linalg.linalg_cond, r""" -linalg.norm(input, p=None) -> Tensor +linalg.norm(input, p=None, *, out=None) -> Tensor Returns the condition number of a matrix. The condition number of :attr:`input` is defined as the norm of @@ -173,6 +173,9 @@ Default: ``None`` +Keyword args: + out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` + Examples:: >>> from torch import linalg as LA From f36e7329b390b660ae24ce1ab1b3ed0b86b6a65d Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 3 Nov 2020 13:38:41 -0600 Subject: [PATCH 10/44] Updated implementation and tests. Complex input is now supported. --- aten/src/ATen/native/LinearAlgebra.cpp | 66 +++++++++++++++----------- test/test_linalg.py | 31 ++++++++++-- torch/linalg/__init__.py | 2 +- 3 files changed, 67 insertions(+), 32 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index f143bc18c8cd..e5b86e205cad 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1623,34 +1623,44 @@ Tensor linalg_cond(const Tensor& self, optional opt_ord) { else { result = s_max / s_min; } + // a trick to convert FLT_MAX or DBL_MAX to INFINITY + // result might also contains infs already then inf + inf - inf = nan + // this will get replaced by inf + result = result + result - result; + result = at::nan_to_num(result, INFINITY); + return result; } // ord == ±1 ord == ±inf else { squareCheckInputs(self); - // Ignore errors if not invertible, self_inverse should contain NaNs in this case + IntArrayRef dim{-2, -1}; + // Ignore errors if not invertible, result is INFINITY in this case try { self_inverse = at::inverse(self); - } catch (...) { - self_inverse = at::empty_like(self); - at::fill_(self_inverse, NAN); + } catch (const std::exception& e) { + if (strstr(e.what(), "singular")) { + result = at::empty_like(at::linalg_norm(self, ord, dim)); + at::fill_(result, INFINITY); + return result; + } + else { + TORCH_CHECK(false, "linalg_cond got an unexpected error:\n", e.what()); + } } - IntArrayRef dim{-2, -1}; Tensor norm_self = at::linalg_norm(self, ord, dim); Tensor norm_inverse = at::linalg_norm(self_inverse, ord, dim); result = norm_self * norm_inverse; } - // a trick to convert FLT_MAX or DBL_MAX to INFINITY - // result might also contains infs already then inf + inf - inf = nan - // this will get replaced by inf - result = result + result - result; - result = at::nan_to_num(result, INFINITY); return result; } Tensor& linalg_cond_out(Tensor& result, const Tensor& self, optional opt_ord) { ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype())); - TORCH_CHECK(result.scalar_type() == real_dtype, - "result dtype ", result.scalar_type(), " does not match self.real dtype ", real_dtype); + Scalar ord = opt_ord.has_value() ? opt_ord.value() : 2; + auto expected_dtype = std::abs(ord.toDouble()) == 2.0 ? real_dtype : self.scalar_type(); + + TORCH_CHECK(result.scalar_type() == expected_dtype, + "result dtype ", result.scalar_type(), " does not match the expected dtype ", expected_dtype); Tensor result_tmp = at::linalg_cond(self, opt_ord); at::native::resize_output(result, result_tmp.sizes()); @@ -1662,31 +1672,31 @@ Tensor& linalg_cond_out(Tensor& result, const Tensor& self, optional opt Tensor linalg_cond(const Tensor& self, std::string ord) { TORCH_CHECK(self.numel() > 0, "linalg_cond is not defined for empty tensors."); squareCheckInputs(self); - // Ignore errors if not invertible, self_inverse should contain NaNs in this case - Tensor self_inverse; + + Tensor self_inverse, result; + IntArrayRef dim{-2, -1}; + // Ignore errors if not invertible, result is INFINITY in this case try { self_inverse = at::inverse(self); - } catch (...) { - self_inverse = at::empty_like(self); - at::fill_(self_inverse, NAN); + } catch (const std::exception& e) { + if (strstr(e.what(), "singular")) { + result = at::empty_like(at::linalg_norm(self, ord, dim)); + at::fill_(result, INFINITY); + return result; + } + else { + TORCH_CHECK(false, "linalg_cond got an unexpected error:\n", e.what()); + } } - IntArrayRef dim{-2, -1}; Tensor norm_self = at::linalg_norm(self, ord, dim); Tensor norm_inverse = at::linalg_norm(self_inverse, ord, dim); - Tensor result = norm_self * norm_inverse; - result = at::nan_to_num(result, INFINITY); - // a trick to convert FLT_MAX or DBL_MAX to INFINITY - // result might also contains infs already then inf + inf - inf = nan - // this will get replaced by inf - result = result + result - result; - result = at::nan_to_num(result, INFINITY); + result = norm_self * norm_inverse; return result; } Tensor& linalg_cond_out(Tensor& result, const Tensor& self, std::string ord) { - ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype())); - TORCH_CHECK(result.scalar_type() == real_dtype, - "result dtype ", result.scalar_type(), " does not match self.real dtype ", real_dtype); + TORCH_CHECK(result.scalar_type() == self.scalar_type(), + "result dtype ", result.scalar_type(), " does not match the expected dtype ", self.scalar_type()); Tensor result_tmp = at::linalg_cond(self, ord); at::native::resize_output(result, result_tmp.sizes()); diff --git a/test/test_linalg.py b/test/test_linalg.py index e9745bb11e39..1c43de42270d 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -311,7 +311,8 @@ def run_test_case(input, p, dim, keepdim): @skipCPUIfNoLapack @skipCUDAIfNoMagma - @dtypes(torch.float32, torch.float64) + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @dtypesIfCUDA(torch.float32, torch.float64) @precisionOverride({torch.float32: 1e-3}) def test_cond(self, device, dtype): def run_test_case(input, ord): @@ -330,17 +331,37 @@ def run_test_case(input, ord): for input_size in input_sizes: input = torch.randn(*input_size, dtype=dtype, device=device) for ord in norm_types: + # frobenius norm not supported for complex tensors + if dtype.is_complex and ord == 'fro': + continue run_test_case(input, ord) # test for singular input a = torch.eye(3, dtype=dtype, device=device) a[-1, -1] = 0 # make 'a' singular for ord in norm_types: + # frobenius norm not supported for complex tensors + if dtype.is_complex and ord == 'fro': + continue run_test_case(a, ord) + # TODO: once "inverse_cuda" supports complex dtypes, they shall be added to above tests + @unittest.expectedFailure + @onlyCUDA + @skipCUDAIfNoMagma + @dtypes(torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-3}) + def test_cond(self, device, dtype): + input_size = (3, 3) + ord = 1 + torch.randn(*input_size, dtype=dtype, device=device) + result = torch.linalg.cond(input, ord) + result_numpy = np.linalg.cond(input.cpu().numpy(), ord) + self.assertEqual(result, result_numpy, rtol=1e-2, atol=self.precision) + @skipCPUIfNoLapack @skipCUDAIfNoMagma - @dtypes(torch.float32, torch.float64) + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) @precisionOverride({torch.float32: 1e-3}) def test_cond_errors_and_warnings(self, device, dtype): norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None] @@ -367,7 +388,11 @@ def test_cond_errors_and_warnings(self, device, dtype): # if non-empty out tensor with wrong shape is passed a warning is given a = torch.ones((2, 2), dtype=dtype, device=device) for ord in ['fro', 2]: - out = torch.empty_like(a) + # frobenius norm not supported for complex tensors + if dtype.is_complex and ord == 'fro': + continue + real_dtype = a.real.dtype if dtype.is_complex else dtype + out = torch.empty(a.shape, dtype=real_dtype, device=device) with warnings.catch_warnings(record=True) as w: # Trigger warning torch.linalg.cond(a, ord, out=out) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index de77aff73431..8cc81c6e6701 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -147,7 +147,7 @@ The condition number of :attr:`input` is defined as the norm of :attr:`input` times the norm of the inverse of :attr:`input`. -This function supports only real-valued input. +This function supports real-valued, and only on CPU, complex-valued input. Args: input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more From de025dcf1113c4d9012e322d7441a104648b71f8 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 4 Nov 2020 03:48:12 -0600 Subject: [PATCH 11/44] Fix typo --- test/test_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 1c43de42270d..24e25217fffb 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -351,7 +351,7 @@ def run_test_case(input, ord): @skipCUDAIfNoMagma @dtypes(torch.complex64, torch.complex128) @precisionOverride({torch.float32: 1e-3}) - def test_cond(self, device, dtype): + def test_cond_xfailed(self, device, dtype): input_size = (3, 3) ord = 1 torch.randn(*input_size, dtype=dtype, device=device) From 43f52bcdcf84bb6892bac62a023df261d4b83922 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 4 Nov 2020 04:44:24 -0600 Subject: [PATCH 12/44] Trying to fix initialization of 'dim' --- aten/src/ATen/native/LinearAlgebra.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index e5b86e205cad..91ccfafb7ae9 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1633,7 +1633,7 @@ Tensor linalg_cond(const Tensor& self, optional opt_ord) { // ord == ±1 ord == ±inf else { squareCheckInputs(self); - IntArrayRef dim{-2, -1}; + IntArrayRef dim({-2, -1}); // Ignore errors if not invertible, result is INFINITY in this case try { self_inverse = at::inverse(self); @@ -1674,7 +1674,7 @@ Tensor linalg_cond(const Tensor& self, std::string ord) { squareCheckInputs(self); Tensor self_inverse, result; - IntArrayRef dim{-2, -1}; + IntArrayRef dim({-2, -1}); // Ignore errors if not invertible, result is INFINITY in this case try { self_inverse = at::inverse(self); From 6bded4583a8cf317070da3b47e53d7be0bd0070a Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 4 Nov 2020 11:06:46 -0600 Subject: [PATCH 13/44] Trying to fix initialization of 'dim' --- aten/src/ATen/native/LinearAlgebra.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 58f80f195f99..ea438b330624 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1633,7 +1633,8 @@ Tensor linalg_cond(const Tensor& self, optional opt_ord) { // ord == ±1 ord == ±inf else { squareCheckInputs(self); - IntArrayRef dim({-2, -1}); + std::array dim_arr = {-2, -1}; + optional dim = IntArrayRef(dim_arr); // Ignore errors if not invertible, result is INFINITY in this case try { self_inverse = at::inverse(self); @@ -1674,7 +1675,8 @@ Tensor linalg_cond(const Tensor& self, std::string ord) { squareCheckInputs(self); Tensor self_inverse, result; - IntArrayRef dim({-2, -1}); + std::array dim_arr = {-2, -1}; + optional dim = IntArrayRef(dim_arr); // Ignore errors if not invertible, result is INFINITY in this case try { self_inverse = at::inverse(self); From 7a4dac6f83ee1f30fb2cacd4fa78f4ce7fa390c4 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 9 Nov 2020 11:46:05 -0600 Subject: [PATCH 14/44] Small changes in the implementation structure --- aten/src/ATen/native/LinearAlgebra.cpp | 67 ++++++++++++++++---------- test/test_linalg.py | 4 +- 2 files changed, 44 insertions(+), 27 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index ea438b330624..27cff26ee411 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1605,9 +1605,8 @@ Tensor& linalg_norm_out(Tensor& result, const Tensor& self, std::string ord, opt // Numerical or None norms Tensor linalg_cond(const Tensor& self, optional opt_ord) { TORCH_CHECK(self.numel() > 0, "linalg_cond is not defined for empty tensors."); - TORCH_CHECK(self.dim() >= 2, "Tensor of matrices must have at least 2 dimensions."); - - Tensor self_inverse, result; + TORCH_CHECK(self.dim() >= 2, "linalg_cond only supports matrices or batches of matrices, but got a tensor with ", + self.dim(), " dimensions."); // The default case is using 2-norm Scalar ord = opt_ord.has_value() ? opt_ord.value() : 2; @@ -1617,10 +1616,10 @@ Tensor linalg_cond(const Tensor& self, optional opt_ord) { auto singular_values = std::get<1>(at::svd(self)); auto s_max = std::get<0>(singular_values.max(/*dim=*/-1)); auto s_min = std::get<0>(singular_values.min(/*dim=*/-1)); + Tensor result; if (ord.toDouble() == -2.0) { result = s_min / s_max; - } - else { + } else { result = s_max / s_min; } // a trick to convert FLT_MAX or DBL_MAX to INFINITY @@ -1630,32 +1629,41 @@ Tensor linalg_cond(const Tensor& self, optional opt_ord) { result = at::nan_to_num(result, INFINITY); return result; } + // ord == ±1 ord == ±inf - else { - squareCheckInputs(self); - std::array dim_arr = {-2, -1}; - optional dim = IntArrayRef(dim_arr); - // Ignore errors if not invertible, result is INFINITY in this case - try { - self_inverse = at::inverse(self); - } catch (const std::exception& e) { - if (strstr(e.what(), "singular")) { - result = at::empty_like(at::linalg_norm(self, ord, dim)); - at::fill_(result, INFINITY); - return result; - } - else { - TORCH_CHECK(false, "linalg_cond got an unexpected error:\n", e.what()); - } + // since at::inverse is used in the implementation, self has to be a tensor consisting of square matrices + // the same check as squareCheckInputs(self) but with a slightly more informative error message + TORCH_CHECK(self.size(-1) == self.size(-2), + "linalg_cond only supports square matrices or batches of square matrices " + "but got ", self.size(-1), " by ", self.size(-2), " matrices"); + std::array dim_arr = {-2, -1}; + optional dim = IntArrayRef(dim_arr); + // Ignore errors if not invertible, result is INFINITY in this case + // Currently checking for error in at::inverse causes cross-device data movement + // For batched input if at least one matrix in the batch is not invertible, + // then the result for all other (possibly) invertible matrices will be infinity as well + // since there is currently no way to use at::inverse with silent errors + Tensor self_inverse, result; + try { + self_inverse = at::inverse(self); + } catch (const std::exception& e) { + if (strstr(e.what(), "singular")) { + result = at::empty_like(at::linalg_norm(self, ord, dim)); + at::fill_(result, INFINITY); + return result; + } else { + TORCH_CHECK(false, "linalg_cond got an unexpected error:\n", e.what()); } - Tensor norm_self = at::linalg_norm(self, ord, dim); - Tensor norm_inverse = at::linalg_norm(self_inverse, ord, dim); - result = norm_self * norm_inverse; } + Tensor norm_self = at::linalg_norm(self, ord, dim); + Tensor norm_inverse = at::linalg_norm(self_inverse, ord, dim); + result = norm_self * norm_inverse; return result; } Tensor& linalg_cond_out(Tensor& result, const Tensor& self, optional opt_ord) { + // If ord == None or ord == ±2 then SVD is used to compute the condition number + // the result is always real-valued, for other cases it is complex-valued for the complex-valued input. ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype())); Scalar ord = opt_ord.has_value() ? opt_ord.value() : 2; auto expected_dtype = std::abs(ord.toDouble()) == 2.0 ? real_dtype : self.scalar_type(); @@ -1672,12 +1680,21 @@ Tensor& linalg_cond_out(Tensor& result, const Tensor& self, optional opt // Frobenius or nuclear norms Tensor linalg_cond(const Tensor& self, std::string ord) { TORCH_CHECK(self.numel() > 0, "linalg_cond is not defined for empty tensors."); - squareCheckInputs(self); + // the same checks as squareCheckInputs(self) but with a slightly more informative error message + TORCH_CHECK(self.dim() >= 2, "linalg_cond only supports matrices or batches of matrices, but got a tensor with ", + self.dim(), " dimensions."); + TORCH_CHECK(self.size(-1) == self.size(-2), + "linalg_cond only supports square matrices or batches of square matrices " + "but got ", self.size(-1), " by ", self.size(-2), " matrices"); Tensor self_inverse, result; std::array dim_arr = {-2, -1}; optional dim = IntArrayRef(dim_arr); // Ignore errors if not invertible, result is INFINITY in this case + // Currently checking for error in at::inverse causes cross-device data movement + // For batched input if at least one matrix in the batch is not invertible, + // then the result for all other (possibly) invertible matrices will be infinity as well + // since there is currently no way to use at::inverse with silent errors try { self_inverse = at::inverse(self); } catch (const std::exception& e) { diff --git a/test/test_linalg.py b/test/test_linalg.py index 8945c5e2c631..8cde6a91843e 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -476,14 +476,14 @@ def test_cond_errors_and_warnings(self, device, dtype): # cond expects the input to be at least 2-dimensional a = torch.ones(3, dtype=dtype, device=device) for ord in norm_types: - with self.assertRaisesRegex(RuntimeError, r'Tensor of matrices must have at least 2 dimensions'): + with self.assertRaisesRegex(RuntimeError, r'supports matrices or batches of matrices'): torch.linalg.cond(a, ord) # for some norm types cond expects the input to be square a = torch.ones(3, 2, dtype=dtype, device=device) norm_types = [1, -1, inf, -inf, 'fro', 'nuc'] for ord in norm_types: - with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'): + with self.assertRaisesRegex(RuntimeError, r'supports square matrices or batches of square matrices'): torch.linalg.cond(a, ord) # if non-empty out tensor with wrong shape is passed a warning is given From 5705429a09c85f18e306e98473e1740de9cc5b6e Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 9 Nov 2020 12:10:14 -0600 Subject: [PATCH 15/44] Use the explicit shape for the result instead of dummy norm computations --- aten/src/ATen/native/LinearAlgebra.cpp | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 27cff26ee411..95f54f1ce6d7 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1648,9 +1648,12 @@ Tensor linalg_cond(const Tensor& self, optional opt_ord) { self_inverse = at::inverse(self); } catch (const std::exception& e) { if (strstr(e.what(), "singular")) { - result = at::empty_like(at::linalg_norm(self, ord, dim)); - at::fill_(result, INFINITY); - return result; + auto result_shape = self.sizes().vec(); + result_shape.pop_back(); + result_shape.pop_back(); // result's shape is equal to self.shape[0:-2] + result = at::empty(result_shape, self.options()); + at::fill_(result, INFINITY); + return result; } else { TORCH_CHECK(false, "linalg_cond got an unexpected error:\n", e.what()); } @@ -1699,11 +1702,13 @@ Tensor linalg_cond(const Tensor& self, std::string ord) { self_inverse = at::inverse(self); } catch (const std::exception& e) { if (strstr(e.what(), "singular")) { - result = at::empty_like(at::linalg_norm(self, ord, dim)); - at::fill_(result, INFINITY); - return result; - } - else { + auto result_shape = self.sizes().vec(); + result_shape.pop_back(); + result_shape.pop_back(); // result's shape is equal to self.shape[0:-2] + result = at::empty(result_shape, self.options()); + at::fill_(result, INFINITY); + return result; + } else { TORCH_CHECK(false, "linalg_cond got an unexpected error:\n", e.what()); } } From 0d9d49f0ed56bd3f4797dfc41cf825abcc39b7ce Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 10 Nov 2020 06:09:57 -0600 Subject: [PATCH 16/44] Fixed conversion FLT_MAX / DBL_MAX -> INFINITY --- aten/src/ATen/native/LinearAlgebra.cpp | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 95f54f1ce6d7..6f2a9ba249f7 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1622,12 +1622,26 @@ Tensor linalg_cond(const Tensor& self, optional opt_ord) { } else { result = s_max / s_min; } - // a trick to convert FLT_MAX or DBL_MAX to INFINITY - // result might also contains infs already then inf + inf - inf = nan - // this will get replaced by inf - result = result + result - result; - result = at::nan_to_num(result, INFINITY); - return result; + // convert FLT_MAX or DBL_MAX to INFINITY for NumPy compatibility + switch (result.scalar_type()) { + case ScalarType::Double: { + Scalar dbl_max = std::numeric_limits::max(); // DBL_MAX + Scalar inf = std::numeric_limits::infinity(); // HUGE_VAL + return at::where(result == dbl_max, inf, result); + } + case ScalarType::Float: { + Scalar flt_max = std::numeric_limits::max(); // FLT_MAX + float inf = std::numeric_limits::infinity(); // HUGE_VALF + // Scalar.dtype() is always ScalarType::Double for isFloatingPoint() = true + // and at::where doesn't allow arguments with different dtype + // so let's use 0-dim tensor filled with inf + Tensor inf_tensor = at::empty({}, result.options()); + at::fill_(inf_tensor, inf); + return at::where(result == flt_max, inf_tensor, result); + } + default: + TORCH_CHECK(false, "linalg_cond got an unexpected result type ", toString(result.scalar_type())); + } } // ord == ±1 ord == ±inf From ae8878f47472d6591b19e2be780bb38abd522491 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 10 Nov 2020 09:15:43 -0600 Subject: [PATCH 17/44] Added a warning for batched singular input; slightly refactored the tests --- aten/src/ATen/native/LinearAlgebra.cpp | 50 +++++++++++++---------- test/test_linalg.py | 55 ++++++++++++++++---------- 2 files changed, 64 insertions(+), 41 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 6f2a9ba249f7..9c4d27878b65 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1602,6 +1602,25 @@ Tensor& linalg_norm_out(Tensor& result, const Tensor& self, std::string ord, opt return linalg_norm_out_impl(result, self, c10::nullopt, ord, opt_dim, keepdim, opt_dtype); } +Tensor _linalg_cond_exception_helper(const Tensor& self) { + // For batched input if at least one matrix in the batch is not invertible, + // then the result for all other (possibly) invertible matrices will be infinity as well + // since there is currently no way to use at::inverse with silent errors + + auto result_shape = self.sizes().vec(); + result_shape.pop_back(); + result_shape.pop_back(); // result's shape is equal to self.shape[0:-2] + Tensor result = at::empty(result_shape, self.options()); + at::fill_(result, INFINITY); + if (self.dim() > 2) { + // Should this be the not-implemented-error? + TORCH_WARN( + "linalg_cond for the batched input returns infinity for all (possibly invertible) matrices in the batch, " + "if at least one matrix in the batch is not invertible."); + } + return result; +} + // Numerical or None norms Tensor linalg_cond(const Tensor& self, optional opt_ord) { TORCH_CHECK(self.numel() > 0, "linalg_cond is not defined for empty tensors."); @@ -1650,31 +1669,27 @@ Tensor linalg_cond(const Tensor& self, optional opt_ord) { TORCH_CHECK(self.size(-1) == self.size(-2), "linalg_cond only supports square matrices or batches of square matrices " "but got ", self.size(-1), " by ", self.size(-2), " matrices"); - std::array dim_arr = {-2, -1}; - optional dim = IntArrayRef(dim_arr); + // Ignore errors if not invertible, result is INFINITY in this case // Currently checking for error in at::inverse causes cross-device data movement // For batched input if at least one matrix in the batch is not invertible, // then the result for all other (possibly) invertible matrices will be infinity as well // since there is currently no way to use at::inverse with silent errors - Tensor self_inverse, result; + Tensor self_inverse; try { self_inverse = at::inverse(self); } catch (const std::exception& e) { if (strstr(e.what(), "singular")) { - auto result_shape = self.sizes().vec(); - result_shape.pop_back(); - result_shape.pop_back(); // result's shape is equal to self.shape[0:-2] - result = at::empty(result_shape, self.options()); - at::fill_(result, INFINITY); - return result; + return _linalg_cond_exception_helper(self); } else { TORCH_CHECK(false, "linalg_cond got an unexpected error:\n", e.what()); } } + std::array dim_arr = {-2, -1}; + optional dim = IntArrayRef(dim_arr); Tensor norm_self = at::linalg_norm(self, ord, dim); Tensor norm_inverse = at::linalg_norm(self_inverse, ord, dim); - result = norm_self * norm_inverse; + Tensor result = norm_self * norm_inverse; return result; } @@ -1704,31 +1719,26 @@ Tensor linalg_cond(const Tensor& self, std::string ord) { "linalg_cond only supports square matrices or batches of square matrices " "but got ", self.size(-1), " by ", self.size(-2), " matrices"); - Tensor self_inverse, result; - std::array dim_arr = {-2, -1}; - optional dim = IntArrayRef(dim_arr); // Ignore errors if not invertible, result is INFINITY in this case // Currently checking for error in at::inverse causes cross-device data movement // For batched input if at least one matrix in the batch is not invertible, // then the result for all other (possibly) invertible matrices will be infinity as well // since there is currently no way to use at::inverse with silent errors + Tensor self_inverse; try { self_inverse = at::inverse(self); } catch (const std::exception& e) { if (strstr(e.what(), "singular")) { - auto result_shape = self.sizes().vec(); - result_shape.pop_back(); - result_shape.pop_back(); // result's shape is equal to self.shape[0:-2] - result = at::empty(result_shape, self.options()); - at::fill_(result, INFINITY); - return result; + return _linalg_cond_exception_helper(self); } else { TORCH_CHECK(false, "linalg_cond got an unexpected error:\n", e.what()); } } + std::array dim_arr = {-2, -1}; + optional dim = IntArrayRef(dim_arr); Tensor norm_self = at::linalg_norm(self, ord, dim); Tensor norm_inverse = at::linalg_norm(self_inverse, ord, dim); - result = norm_self * norm_inverse; + Tensor result = norm_self * norm_inverse; return result; } diff --git a/test/test_linalg.py b/test/test_linalg.py index 8cde6a91843e..ef3e884081d4 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -434,6 +434,8 @@ def run_test_case(input, ord): for ord in norm_types: # frobenius norm not supported for complex tensors if dtype.is_complex and ord == 'fro': + with self.assertRaisesRegex(RuntimeError, "frobenius norm not supported for complex tensors"): + torch.linalg.cond(input, ord) continue run_test_case(input, ord) @@ -441,28 +443,12 @@ def run_test_case(input, ord): a = torch.eye(3, dtype=dtype, device=device) a[-1, -1] = 0 # make 'a' singular for ord in norm_types: - # frobenius norm not supported for complex tensors - if dtype.is_complex and ord == 'fro': - continue run_test_case(a, ord) - # TODO: once "inverse_cuda" supports complex dtypes, they shall be added to above tests - @unittest.expectedFailure - @onlyCUDA - @skipCUDAIfNoMagma - @dtypes(torch.complex64, torch.complex128) - @precisionOverride({torch.float32: 1e-3}) - def test_cond_xfailed(self, device, dtype): - input_size = (3, 3) - ord = 1 - torch.randn(*input_size, dtype=dtype, device=device) - result = torch.linalg.cond(input, ord) - result_numpy = np.linalg.cond(input.cpu().numpy(), ord) - self.assertEqual(result, result_numpy, rtol=1e-2, atol=self.precision) - @skipCPUIfNoLapack @skipCUDAIfNoMagma @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @dtypesIfCUDA(torch.float32, torch.float64) @precisionOverride({torch.float32: 1e-3}) def test_cond_errors_and_warnings(self, device, dtype): norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None] @@ -489,10 +475,7 @@ def test_cond_errors_and_warnings(self, device, dtype): # if non-empty out tensor with wrong shape is passed a warning is given a = torch.ones((2, 2), dtype=dtype, device=device) for ord in ['fro', 2]: - # frobenius norm not supported for complex tensors - if dtype.is_complex and ord == 'fro': - continue - real_dtype = a.real.dtype if dtype.is_complex else dtype + real_dtype = a.real.dtype if dtype.is_complex and ord == 2 else dtype out = torch.empty(a.shape, dtype=real_dtype, device=device) with warnings.catch_warnings(record=True) as w: # Trigger warning @@ -507,6 +490,36 @@ def test_cond_errors_and_warnings(self, device, dtype): with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match"): torch.linalg.cond(a, ord, out=out) + # for batched input if at least one matrix in the batch is not invertible, + # then the result for all other (possibly) invertible matrices will be infinity as well + # since there is currently no way to use torch.inverse with silent errors + batch_dim = 3 + a = torch.eye(3, 3, dtype=dtype, device=device) + a = a.reshape((1, 3, 3)) + a = a.repeat(batch_dim, 1, 1) + a[0, -1, -1] = 0 # now a[0] is singular + for ord in [1, -1, inf, -inf, 'fro', 'nuc']: + with warnings.catch_warnings(record=True) as w: + # Trigger warning + torch.linalg.cond(a, ord) + # Check warning occurs + self.assertEqual(len(w), 1) + self.assertTrue("for the batched input returns infinity for all" in str(w[-1].message)) + + # TODO: once "inverse_cuda" supports complex dtypes, they shall be added to above tests + @unittest.expectedFailure + @onlyCUDA + @skipCUDAIfNoMagma + @dtypes(torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-3}) + def test_cond_xfailed(self, device, dtype): + input_size = (3, 3) + ord = 1 + torch.randn(*input_size, dtype=dtype, device=device) + result = torch.linalg.cond(input, ord) + result_numpy = np.linalg.cond(input.cpu().numpy(), ord) + self.assertEqual(result, result_numpy, rtol=1e-2, atol=self.precision) + # Test autograd and jit functionality for linalg functions. # TODO: Once support for linalg functions is added to method_tests in common_methods_invocations.py, # the `test_cases` entries below should be moved there. These entries are in a similar format, From 8dc29e0d6d832efa1135e1db76718c762acb1511 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 10 Nov 2020 09:16:08 -0600 Subject: [PATCH 18/44] Updated documentation --- torch/linalg/__init__.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 8cc81c6e6701..59b1e35a9b0d 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -143,22 +143,29 @@ cond = _add_docstr(_linalg.linalg_cond, r""" linalg.norm(input, p=None, *, out=None) -> Tensor -Returns the condition number of a matrix. -The condition number of :attr:`input` is defined as the norm of -:attr:`input` times the norm of the inverse of :attr:`input`. +Computes the condition number of a matrix :attr:`input`, +or of each matrix in a batched :attr:`input`, using the matrix norm defined by :attr:`p`. +The condition number is defined as the matrix norm of +:attr:`input` times the matrix norm of the inverse of :attr:`input`. This function supports real-valued, and only on CPU, complex-valued input. +.. note:: For `p = {None, 2, -2}` :attr:`input` can be non-square. For other norm types :attr:`input` must be + a square matrix or a batch of square matrices. If :attr:`input` does not satisfy the requirement + then a RuntimeError will be thrown. + +.. note:: For the batched input if at least one matrix in the batch is not invertible, + then the result for all other (possibly) invertible matrices will be erroneously infinity as well. + Args: - input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more - batch dimensions. + input (Tensor): the input matrix of size :math:`(m, n)` or the batch of matrices of size :math:`(*, m, n)` + where `*` is one or more batch dimensions. - p (int, float, inf, -inf, 'fro', optional): The order of norm. - inf refers to :attr:`float('inf')`, numpy's :attr:`inf` object, or any equivalent object. - The following norms can be used: + p (int, float, inf, -inf, 'fro', 'nuc', optional): the type of the matrix norm to use in the computations. + The following norms are supported: ===== ============================ - ord norm for matrices + p norm for matrices ===== ============================ None 2-norm, computed directly using the SVD 'fro' Frobenius norm From 02f43eba32aaeb7fdf8dcb682672d6ca17bb0c4f Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 10 Nov 2020 09:22:21 -0600 Subject: [PATCH 19/44] Added non-square input test cases --- test/test_linalg.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index ef3e884081d4..23e0ade12a7f 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -428,7 +428,7 @@ def run_test_case(input, ord): self.assertEqual(ans, result) norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None] - input_sizes = [(3, 3), (2, 3, 3, 3)] + input_sizes = [(32, 32), (2, 3, 3, 3)] for input_size in input_sizes: input = torch.randn(*input_size, dtype=dtype, device=device) for ord in norm_types: @@ -439,6 +439,13 @@ def run_test_case(input, ord): continue run_test_case(input, ord) + # test non-square input + input_sizes = [(16, 32), (32, 16), (2, 3, 5, 3), (2, 3, 3, 5)] + for input_size in input_sizes: + input = torch.randn(*input_size, dtype=dtype, device=device) + for ord in [2, -2, None]: + run_test_case(input, ord) + # test for singular input a = torch.eye(3, dtype=dtype, device=device) a[-1, -1] = 0 # make 'a' singular From 36385d962062c3aa9a400d623ae4d1e3f9cd249b Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 10 Nov 2020 09:28:43 -0600 Subject: [PATCH 20/44] Updated norms description --- torch/linalg/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 59b1e35a9b0d..6c61fc06c3ac 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -167,15 +167,15 @@ ===== ============================ p norm for matrices ===== ============================ - None 2-norm, computed directly using the SVD + None 2-norm (ratio of the largest singular value to the smallest singular value) 'fro' Frobenius norm 'nuc' nuclear norm inf max(sum(abs(x), dim=1)) -inf min(sum(abs(x), dim=1)) 1 max(sum(abs(x), dim=0)) -1 min(sum(abs(x), dim=0)) - 2 2-norm (largest sing. value) - -2 smallest singular value + 2 2-norm (ratio of the largest singular value to the smallest singular value) + -2 ratio of the smallest singular value to the largest singular value ===== ============================ Default: ``None`` From 49a157768cd81edd9786381de31ae10f920b8b74 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 12 Nov 2020 04:18:25 -0600 Subject: [PATCH 21/44] Raise error for batched input with at least one non-invertible matrix --- aten/src/ATen/native/LinearAlgebra.cpp | 16 +++++++--------- test/test_linalg.py | 12 +++++------- torch/linalg/__init__.py | 2 +- 3 files changed, 13 insertions(+), 17 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index f36eebe9eb5f..e3c25422a18d 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1603,20 +1603,18 @@ Tensor& linalg_norm_out(Tensor& result, const Tensor& self, std::string ord, opt Tensor _linalg_cond_exception_helper(const Tensor& self) { // For batched input if at least one matrix in the batch is not invertible, - // then the result for all other (possibly) invertible matrices will be infinity as well - // since there is currently no way to use at::inverse with silent errors - + // we can't get the result for all other (possibly) invertible matrices in the batch without an explicit for loop. + // This should change when at::inverse works with silent errors + if (self.dim() > 2) { + TORCH_CHECK(false, + "At least one matrix in the batch is not invertible, its condition number is infinity, " + "linalg_cond does not support yet calculating the condition number for all other (possibly invertible) matrices in the batch."); + } auto result_shape = self.sizes().vec(); result_shape.pop_back(); result_shape.pop_back(); // result's shape is equal to self.shape[0:-2] Tensor result = at::empty(result_shape, self.options()); at::fill_(result, INFINITY); - if (self.dim() > 2) { - // Should this be the not-implemented-error? - TORCH_WARN( - "linalg_cond for the batched input returns infinity for all (possibly invertible) matrices in the batch, " - "if at least one matrix in the batch is not invertible."); - } return result; } diff --git a/test/test_linalg.py b/test/test_linalg.py index 0c9440dffdc2..066f52e52bc7 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -500,20 +500,18 @@ def test_cond_errors_and_warnings(self, device, dtype): torch.linalg.cond(a, ord, out=out) # for batched input if at least one matrix in the batch is not invertible, - # then the result for all other (possibly) invertible matrices will be infinity as well - # since there is currently no way to use torch.inverse with silent errors + # we can't get the result for all other (possibly) invertible matrices in the batch without an explicit for loop. + # this should change when at::inverse works with silent errors + # NumPy works fine in this case because it's possible to silence the error and get the inverse matrix results + # possibly filled with NANs batch_dim = 3 a = torch.eye(3, 3, dtype=dtype, device=device) a = a.reshape((1, 3, 3)) a = a.repeat(batch_dim, 1, 1) a[0, -1, -1] = 0 # now a[0] is singular for ord in [1, -1, inf, -inf, 'fro', 'nuc']: - with warnings.catch_warnings(record=True) as w: - # Trigger warning + with self.assertRaisesRegex(RuntimeError, "linalg_cond does not support yet"): torch.linalg.cond(a, ord) - # Check warning occurs - self.assertEqual(len(w), 1) - self.assertTrue("for the batched input returns infinity for all" in str(w[-1].message)) # TODO: once "inverse_cuda" supports complex dtypes, they shall be added to above tests @unittest.expectedFailure diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 6c61fc06c3ac..57699d31105e 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -155,7 +155,7 @@ then a RuntimeError will be thrown. .. note:: For the batched input if at least one matrix in the batch is not invertible, - then the result for all other (possibly) invertible matrices will be erroneously infinity as well. + currently getting the result for all other (possibly) invertible matrices in the batch is not implemented. Args: input (Tensor): the input matrix of size :math:`(m, n)` or the batch of matrices of size :math:`(*, m, n)` From 655a03d8e65a126cee658649815ae192b96e9641 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 17 Nov 2020 05:58:00 -0600 Subject: [PATCH 22/44] Updated documentations --- torch/linalg/__init__.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index ad0a3e7e18e6..5f6b945b6a65 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -222,14 +222,17 @@ The condition number is defined as the matrix norm of :attr:`input` times the matrix norm of the inverse of :attr:`input`. -This function supports real-valued, and only on CPU, complex-valued input. +This function supports ``float``, ``double``, and only on CPU, ``cfloat`` and ``cdouble`` dtype for :attr:`input`. .. note:: For `p = {None, 2, -2}` :attr:`input` can be non-square. For other norm types :attr:`input` must be - a square matrix or a batch of square matrices. If :attr:`input` does not satisfy the requirement + a square matrix or a batch of square matrices. If :attr:`input` does not satisfy this requirement then a RuntimeError will be thrown. -.. note:: For the batched input if at least one matrix in the batch is not invertible, - currently getting the result for all other (possibly) invertible matrices in the batch is not implemented. +.. note:: If :attr:`input` is a non-invertible matrix then a tensor containing infinity will be returned. + If :attr:`input` is a batch of matrices and one or more of them is not invertible + then a RuntimeError will be thrown. + +.. note:: When given inputs on a CUDA device, this function synchronizes that device with the CPU. Args: input (Tensor): the input matrix of size :math:`(m, n)` or the batch of matrices of size :math:`(*, m, n)` @@ -279,6 +282,16 @@ tensor(1.4142) >>> LA.cond(a, -2) tensor(0.7071) + + >>> a = torch.randn(3, 4, 4) + >>> LA.cond(a) + tensor([ 4.4739, 76.5234, 10.8409]) + + >>> a = torch.randn(3, 4, 4, dtype=torch.complex64) + >>> LA.cond(a) + tensor([ 5.9175, 48.4590, 5.6443]) + >>> LA.cond(a, 1) + >>> tensor([ 11.6734+0.j, 105.1037+0.j, 10.1978+0.j]) """) tensorsolve = _add_docstr(_linalg.linalg_tensorsolve, r""" From 27142ba801c1570e1c64b12ddd66be88e2dde3b5 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 17 Nov 2020 07:25:49 -0600 Subject: [PATCH 23/44] Use enums to dispatch on norm type --- aten/src/ATen/native/LinearAlgebra.cpp | 175 ++++++++++++++++++++----- 1 file changed, 143 insertions(+), 32 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index d8e24e68bcbb..ce03ab8b62cf 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -16,6 +16,8 @@ #include #include +#include + namespace at { namespace native { @@ -1621,6 +1623,111 @@ Tensor _linalg_cond_exception_helper(const Tensor& self) { return result; } +enum class CondMode { + kNorm2, // 2-norm using the SVD, default + kNormNegative2, + kNorm1, + kNormNegative1, + kNormInfinity, + kNormNegativeInfinity, + kNormFrobenius, + kNormNuclear, + kUndefined +}; + +// this function converts cond_mode enum to ord argument compatible with linalg_norm +c10::variant cond_mode_to_ord(CondMode cond_mode) { + c10::variant ord; + switch (cond_mode) { + case CondMode::kNorm2: + ord = 2; + break; + case CondMode::kNormNegative2: + ord = -2; + break; + case CondMode::kNorm1: + ord = 1; + break; + case CondMode::kNormNegative1: + ord = -1; + break; + case CondMode::kNormInfinity: + ord = 1; + break; + case CondMode::kNormNegativeInfinity: + ord = -1; + break; + case CondMode::kNormFrobenius: + ord = "fro"; + break; + case CondMode::kNormNuclear: + ord = "nuc"; + break; + default: + TORCH_CHECK(false, "cond_mode_to_ord: got an unexpected norm type."); + } + return ord; +} + +// this function converts ord argument compatible with linalg_norm to cond_mode enum compatible with _linalg_cond_helper +CondMode ord_to_cond_mode(c10::variant ord_variant) { + if (ord_variant.index() == 0) { + // ord_variant holds a Scalar type + auto ord = c10::get(ord_variant).toDouble(); + if (ord == 2) { + return CondMode::kNorm2; + } else if (ord == -2) { + return CondMode::kNormNegative2; + } else if (ord == 1) { + return CondMode::kNorm1; + } else if (ord == -1) { + return CondMode::kNormNegative1; + } else if (ord == INFINITY) { + return CondMode::kNormInfinity; + } else if (ord == -INFINITY) { + return CondMode::kNormNegativeInfinity; + } + } else if (ord_variant.index() == 1) { + // ord_variant holds a std::string type + std::string ord = c10::get(ord_variant); + if (ord == "fro") { + return CondMode::kNormFrobenius; + } else if (ord == "nuc") { + return CondMode::kNormNuclear; + } + } + return CondMode::kUndefined; +} + +Tensor _linalg_cond_helper(const Tensor& self, CondMode cond_mode) { + // Ignore errors if not invertible, result is INFINITY in this case + // Currently checking for error in at::inverse causes cross-device data movement + // For batched input if at least one matrix in the batch is not invertible, + // then the result for all other (possibly) invertible matrices will be infinity as well + // since there is currently no way to use at::inverse with silent errors + Tensor self_inverse; + try { + self_inverse = at::inverse(self); + } catch (const std::exception& e) { + if (strstr(e.what(), "singular")) { + return _linalg_cond_exception_helper(self); + } else { + TORCH_CHECK(false, "linalg_cond got an unexpected error:\n", e.what()); + } + } + std::array dim_arr = {-2, -1}; + optional dim = IntArrayRef(dim_arr); + + c10::variant ord_variant = cond_mode_to_ord(cond_mode); + + return c10::visit([&](auto&& ord) { + Tensor norm_self = at::linalg_norm(self, ord, dim); + Tensor norm_inverse = at::linalg_norm(self_inverse, ord, dim); + Tensor result = norm_self * norm_inverse; + return result; + }, ord_variant); +} + // Numerical or None norms Tensor linalg_cond(const Tensor& self, optional opt_ord) { TORCH_CHECK(self.numel() > 0, "linalg_cond is not defined for empty tensors."); @@ -1675,22 +1782,24 @@ Tensor linalg_cond(const Tensor& self, optional opt_ord) { // For batched input if at least one matrix in the batch is not invertible, // then the result for all other (possibly) invertible matrices will be infinity as well // since there is currently no way to use at::inverse with silent errors - Tensor self_inverse; - try { - self_inverse = at::inverse(self); - } catch (const std::exception& e) { - if (strstr(e.what(), "singular")) { - return _linalg_cond_exception_helper(self); - } else { - TORCH_CHECK(false, "linalg_cond got an unexpected error:\n", e.what()); - } - } - std::array dim_arr = {-2, -1}; - optional dim = IntArrayRef(dim_arr); - Tensor norm_self = at::linalg_norm(self, ord, dim); - Tensor norm_inverse = at::linalg_norm(self_inverse, ord, dim); - Tensor result = norm_self * norm_inverse; - return result; + // Tensor self_inverse; + // try { + // self_inverse = at::inverse(self); + // } catch (const std::exception& e) { + // if (strstr(e.what(), "singular")) { + // return _linalg_cond_exception_helper(self); + // } else { + // TORCH_CHECK(false, "linalg_cond got an unexpected error:\n", e.what()); + // } + // } + // std::array dim_arr = {-2, -1}; + // optional dim = IntArrayRef(dim_arr); + // Tensor norm_self = at::linalg_norm(self, ord, dim); + // Tensor norm_inverse = at::linalg_norm(self_inverse, ord, dim); + // Tensor result = norm_self * norm_inverse; + // return result; + CondMode cond_mode = ord_to_cond_mode(ord); + return _linalg_cond_helper(self, cond_mode); } Tensor& linalg_cond_out(Tensor& result, const Tensor& self, optional opt_ord) { @@ -1724,22 +1833,24 @@ Tensor linalg_cond(const Tensor& self, std::string ord) { // For batched input if at least one matrix in the batch is not invertible, // then the result for all other (possibly) invertible matrices will be infinity as well // since there is currently no way to use at::inverse with silent errors - Tensor self_inverse; - try { - self_inverse = at::inverse(self); - } catch (const std::exception& e) { - if (strstr(e.what(), "singular")) { - return _linalg_cond_exception_helper(self); - } else { - TORCH_CHECK(false, "linalg_cond got an unexpected error:\n", e.what()); - } - } - std::array dim_arr = {-2, -1}; - optional dim = IntArrayRef(dim_arr); - Tensor norm_self = at::linalg_norm(self, ord, dim); - Tensor norm_inverse = at::linalg_norm(self_inverse, ord, dim); - Tensor result = norm_self * norm_inverse; - return result; + // Tensor self_inverse; + // try { + // self_inverse = at::inverse(self); + // } catch (const std::exception& e) { + // if (strstr(e.what(), "singular")) { + // return _linalg_cond_exception_helper(self); + // } else { + // TORCH_CHECK(false, "linalg_cond got an unexpected error:\n", e.what()); + // } + // } + // std::array dim_arr = {-2, -1}; + // optional dim = IntArrayRef(dim_arr); + // Tensor norm_self = at::linalg_norm(self, ord, dim); + // Tensor norm_inverse = at::linalg_norm(self_inverse, ord, dim); + // Tensor result = norm_self * norm_inverse; + // return result; + CondMode cond_mode = ord_to_cond_mode(ord); + return _linalg_cond_helper(self, cond_mode); } Tensor& linalg_cond_out(Tensor& result, const Tensor& self, std::string ord) { From 195463f13ee67e52c8ed4bf89bc515d6d5cea961 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 17 Nov 2020 07:32:31 -0600 Subject: [PATCH 24/44] Add missing imports --- test/test_linalg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index ebaa4e60b491..db68e910259b 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -8,8 +8,8 @@ from torch.testing._internal.common_utils import \ (TestCase, run_tests, TEST_NUMPY, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest, TEST_WITH_ASAN, make_tensor) from torch.testing._internal.common_device_type import \ - (instantiate_device_type_tests, dtypes, - onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, + (instantiate_device_type_tests, dtypes, dtypesIfCUDA, + onlyCUDA, onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, onlyOnCPUAndCUDA) from torch.testing._internal.common_cuda import tf32_on_and_off from torch.testing._internal.jit_metaprogramming_utils import gen_script_fn_and_args From ea95b087f90cec66617ff940d9a451ad522bc8d0 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 17 Nov 2020 07:33:04 -0600 Subject: [PATCH 25/44] Enums are redundant here, use only c10::variant for norm dispatch --- aten/src/ATen/native/LinearAlgebra.cpp | 131 ++----------------------- 1 file changed, 6 insertions(+), 125 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index ce03ab8b62cf..5cab4eae69bf 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1623,83 +1623,8 @@ Tensor _linalg_cond_exception_helper(const Tensor& self) { return result; } -enum class CondMode { - kNorm2, // 2-norm using the SVD, default - kNormNegative2, - kNorm1, - kNormNegative1, - kNormInfinity, - kNormNegativeInfinity, - kNormFrobenius, - kNormNuclear, - kUndefined -}; - -// this function converts cond_mode enum to ord argument compatible with linalg_norm -c10::variant cond_mode_to_ord(CondMode cond_mode) { - c10::variant ord; - switch (cond_mode) { - case CondMode::kNorm2: - ord = 2; - break; - case CondMode::kNormNegative2: - ord = -2; - break; - case CondMode::kNorm1: - ord = 1; - break; - case CondMode::kNormNegative1: - ord = -1; - break; - case CondMode::kNormInfinity: - ord = 1; - break; - case CondMode::kNormNegativeInfinity: - ord = -1; - break; - case CondMode::kNormFrobenius: - ord = "fro"; - break; - case CondMode::kNormNuclear: - ord = "nuc"; - break; - default: - TORCH_CHECK(false, "cond_mode_to_ord: got an unexpected norm type."); - } - return ord; -} - -// this function converts ord argument compatible with linalg_norm to cond_mode enum compatible with _linalg_cond_helper -CondMode ord_to_cond_mode(c10::variant ord_variant) { - if (ord_variant.index() == 0) { - // ord_variant holds a Scalar type - auto ord = c10::get(ord_variant).toDouble(); - if (ord == 2) { - return CondMode::kNorm2; - } else if (ord == -2) { - return CondMode::kNormNegative2; - } else if (ord == 1) { - return CondMode::kNorm1; - } else if (ord == -1) { - return CondMode::kNormNegative1; - } else if (ord == INFINITY) { - return CondMode::kNormInfinity; - } else if (ord == -INFINITY) { - return CondMode::kNormNegativeInfinity; - } - } else if (ord_variant.index() == 1) { - // ord_variant holds a std::string type - std::string ord = c10::get(ord_variant); - if (ord == "fro") { - return CondMode::kNormFrobenius; - } else if (ord == "nuc") { - return CondMode::kNormNuclear; - } - } - return CondMode::kUndefined; -} - -Tensor _linalg_cond_helper(const Tensor& self, CondMode cond_mode) { +// This function helps to dispatch norm computations depending on 'ord' of variant type +Tensor _linalg_cond_helper(const Tensor& self, c10::variant ord_variant) { // Ignore errors if not invertible, result is INFINITY in this case // Currently checking for error in at::inverse causes cross-device data movement // For batched input if at least one matrix in the batch is not invertible, @@ -1718,8 +1643,6 @@ Tensor _linalg_cond_helper(const Tensor& self, CondMode cond_mode) { std::array dim_arr = {-2, -1}; optional dim = IntArrayRef(dim_arr); - c10::variant ord_variant = cond_mode_to_ord(cond_mode); - return c10::visit([&](auto&& ord) { Tensor norm_self = at::linalg_norm(self, ord, dim); Tensor norm_inverse = at::linalg_norm(self_inverse, ord, dim); @@ -1777,29 +1700,8 @@ Tensor linalg_cond(const Tensor& self, optional opt_ord) { "linalg_cond only supports square matrices or batches of square matrices " "but got ", self.size(-1), " by ", self.size(-2), " matrices"); - // Ignore errors if not invertible, result is INFINITY in this case - // Currently checking for error in at::inverse causes cross-device data movement - // For batched input if at least one matrix in the batch is not invertible, - // then the result for all other (possibly) invertible matrices will be infinity as well - // since there is currently no way to use at::inverse with silent errors - // Tensor self_inverse; - // try { - // self_inverse = at::inverse(self); - // } catch (const std::exception& e) { - // if (strstr(e.what(), "singular")) { - // return _linalg_cond_exception_helper(self); - // } else { - // TORCH_CHECK(false, "linalg_cond got an unexpected error:\n", e.what()); - // } - // } - // std::array dim_arr = {-2, -1}; - // optional dim = IntArrayRef(dim_arr); - // Tensor norm_self = at::linalg_norm(self, ord, dim); - // Tensor norm_inverse = at::linalg_norm(self_inverse, ord, dim); - // Tensor result = norm_self * norm_inverse; - // return result; - CondMode cond_mode = ord_to_cond_mode(ord); - return _linalg_cond_helper(self, cond_mode); + c10::variant ord_variant = ord; + return _linalg_cond_helper(self, ord_variant); } Tensor& linalg_cond_out(Tensor& result, const Tensor& self, optional opt_ord) { @@ -1828,29 +1730,8 @@ Tensor linalg_cond(const Tensor& self, std::string ord) { "linalg_cond only supports square matrices or batches of square matrices " "but got ", self.size(-1), " by ", self.size(-2), " matrices"); - // Ignore errors if not invertible, result is INFINITY in this case - // Currently checking for error in at::inverse causes cross-device data movement - // For batched input if at least one matrix in the batch is not invertible, - // then the result for all other (possibly) invertible matrices will be infinity as well - // since there is currently no way to use at::inverse with silent errors - // Tensor self_inverse; - // try { - // self_inverse = at::inverse(self); - // } catch (const std::exception& e) { - // if (strstr(e.what(), "singular")) { - // return _linalg_cond_exception_helper(self); - // } else { - // TORCH_CHECK(false, "linalg_cond got an unexpected error:\n", e.what()); - // } - // } - // std::array dim_arr = {-2, -1}; - // optional dim = IntArrayRef(dim_arr); - // Tensor norm_self = at::linalg_norm(self, ord, dim); - // Tensor norm_inverse = at::linalg_norm(self_inverse, ord, dim); - // Tensor result = norm_self * norm_inverse; - // return result; - CondMode cond_mode = ord_to_cond_mode(ord); - return _linalg_cond_helper(self, cond_mode); + c10::variant ord_variant = ord; + return _linalg_cond_helper(self, ord_variant); } Tensor& linalg_cond_out(Tensor& result, const Tensor& self, std::string ord) { From fbe025ab7404f69850e56af21bd7bb21c8513791 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 17 Nov 2020 08:26:27 -0600 Subject: [PATCH 26/44] Simplify error message for batched non-invertibel case --- aten/src/ATen/native/LinearAlgebra.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 5cab4eae69bf..b4453de90e72 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1612,8 +1612,8 @@ Tensor _linalg_cond_exception_helper(const Tensor& self) { // This should change when at::inverse works with silent errors if (self.dim() > 2) { TORCH_CHECK(false, - "At least one matrix in the batch is not invertible, its condition number is infinity, " - "linalg_cond does not support yet calculating the condition number for all other (possibly invertible) matrices in the batch."); + "One or more matrices in the batch was not invertible! " + "linalg_cond does not support yet this case."); } auto result_shape = self.sizes().vec(); result_shape.pop_back(); From 0056ca9e9e1fe7976fd28794e142c63a6e79dd1c Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 17 Nov 2020 08:34:12 -0600 Subject: [PATCH 27/44] Use at::full instead of empty + fill_ --- aten/src/ATen/native/LinearAlgebra.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index b4453de90e72..df3d08fe0ed0 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1618,8 +1618,7 @@ Tensor _linalg_cond_exception_helper(const Tensor& self) { auto result_shape = self.sizes().vec(); result_shape.pop_back(); result_shape.pop_back(); // result's shape is equal to self.shape[0:-2] - Tensor result = at::empty(result_shape, self.options()); - at::fill_(result, INFINITY); + Tensor result = at::full(result_shape, INFINITY, self.options()); return result; } @@ -1684,8 +1683,7 @@ Tensor linalg_cond(const Tensor& self, optional opt_ord) { // Scalar.dtype() is always ScalarType::Double for isFloatingPoint() = true // and at::where doesn't allow arguments with different dtype // so let's use 0-dim tensor filled with inf - Tensor inf_tensor = at::empty({}, result.options()); - at::fill_(inf_tensor, inf); + Tensor inf_tensor = at::full({}, inf, result.options()); return at::where(result == flt_max, inf_tensor, result); } default: From 47da62507b2c0a678499612565807bd4c1043a3f Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 17 Nov 2020 08:40:52 -0600 Subject: [PATCH 28/44] Make error message more specific about the norm types --- aten/src/ATen/native/LinearAlgebra.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index df3d08fe0ed0..c9055ee611e8 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1695,7 +1695,7 @@ Tensor linalg_cond(const Tensor& self, optional opt_ord) { // since at::inverse is used in the implementation, self has to be a tensor consisting of square matrices // the same check as squareCheckInputs(self) but with a slightly more informative error message TORCH_CHECK(self.size(-1) == self.size(-2), - "linalg_cond only supports square matrices or batches of square matrices " + "linalg_cond with ±1 or ±inf norm types only supports square matrices or batches of square matrices " "but got ", self.size(-1), " by ", self.size(-2), " matrices"); c10::variant ord_variant = ord; @@ -1725,7 +1725,7 @@ Tensor linalg_cond(const Tensor& self, std::string ord) { TORCH_CHECK(self.dim() >= 2, "linalg_cond only supports matrices or batches of matrices, but got a tensor with ", self.dim(), " dimensions."); TORCH_CHECK(self.size(-1) == self.size(-2), - "linalg_cond only supports square matrices or batches of square matrices " + "linalg_cond with frobenius or nuclear norm types only supports square matrices or batches of square matrices " "but got ", self.size(-1), " by ", self.size(-2), " matrices"); c10::variant ord_variant = ord; From 6eb461a3b5f4d146b7098890c832d7b6192dc63c Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 18 Nov 2020 10:42:04 -0600 Subject: [PATCH 29/44] Removed unneeded conversions to infinity --- aten/src/ATen/native/LinearAlgebra.cpp | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index c9055ee611e8..c6dde3d94a45 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1670,25 +1670,7 @@ Tensor linalg_cond(const Tensor& self, optional opt_ord) { } else { result = s_max / s_min; } - // convert FLT_MAX or DBL_MAX to INFINITY for NumPy compatibility - switch (result.scalar_type()) { - case ScalarType::Double: { - Scalar dbl_max = std::numeric_limits::max(); // DBL_MAX - Scalar inf = std::numeric_limits::infinity(); // HUGE_VAL - return at::where(result == dbl_max, inf, result); - } - case ScalarType::Float: { - Scalar flt_max = std::numeric_limits::max(); // FLT_MAX - float inf = std::numeric_limits::infinity(); // HUGE_VALF - // Scalar.dtype() is always ScalarType::Double for isFloatingPoint() = true - // and at::where doesn't allow arguments with different dtype - // so let's use 0-dim tensor filled with inf - Tensor inf_tensor = at::full({}, inf, result.options()); - return at::where(result == flt_max, inf_tensor, result); - } - default: - TORCH_CHECK(false, "linalg_cond got an unexpected result type ", toString(result.scalar_type())); - } + return result; } // ord == ±1 ord == ±inf From bacdc6591b70e1ba68991c3c5c2bc17c6682ec8d Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Sun, 29 Nov 2020 10:38:42 -0600 Subject: [PATCH 30/44] Fix typo norm -> cond --- torch/linalg/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index ebd9fb15e41a..a4eb10d1265b 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -340,7 +340,7 @@ """) cond = _add_docstr(_linalg.linalg_cond, r""" -linalg.norm(input, p=None, *, out=None) -> Tensor +linalg.cond(input, p=None, *, out=None) -> Tensor Computes the condition number of a matrix :attr:`input`, or of each matrix in a batched :attr:`input`, using the matrix norm defined by :attr:`p`. From c7b52de9e97b39f64ba46a080dc3f09754a053a8 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Sun, 29 Nov 2020 10:43:42 -0600 Subject: [PATCH 31/44] Rename ord -> p --- aten/src/ATen/native/native_functions.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 5e3eebe041ca..1f7879a3915b 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -9484,25 +9484,25 @@ python_module: linalg variants: function -- func: linalg_cond(Tensor self, Scalar? ord=None) -> Tensor +- func: linalg_cond(Tensor self, Scalar? p=None) -> Tensor python_module: linalg variants: function dispatch: Math: linalg_cond -- func: linalg_cond.out(Tensor self, Scalar? ord=None, *, Tensor(a!) out) -> Tensor(a!) +- func: linalg_cond.out(Tensor self, Scalar? p=None, *, Tensor(a!) out) -> Tensor(a!) python_module: linalg variants: function dispatch: Math: linalg_cond_out -- func: linalg_cond.ord_str(Tensor self, str ord) -> Tensor +- func: linalg_cond.p_str(Tensor self, str p) -> Tensor python_module: linalg variants: function dispatch: Math: linalg_cond -- func: linalg_cond.ord_str_out(Tensor self, str ord, *, Tensor(a!) out) -> Tensor(a!) +- func: linalg_cond.p_str_out(Tensor self, str p, *, Tensor(a!) out) -> Tensor(a!) python_module: linalg variants: function dispatch: From 873f93e589096b1df68b74cb695d33a550ec3e1c Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Sun, 29 Nov 2020 10:44:50 -0600 Subject: [PATCH 32/44] Fix typo --- torch/linalg/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index a4eb10d1265b..aef6baf41911 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -347,7 +347,7 @@ The condition number is defined as the matrix norm of :attr:`input` times the matrix norm of the inverse of :attr:`input`. -This function supports ``float``, ``double``, and only on CPU, ``cfloat`` and ``cdouble`` dtype for :attr:`input`. +This function supports ``float``, ``double``, and only on CPU, ``cfloat`` and ``cdouble`` dtypes for :attr:`input`. .. note:: For `p = {None, 2, -2}` :attr:`input` can be non-square. For other norm types :attr:`input` must be a square matrix or a batch of square matrices. If :attr:`input` does not satisfy this requirement From 581333382ac806c579767545fb5c7d2ce49d2415 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Sun, 29 Nov 2020 10:47:52 -0600 Subject: [PATCH 33/44] Remove 2-norm mentions in the docs, leaving just ratio of singular values --- torch/linalg/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index aef6baf41911..b84225aa11e8 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -369,14 +369,14 @@ ===== ============================ p norm for matrices ===== ============================ - None 2-norm (ratio of the largest singular value to the smallest singular value) + None ratio of the largest singular value to the smallest singular value 'fro' Frobenius norm 'nuc' nuclear norm inf max(sum(abs(x), dim=1)) -inf min(sum(abs(x), dim=1)) 1 max(sum(abs(x), dim=0)) -1 min(sum(abs(x), dim=0)) - 2 2-norm (ratio of the largest singular value to the smallest singular value) + 2 ratio of the largest singular value to the smallest singular value -2 ratio of the smallest singular value to the largest singular value ===== ============================ From c0828810c3016db310a57b8dd6f6f69110690d61 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Sun, 29 Nov 2020 11:59:21 -0600 Subject: [PATCH 34/44] Use IntArrayRef for constructing result_shape --- aten/src/ATen/native/LinearAlgebra.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index d0cc301e049f..2b33e0cbfea4 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1715,9 +1715,7 @@ Tensor _linalg_cond_exception_helper(const Tensor& self) { "One or more matrices in the batch was not invertible! " "linalg_cond does not support yet this case."); } - auto result_shape = self.sizes().vec(); - result_shape.pop_back(); - result_shape.pop_back(); // result's shape is equal to self.shape[0:-2] + auto result_shape = IntArrayRef(self.sizes().cbegin(), self.sizes().cend()-2); Tensor result = at::full(result_shape, INFINITY, self.options()); return result; } From c1e4bd92a78dc6e4ac785a57833387ebeb99682b Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Sun, 29 Nov 2020 12:31:18 -0600 Subject: [PATCH 35/44] Allow batched input with zero dimensions; singular values are sorted so no need to use max and min --- aten/src/ATen/native/LinearAlgebra.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 2b33e0cbfea4..71c41fe4b902 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1750,7 +1750,7 @@ Tensor _linalg_cond_helper(const Tensor& self, c10::variant // Numerical or None norms Tensor linalg_cond(const Tensor& self, optional opt_ord) { - TORCH_CHECK(self.numel() > 0, "linalg_cond is not defined for empty tensors."); + TORCH_CHECK(!(self.numel() == 0 && self.size(-2)*self.size(-1) == 0), "linalg_cond is not defined for empty tensors."); TORCH_CHECK(self.dim() >= 2, "linalg_cond only supports matrices or batches of matrices, but got a tensor with ", self.dim(), " dimensions."); @@ -1760,8 +1760,9 @@ Tensor linalg_cond(const Tensor& self, optional opt_ord) { // If ord == None or ord == ±2 if (std::abs(ord.toDouble()) == 2.0) { auto singular_values = std::get<1>(at::svd(self)); - auto s_max = std::get<0>(singular_values.max(/*dim=*/-1)); - auto s_min = std::get<0>(singular_values.min(/*dim=*/-1)); + // singular values are sorted in descending order + auto s_max = at::narrow(singular_values, /*dim=*/-1, /*start=*/0, /*length=*/1); + auto s_min = at::narrow(singular_values, /*dim=*/-1, /*start=*/-1, /*length=*/1); Tensor result; if (ord.toDouble() == -2.0) { result = s_min / s_max; @@ -1800,7 +1801,7 @@ Tensor& linalg_cond_out(Tensor& result, const Tensor& self, optional opt // Frobenius or nuclear norms Tensor linalg_cond(const Tensor& self, std::string ord) { - TORCH_CHECK(self.numel() > 0, "linalg_cond is not defined for empty tensors."); + TORCH_CHECK(!(self.numel() == 0 && self.size(-2)*self.size(-1) == 0), "linalg_cond is not defined for empty tensors."); // the same checks as squareCheckInputs(self) but with a slightly more informative error message TORCH_CHECK(self.dim() >= 2, "linalg_cond only supports matrices or batches of matrices, but got a tensor with ", self.dim(), " dimensions."); From f5296cb80e8ea859c2ca20cf54b1da6957be3f39 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Sun, 29 Nov 2020 12:32:04 -0600 Subject: [PATCH 36/44] Added empty batch size test inputs --- test/test_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 5da92bce5874..1bb56d45a958 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -963,7 +963,7 @@ def run_test_case(input, ord): self.assertEqual(ans, result) norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None] - input_sizes = [(32, 32), (2, 3, 3, 3)] + input_sizes = [(32, 32), (2, 3, 3, 3), (0, 3, 3), (0, 2, 5, 5)] for input_size in input_sizes: input = torch.randn(*input_size, dtype=dtype, device=device) for ord in norm_types: From b45237941ff6947648ba24f41f7a7e238da2b5ad Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Sun, 29 Nov 2020 12:35:10 -0600 Subject: [PATCH 37/44] Revert changes to common_methods_invocations.py --- .../testing/_internal/common_methods_invocations.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 8145928f735d..8850c3f7bf49 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1580,16 +1580,6 @@ def method_tests(): ('__getitem__', torch.randn(S, S, S), (dont_convert([[0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])]),), 'adv_index_var'), ('to_sparse', (S, S), (), '', (), (), [], lambda x: x.to_dense()), - ('linalg.cond', (S, S), (), 'default', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('linalg.cond', (S, S, S), (), 'default_batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('linalg.cond', (S, S), (inf,), 'matrix_inf', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('linalg.cond', (S, S), (2,), 'matrix_2', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('linalg.cond', (S, S), (1,), 'matrix_1', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('linalg.cond', (S, S), (-inf,), 'matrix_neg_inf', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('linalg.cond', (S, S), (-2,), 'matrix_neg_2', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('linalg.cond', (S, S), (-1,), 'matrix_neg_1', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('linalg.cond', (S, S), ('fro',), 'fro', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), - ('linalg.cond', (S, S), ('nuc',), 'nuc', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), ('triangular_solve', (S, M), ((S, S), ), '', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), ('kron', (S, S), ((M, L),)) ] @@ -1886,6 +1876,4 @@ def exclude_tensor_method(name, test_name): return True if 'fft.' in name: return True - if 'linalg.' in name: - return True return False From b15fedda0078a4a544f3f2f44e75a46288dc0eba Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 30 Nov 2020 12:13:02 -0600 Subject: [PATCH 38/44] Return 0 for 0x0 matrices --- aten/src/ATen/native/LinearAlgebra.cpp | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 71c41fe4b902..bd5aec275c53 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1748,15 +1748,27 @@ Tensor _linalg_cond_helper(const Tensor& self, c10::variant }, ord_variant); } +// Return zero for each matrix in the batch +Tensor _linalg_cond_empty_matrix(const Tensor& self, c10::ScalarType dtype) { + auto result_shape = IntArrayRef(self.sizes().cbegin(), self.sizes().cend()-2); + return at::zeros(result_shape, self.options().dtype(dtype)); +} + // Numerical or None norms Tensor linalg_cond(const Tensor& self, optional opt_ord) { - TORCH_CHECK(!(self.numel() == 0 && self.size(-2)*self.size(-1) == 0), "linalg_cond is not defined for empty tensors."); TORCH_CHECK(self.dim() >= 2, "linalg_cond only supports matrices or batches of matrices, but got a tensor with ", self.dim(), " dimensions."); // The default case is using 2-norm Scalar ord = opt_ord.has_value() ? opt_ord.value() : 2; + // NumPy doesn't define the condition number for 0x0 matrices, we return 0.0 for such input + if (self.numel() == 0) { + auto real_dtype = toValueType(typeMetaToScalarType(self.dtype())); + auto expected_dtype = std::abs(ord.toDouble()) == 2.0 ? real_dtype : self.scalar_type(); + return _linalg_cond_empty_matrix(self, expected_dtype); + } + // If ord == None or ord == ±2 if (std::abs(ord.toDouble()) == 2.0) { auto singular_values = std::get<1>(at::svd(self)); @@ -1801,7 +1813,6 @@ Tensor& linalg_cond_out(Tensor& result, const Tensor& self, optional opt // Frobenius or nuclear norms Tensor linalg_cond(const Tensor& self, std::string ord) { - TORCH_CHECK(!(self.numel() == 0 && self.size(-2)*self.size(-1) == 0), "linalg_cond is not defined for empty tensors."); // the same checks as squareCheckInputs(self) but with a slightly more informative error message TORCH_CHECK(self.dim() >= 2, "linalg_cond only supports matrices or batches of matrices, but got a tensor with ", self.dim(), " dimensions."); @@ -1809,6 +1820,11 @@ Tensor linalg_cond(const Tensor& self, std::string ord) { "linalg_cond with frobenius or nuclear norm types only supports square matrices or batches of square matrices " "but got ", self.size(-1), " by ", self.size(-2), " matrices"); + // NumPy doesn't define the condition number for 0x0 matrices, we return 0.0 for such input + if (self.numel() == 0) { + return _linalg_cond_empty_matrix(self, self.scalar_type()); + } + c10::variant ord_variant = ord; return _linalg_cond_helper(self, ord_variant); } From 73046d18bdc723690e3a77c7ce062075a4fa68de Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 30 Nov 2020 12:25:15 -0600 Subject: [PATCH 39/44] Added tests case for 0x0 matrices Added complex dtype tests cases; Renamed ord -> p --- test/test_linalg.py | 87 +++++++++++++++++++++------------------------ 1 file changed, 41 insertions(+), 46 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 1bb56d45a958..fd7682814720 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -948,89 +948,98 @@ def run_test_case(input, p, dim, keepdim): @skipCPUIfNoLapack @skipCUDAIfNoMagma @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) - @dtypesIfCUDA(torch.float32, torch.float64) @precisionOverride({torch.float32: 1e-3}) def test_cond(self, device, dtype): - def run_test_case(input, ord): - result = torch.linalg.cond(input, ord) - result_numpy = np.linalg.cond(input.cpu().numpy(), ord) + def run_test_case(input, p): + result = torch.linalg.cond(input, p) + result_numpy = np.linalg.cond(input.cpu().numpy(), p) self.assertEqual(result, result_numpy, rtol=1e-2, atol=self.precision) # test out= variant out = torch.empty_like(result) - ans = torch.linalg.cond(input, ord, out=out) + ans = torch.linalg.cond(input, p, out=out) self.assertEqual(ans, out) self.assertEqual(ans, result) norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None] - input_sizes = [(32, 32), (2, 3, 3, 3), (0, 3, 3), (0, 2, 5, 5)] + input_sizes = [(32, 32), (2, 3, 3, 3)] for input_size in input_sizes: input = torch.randn(*input_size, dtype=dtype, device=device) - for ord in norm_types: + for p in norm_types: # frobenius norm not supported for complex tensors - if dtype.is_complex and ord == 'fro': + if dtype.is_complex and p == 'fro': with self.assertRaisesRegex(RuntimeError, "frobenius norm not supported for complex tensors"): - torch.linalg.cond(input, ord) + torch.linalg.cond(input, p) continue - run_test_case(input, ord) + run_test_case(input, p) + + # test empty batch sizes + input_sizes = [(0, 3, 3), (0, 2, 5, 5)] + for input_size in input_sizes: + input = torch.randn(*input_size, dtype=dtype, device=device) + for p in norm_types: + run_test_case(input, p) # test non-square input input_sizes = [(16, 32), (32, 16), (2, 3, 5, 3), (2, 3, 3, 5)] for input_size in input_sizes: input = torch.randn(*input_size, dtype=dtype, device=device) - for ord in [2, -2, None]: - run_test_case(input, ord) + for p in [2, -2, None]: + run_test_case(input, p) # test for singular input a = torch.eye(3, dtype=dtype, device=device) a[-1, -1] = 0 # make 'a' singular - for ord in norm_types: - run_test_case(a, ord) + for p in norm_types: + run_test_case(a, p) + + # test for 0x0 matrices. NumPy doesn't work for such input, we return 0 + input_sizes = [(0, 0), (2, 5, 0, 0)] + for input_size in input_sizes: + input = torch.randn(*input_size, dtype=dtype, device=device) + for p in ['fro', 2]: + expected_dtype = a.real.dtype if dtype.is_complex and p == 2 else dtype + expected = torch.zeros(input_size[:-2], dtype=expected_dtype, device=device) + actual = torch.linalg.cond(input, p) + self.assertEqual(actual, expected) @skipCPUIfNoLapack @skipCUDAIfNoMagma @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) - @dtypesIfCUDA(torch.float32, torch.float64) @precisionOverride({torch.float32: 1e-3}) def test_cond_errors_and_warnings(self, device, dtype): norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None] - # cond expects the input to be non-empty - a = torch.zeros((0, 0), dtype=dtype, device=device) - for ord in norm_types: - with self.assertRaisesRegex(RuntimeError, r'linalg_cond is not defined for empty tensors'): - torch.linalg.cond(a, ord) - # cond expects the input to be at least 2-dimensional a = torch.ones(3, dtype=dtype, device=device) - for ord in norm_types: + for p in norm_types: with self.assertRaisesRegex(RuntimeError, r'supports matrices or batches of matrices'): - torch.linalg.cond(a, ord) + torch.linalg.cond(a, p) # for some norm types cond expects the input to be square a = torch.ones(3, 2, dtype=dtype, device=device) norm_types = [1, -1, inf, -inf, 'fro', 'nuc'] - for ord in norm_types: + for p in norm_types: with self.assertRaisesRegex(RuntimeError, r'supports square matrices or batches of square matrices'): - torch.linalg.cond(a, ord) + torch.linalg.cond(a, p) # if non-empty out tensor with wrong shape is passed a warning is given a = torch.ones((2, 2), dtype=dtype, device=device) - for ord in ['fro', 2]: - real_dtype = a.real.dtype if dtype.is_complex and ord == 2 else dtype + for p in ['fro', 2]: + real_dtype = a.real.dtype if dtype.is_complex and p == 2 else dtype out = torch.empty(a.shape, dtype=real_dtype, device=device) with warnings.catch_warnings(record=True) as w: # Trigger warning - torch.linalg.cond(a, ord, out=out) + torch.linalg.cond(a, p, out=out) # Check warning occurs self.assertEqual(len(w), 1) self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) # dtypes should match out = torch.empty_like(a).to(torch.int) - for ord in ['fro', 2]: + for p in ['fro', 2]: with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match"): - torch.linalg.cond(a, ord, out=out) + torch.linalg.cond(a, p, out=out) # for batched input if at least one matrix in the batch is not invertible, # we can't get the result for all other (possibly) invertible matrices in the batch without an explicit for loop. @@ -1042,23 +1051,9 @@ def test_cond_errors_and_warnings(self, device, dtype): a = a.reshape((1, 3, 3)) a = a.repeat(batch_dim, 1, 1) a[0, -1, -1] = 0 # now a[0] is singular - for ord in [1, -1, inf, -inf, 'fro', 'nuc']: + for p in [1, -1, inf, -inf, 'fro', 'nuc']: with self.assertRaisesRegex(RuntimeError, "linalg_cond does not support yet"): - torch.linalg.cond(a, ord) - - # TODO: once "inverse_cuda" supports complex dtypes, they shall be added to above tests - @unittest.expectedFailure - @onlyCUDA - @skipCUDAIfNoMagma - @dtypes(torch.complex64, torch.complex128) - @precisionOverride({torch.float32: 1e-3}) - def test_cond_xfailed(self, device, dtype): - input_size = (3, 3) - ord = 1 - torch.randn(*input_size, dtype=dtype, device=device) - result = torch.linalg.cond(input, ord) - result_numpy = np.linalg.cond(input.cpu().numpy(), ord) - self.assertEqual(result, result_numpy, rtol=1e-2, atol=self.precision) + torch.linalg.cond(a, p) # Test autograd and jit functionality for linalg functions. # TODO: Once support for linalg functions is added to method_tests in common_methods_invocations.py, From 5ed8a776e901dfa9f75f8e39efcbbb9f6d01da92 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 30 Nov 2020 12:33:09 -0600 Subject: [PATCH 40/44] Updated note for p = 2, -2 to include reference to svd --- torch/linalg/__init__.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index b84225aa11e8..4cbd1af366e9 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -349,9 +349,11 @@ This function supports ``float``, ``double``, and only on CPU, ``cfloat`` and ``cdouble`` dtypes for :attr:`input`. -.. note:: For `p = {None, 2, -2}` :attr:`input` can be non-square. For other norm types :attr:`input` must be - a square matrix or a batch of square matrices. If :attr:`input` does not satisfy this requirement - then a RuntimeError will be thrown. +.. note:: For `p = {None, 2, -2}` there is a relation between singular value decomposition and matrix norm. + For this case :func:`torch.linalg.svd` is used for computing the condition number as the ratio of + the largest and smallest singular values. Since :func:`torch.linalg.svd` is used :attr:`input` can be non-square. + For other norm types :attr:`input` must be a square matrix or a batch of square matrices. + If :attr:`input` does not satisfy this requirement then a RuntimeError will be thrown. .. note:: If :attr:`input` is a non-invertible matrix then a tensor containing infinity will be returned. If :attr:`input` is a batch of matrices and one or more of them is not invertible From e4070984f7b3bdf54f0a7b53d359cc96433d2990 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 30 Nov 2020 15:53:15 -0600 Subject: [PATCH 41/44] Added checks for norm type --- aten/src/ATen/native/LinearAlgebra.cpp | 24 ++++++++++++++++++++++-- test/test_linalg.py | 6 ++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index bd5aec275c53..483703fac215 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1754,6 +1754,22 @@ Tensor _linalg_cond_empty_matrix(const Tensor& self, c10::ScalarType dtype) { return at::zeros(result_shape, self.options().dtype(dtype)); } +void _linalg_cond_check_ord(c10::variant ord_variant) { + if (ord_variant.index() == 0) { + Scalar* ord = c10::get_if(&ord_variant); + double abs_ord = std::abs(ord->toDouble()); + TORCH_CHECK(abs_ord == 2.0 || abs_ord == 1.0 || abs_ord == INFINITY, + "linalg_cond got an invalid norm type: ", ord->toDouble()); + } else if (ord_variant.index() == 1) { + std::string* ord = c10::get_if(&ord_variant); + TORCH_CHECK(*ord == "fro" || *ord == "nuc", + "linalg_cond got an invalid norm type: ", *ord); + } else { + TORCH_CHECK(false, + "linalg_cond: something went wrong while checking the norm type"); + } +} + // Numerical or None norms Tensor linalg_cond(const Tensor& self, optional opt_ord) { TORCH_CHECK(self.dim() >= 2, "linalg_cond only supports matrices or batches of matrices, but got a tensor with ", @@ -1762,6 +1778,9 @@ Tensor linalg_cond(const Tensor& self, optional opt_ord) { // The default case is using 2-norm Scalar ord = opt_ord.has_value() ? opt_ord.value() : 2; + c10::variant ord_variant = ord; + _linalg_cond_check_ord(ord_variant); + // NumPy doesn't define the condition number for 0x0 matrices, we return 0.0 for such input if (self.numel() == 0) { auto real_dtype = toValueType(typeMetaToScalarType(self.dtype())); @@ -1791,7 +1810,6 @@ Tensor linalg_cond(const Tensor& self, optional opt_ord) { "linalg_cond with ±1 or ±inf norm types only supports square matrices or batches of square matrices " "but got ", self.size(-1), " by ", self.size(-2), " matrices"); - c10::variant ord_variant = ord; return _linalg_cond_helper(self, ord_variant); } @@ -1820,12 +1838,14 @@ Tensor linalg_cond(const Tensor& self, std::string ord) { "linalg_cond with frobenius or nuclear norm types only supports square matrices or batches of square matrices " "but got ", self.size(-1), " by ", self.size(-2), " matrices"); + c10::variant ord_variant = ord; + _linalg_cond_check_ord(ord_variant); + // NumPy doesn't define the condition number for 0x0 matrices, we return 0.0 for such input if (self.numel() == 0) { return _linalg_cond_empty_matrix(self, self.scalar_type()); } - c10::variant ord_variant = ord; return _linalg_cond_helper(self, ord_variant); } diff --git a/test/test_linalg.py b/test/test_linalg.py index fd7682814720..94922f6bffd2 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1055,6 +1055,12 @@ def test_cond_errors_and_warnings(self, device, dtype): with self.assertRaisesRegex(RuntimeError, "linalg_cond does not support yet"): torch.linalg.cond(a, p) + # check invalid norm type + a = torch.ones(3, 3, dtype=dtype, device=device) + for p in ['wrong_norm', 5]: + with self.assertRaisesRegex(RuntimeError, f"linalg_cond got an invalid norm type: {p}"): + torch.linalg.cond(a, p) + # Test autograd and jit functionality for linalg functions. # TODO: Once support for linalg functions is added to method_tests in common_methods_invocations.py, # the `test_cases` entries below should be moved there. These entries are in a similar format, From 9008c10d63e7f5ddd0f06bbd5c7f1548c945d917 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 2 Dec 2020 03:25:13 -0600 Subject: [PATCH 42/44] Updated documentation --- torch/linalg/__init__.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 4cbd1af366e9..2f0f5cd37e94 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -344,16 +344,16 @@ Computes the condition number of a matrix :attr:`input`, or of each matrix in a batched :attr:`input`, using the matrix norm defined by :attr:`p`. -The condition number is defined as the matrix norm of -:attr:`input` times the matrix norm of the inverse of :attr:`input`. +For norms ``p = {'fro', 'nuc', inf, -inf, 1, -1}`` this is defined as the matrix norm of :attr:`input` +times the matrix norm of the inverse of :attr:`input`. And for norms ``p = {None, 2, -2}`` this is defined as +the ratio between the largest and smallest singular values. This function supports ``float``, ``double``, and only on CPU, ``cfloat`` and ``cdouble`` dtypes for :attr:`input`. -.. note:: For `p = {None, 2, -2}` there is a relation between singular value decomposition and matrix norm. - For this case :func:`torch.linalg.svd` is used for computing the condition number as the ratio of - the largest and smallest singular values. Since :func:`torch.linalg.svd` is used :attr:`input` can be non-square. - For other norm types :attr:`input` must be a square matrix or a batch of square matrices. - If :attr:`input` does not satisfy this requirement then a RuntimeError will be thrown. +.. note:: For ``p = {None, 2, -2}`` the condition number is computed as the ratio between the largest and smallest singular values + computed using :func:`torch.linalg.svd`. For these norms :attr:`input` may be a non-square matrix or batch of non-square matrices. + For other norms, however, :attr:`input` must be a square matrix or a batch of square matrices, + and if this requirement is not satisfied a RuntimeError will be thrown. .. note:: If :attr:`input` is a non-invertible matrix then a tensor containing infinity will be returned. If :attr:`input` is a batch of matrices and one or more of them is not invertible From f61533284c4b16925303aca7754d5dc1d67e8f44 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 2 Dec 2020 12:59:56 +0200 Subject: [PATCH 43/44] Changed the note on non-invertible input Specified that this note applies only for p = {'fro', 'nuc', inf, -inf, 1, -1} --- torch/linalg/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 2f0f5cd37e94..fa7ffb7194ab 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -355,9 +355,9 @@ For other norms, however, :attr:`input` must be a square matrix or a batch of square matrices, and if this requirement is not satisfied a RuntimeError will be thrown. -.. note:: If :attr:`input` is a non-invertible matrix then a tensor containing infinity will be returned. - If :attr:`input` is a batch of matrices and one or more of them is not invertible - then a RuntimeError will be thrown. +.. note:: For ``p = {'fro', 'nuc', inf, -inf, 1, -1}`` if :attr:`input` is a non-invertible matrix then + a tensor containing infinity will be returned. If :attr:`input` is a batch of matrices and one + or more of them is not invertible then a RuntimeError will be thrown. .. note:: When given inputs on a CUDA device, this function synchronizes that device with the CPU. From 994cca67a44f7679a2ba4d4e40ec7b8f84f37c20 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 3 Dec 2020 04:33:43 -0600 Subject: [PATCH 44/44] flake8 --- torch/linalg/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index fa7ffb7194ab..0f89f7304f6f 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -350,8 +350,9 @@ This function supports ``float``, ``double``, and only on CPU, ``cfloat`` and ``cdouble`` dtypes for :attr:`input`. -.. note:: For ``p = {None, 2, -2}`` the condition number is computed as the ratio between the largest and smallest singular values - computed using :func:`torch.linalg.svd`. For these norms :attr:`input` may be a non-square matrix or batch of non-square matrices. +.. note:: For ``p = {None, 2, -2}`` the condition number is computed as the ratio between the largest + and smallest singular values computed using :func:`torch.linalg.svd`. + For these norms :attr:`input` may be a non-square matrix or batch of non-square matrices. For other norms, however, :attr:`input` must be a square matrix or a batch of square matrices, and if this requirement is not satisfied a RuntimeError will be thrown.