From 6b7f14ad980cfd0638e50567631760682e100ef3 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 7 Oct 2020 11:40:20 -0500 Subject: [PATCH 01/44] wip linalg.cholesky --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 8 +++- .../ATen/native/cuda/BatchLinearAlgebra.cu | 7 ++-- aten/src/ATen/native/native_functions.yaml | 5 +++ docs/source/linalg.rst | 1 + test/test_linalg.py | 34 +++++++++++++++ torch/linalg/__init__.py | 41 +++++++++++++++++++ 6 files changed, 92 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index e7e5659babbb..ae9a74a04a50 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -534,11 +534,12 @@ static void apply_cholesky(Tensor& self, bool upper, std::vector& infos auto self_matrix_stride = matrixStride(self); auto batch_size = batchCount(self); auto n = self.size(-2); + auto lda = std::max(int64_t{1}, n); int info; for (int64_t i = 0; i < batch_size; i++) { scalar_t* self_working_ptr = &self_data[i * self_matrix_stride]; - lapackCholesky(uplo, n, self_working_ptr, n, &info); + lapackCholesky(uplo, n, self_working_ptr, lda, &info); infos[i] = info; if (info != 0) { return; @@ -583,6 +584,11 @@ Tensor& cholesky_out(Tensor &result, const Tensor &self, bool upper) { return result; } +Tensor linalg_cholesky(const Tensor &self) { + squareCheckInputs(self); + return at::_cholesky_helper(self, /*upper=*/false).tril_(); +} + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index e9dfe2d9285d..e33a9019fcd6 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -940,10 +940,11 @@ AT_ERROR("cholesky: MAGMA library not found in " auto self_data = self.data_ptr(); magma_int_t n = magma_int_cast(self.size(-2), "self.size(-2)"); + auto lda = std::max(magma_int_t{1}, n); if (self.dim() == 2) { magma_int_t info = 0; - magmaCholesky(uplo, n, self_data, n, &info); + magmaCholesky(uplo, n, self_data, lda, &info); infos[0] = info; } else { auto self_mat_stride = matrixStride(self); @@ -974,14 +975,14 @@ AT_ERROR("cholesky: MAGMA library not found in " magma_int_t* info_array_cur = &info_array[mini_idx]; magmaCholeskyBatched( - uplo, n, self_array_cur, n, info_array_cur, batch_limit, magma_queue); + uplo, n, self_array_cur, lda, info_array_cur, batch_limit, magma_queue); } // Compute whatever is left = batch_size - floor(batch_size / batch_limit) * batch_limit // which concisely is equal to batch_size % batch_limit if (batch_size % batch_limit != 0) { magmaCholeskyBatched( - uplo, n, &self_array[mini_idx], n, &info_array[mini_idx], batch_size % batch_limit, magma_queue); + uplo, n, &self_array[mini_idx], lda, &info_array[mini_idx], batch_size % batch_limit, magma_queue); } for (int64_t i = 0; i < batch_size; i++) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index e64a66a07417..8c28e2f035fd 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -8244,6 +8244,11 @@ # # See linalg_det as an example. +- func: linalg_cholesky(Tensor self) -> Tensor + python_module: linalg + use_c10_dispatcher: full + variants: function + # torch.linalg.det, alias for torch.det - func: linalg_det(Tensor self) -> Tensor python_module: linalg diff --git a/docs/source/linalg.rst b/docs/source/linalg.rst index 834b6a60ac93..1eea045641d4 100644 --- a/docs/source/linalg.rst +++ b/docs/source/linalg.rst @@ -12,5 +12,6 @@ Common linear algebra operations. Functions --------- +.. autofunction:: cholesky .. autofunction:: det .. autofunction:: norm diff --git a/test/test_linalg.py b/test/test_linalg.py index 97c7b926faf4..e7c0fbcc3a31 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -678,6 +678,40 @@ def test_norm_fastpaths(self, device): expected = torch.pow(x.pow(3).abs().sum(1), 1.0 / 3.0) self.assertEqual(result, expected) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_cholesky(self, device, dtype): + from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value, random_symmetric_pd_matrix + + def run_test(shape, batch): + # matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device) + # matrix = matrix + 1e-5*torch.eye(shape, device=device) + # matrix = (shape, *batch, dtype=dtype, device=device) + # matrix = matrix @ matrix.transpose(-2, -1).conj() + if dtype.is_complex: + real_dtype = torch.float32 if dtype is torch.complex64 else torch.float64 + A_real = random_fullrank_matrix_distinct_singular_value(shape, *batch, dtype=real_dtype, device=device) + A_imag = random_fullrank_matrix_distinct_singular_value(shape, *batch, dtype=real_dtype, device=device) + A = A_real + 1j * A_imag + A = A @ A.transpose(-2, -1).conj() + else: + A = random_symmetric_pd_matrix(shape, *batch, dtype=dtype, device=device) + expected_L = np.linalg.cholesky(A.cpu().numpy()) + actual_L = torch.linalg.cholesky(A) + self.assertEqual(actual_L, expected_L) + + shapes = (0, 3, 5) + batches = ((), (3, ), (2, 2)) + for shape, batch in itertools.product(shapes, batches): + run_test(shape, batch) + + # cholesky requires a square matrix + t = torch.randn(2, 3, device=device, dtype=dtype) + with self.assertRaises(RuntimeError): + torch.linalg.cholesky(t) + instantiate_device_type_tests(TestLinalg, globals()) if __name__ == '__main__': diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 5e2b59c45c80..91f42232bf2b 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -8,6 +8,47 @@ # Note: This not only adds doc strings for functions in the linalg namespace, but # also connects the torch.linalg Python namespace to the torch._C._linalg builtins. +cholesky = _add_docstr(_linalg.linalg_cholesky, r""" +linalg.cholesky(input) -> Tensor + +Returns the Cholesky decomposition. + +Computes the Cholesky decomposition of a Hermitian positive-definite +matrix :math:`A` or for batches of Hermitian positive-definite matrices. +The returned matrix ``L`` is lower-triangular, and +the decomposition has the form: + +.. math:: + + A = LL^H + +If :attr:`input` is a batch of Hermitian positive-definite +matrices, then the returned tensor will be composed of lower-triangular Cholesky factors +of each of the individual matrices. + +Args: + input (Tensor): the input tensor :math:`A` of size :math:`(*, n, n)` where `*` is zero or more + batch dimensions consisting of symmetric positive-definite matrices. + +Example:: + + >>> a = torch.randn(2, 2, dtype=torch.complex128) + >>> a = a + a.t().conj() # To make a Hermitian + >>> l = torch.cholesky(a) + >>> a + tensor([[ 2.4112, -0.7486, 1.4551], + [-0.7486, 1.3544, 0.1294], + [ 1.4551, 0.1294, 1.6724]]) + >>> l + tensor([[ 1.5528, 0.0000, 0.0000], + [-0.4821, 1.0592, 0.0000], + [ 0.9371, 0.5487, 0.7023]]) + >>> torch.mm(l, l.t().conj()) + tensor([[ 2.4112, -0.7486, 1.4551], + [-0.7486, 1.3544, 0.1294], + [ 1.4551, 0.1294, 1.6724]]) +""") + det = _add_docstr(_linalg.linalg_det, r""" linalg.det(input) -> Tensor From f7a08f4fc5e960581ae01426c7794c0b0086a3c5 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 8 Oct 2020 07:21:34 -0500 Subject: [PATCH 02/44] Added xfailed test case --- test/test_linalg.py | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index e7c0fbcc3a31..c7d524ab6643 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -6,7 +6,8 @@ from torch.testing._internal.common_utils import \ (TestCase, run_tests, TEST_NUMPY, IS_MACOS, IS_WINDOWS, TEST_WITH_ASAN) from torch.testing._internal.common_device_type import \ - (instantiate_device_type_tests, dtypes, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride) + (instantiate_device_type_tests, dtypes, dtypesIfCPU, dtypesIfCUDA, + onlyCUDA, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride) from torch.testing._internal.jit_metaprogramming_utils import gen_script_fn_and_args from torch.autograd import gradcheck @@ -681,19 +682,16 @@ def test_norm_fastpaths(self, device): @skipCUDAIfNoMagma @skipCPUIfNoLapack @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @dtypesIfCPU(torch.float32, torch.float64, torch.complex64, torch.complex128) + @dtypesIfCUDA(torch.float32, torch.float64) def test_cholesky(self, device, dtype): - from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value, random_symmetric_pd_matrix + from torch.testing._internal.common_utils import random_matrix, random_symmetric_pd_matrix def run_test(shape, batch): - # matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device) - # matrix = matrix + 1e-5*torch.eye(shape, device=device) - # matrix = (shape, *batch, dtype=dtype, device=device) - # matrix = matrix @ matrix.transpose(-2, -1).conj() if dtype.is_complex: real_dtype = torch.float32 if dtype is torch.complex64 else torch.float64 - A_real = random_fullrank_matrix_distinct_singular_value(shape, *batch, dtype=real_dtype, device=device) - A_imag = random_fullrank_matrix_distinct_singular_value(shape, *batch, dtype=real_dtype, device=device) + A_real = random_matrix(shape, shape, *batch, dtype=real_dtype, device=device) + A_imag = random_matrix(shape, shape, *batch, dtype=real_dtype, device=device) A = A_real + 1j * A_imag A = A @ A.transpose(-2, -1).conj() else: @@ -712,6 +710,24 @@ def run_test(shape, batch): with self.assertRaises(RuntimeError): torch.linalg.cholesky(t) + # TODO: once there is more support for complex dtypes on GPU, they shall be added to above test + # particularly when RuntimeError: _th_bmm_out not supported on CUDAType for ComplexFloat is fixed + @unittest.expectedFailure + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + @onlyCUDA + @skipCUDAIfNoMagma + @dtypes(torch.complex64, torch.complex128) + def test_cholesky_xfailed(self, device, dtype): + from torch.testing._internal.common_utils import random_matrix + real_dtype = torch.float32 if dtype is torch.complex64 else torch.float64 + A_real = random_matrix(shape, shape, *batch, dtype=real_dtype, device=device) + A_imag = random_matrix(shape, shape, *batch, dtype=real_dtype, device=device) + A = A_real + 1j * A_imag + A = A @ A.transpose(-2, -1).conj() + expected_L = np.linalg.cholesky(A.cpu().numpy()) + actual_L = torch.linalg.cholesky(A) + self.assertEqual(actual_L, expected_L) + instantiate_device_type_tests(TestLinalg, globals()) if __name__ == '__main__': From 32e10f89caddee2788d133ebf9da1b06e16e7ec0 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 8 Oct 2020 07:35:58 -0500 Subject: [PATCH 03/44] Added cholesky to csrc/api/include/torch/linalg.h --- torch/csrc/api/include/torch/linalg.h | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/torch/csrc/api/include/torch/linalg.h b/torch/csrc/api/include/torch/linalg.h index 5ce90dcc972e..7bdfa76d2ce3 100644 --- a/torch/csrc/api/include/torch/linalg.h +++ b/torch/csrc/api/include/torch/linalg.h @@ -8,6 +8,10 @@ namespace linalg { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { +inline Tensor cholesky(const Tensor& self) { + return torch::linalg_cholesky(self); +} + inline Tensor det(const Tensor& self) { return torch::linalg_det(self); } @@ -31,6 +35,20 @@ inline Tensor& norm_out(Tensor& result, const Tensor& self, std::string ord, opt } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ +/// Cholesky decomposition +/// +/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.cholesky +/// +/// Example: +/// ``` +/// auto A = torch::randn({4, 4}); +/// auto A = torch::matmul(A, A.t()); +/// auto L = torch::linalg::cholesky(A); +/// assert(torch::allclose(torch::matmul(L, L.t()), A)); +/// ``` +inline Tensor cholesky(const Tensor& self) { + return detail::cholesky(self); +} /// See the documentation of torch.linalg.det inline Tensor linalg_det(const Tensor& self) { From 3c4d5a4ff2f354bce45c69fd9bed94e69607ce77 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Fri, 9 Oct 2020 01:41:51 -0500 Subject: [PATCH 04/44] Updated example in docs --- torch/linalg/__init__.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 91f42232bf2b..94837816d99c 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -13,8 +13,8 @@ Returns the Cholesky decomposition. -Computes the Cholesky decomposition of a Hermitian positive-definite -matrix :math:`A` or for batches of Hermitian positive-definite matrices. +Computes the Cholesky decomposition of a Hermitian (or symmetric for real-valued matrices) +positive-definite matrix :math:`A` or for batches of Hermitian positive-definite matrices. The returned matrix ``L`` is lower-triangular, and the decomposition has the form: @@ -33,20 +33,17 @@ Example:: >>> a = torch.randn(2, 2, dtype=torch.complex128) - >>> a = a + a.t().conj() # To make a Hermitian - >>> l = torch.cholesky(a) + >>> a = torch.mm(a, a.t().conj()) # To make a Hermitian + >>> l = torch.linalg.cholesky(a) >>> a - tensor([[ 2.4112, -0.7486, 1.4551], - [-0.7486, 1.3544, 0.1294], - [ 1.4551, 0.1294, 1.6724]]) + tensor([[2.5266+0.0000j, 1.9586-2.0626j], + [1.9586+2.0626j, 9.4160+0.0000j]], dtype=torch.complex128) >>> l - tensor([[ 1.5528, 0.0000, 0.0000], - [-0.4821, 1.0592, 0.0000], - [ 0.9371, 0.5487, 0.7023]]) + tensor([[1.5895+0.0000j, 0.0000+0.0000j], + [1.2322+1.2976j, 2.4928+0.0000j]], dtype=torch.complex128) >>> torch.mm(l, l.t().conj()) - tensor([[ 2.4112, -0.7486, 1.4551], - [-0.7486, 1.3544, 0.1294], - [ 1.4551, 0.1294, 1.6724]]) + tensor([[2.5266+0.0000j, 1.9586-2.0626j], + [1.9586+2.0626j, 9.4160+0.0000j]], dtype=torch.complex128) """) det = _add_docstr(_linalg.linalg_det, r""" From 307020ec7accf9a1954fe12fc241b733880ac620 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Fri, 9 Oct 2020 01:48:42 -0500 Subject: [PATCH 05/44] Added random_hermitian_pd_matrix for the test --- test/test_linalg.py | 19 ++++--------------- torch/testing/_internal/common_utils.py | 8 ++++++++ 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index c7d524ab6643..c16b56cabd87 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -685,17 +685,10 @@ def test_norm_fastpaths(self, device): @dtypesIfCPU(torch.float32, torch.float64, torch.complex64, torch.complex128) @dtypesIfCUDA(torch.float32, torch.float64) def test_cholesky(self, device, dtype): - from torch.testing._internal.common_utils import random_matrix, random_symmetric_pd_matrix + from torch.testing._internal.common_utils import random_hermitian_pd_matrix def run_test(shape, batch): - if dtype.is_complex: - real_dtype = torch.float32 if dtype is torch.complex64 else torch.float64 - A_real = random_matrix(shape, shape, *batch, dtype=real_dtype, device=device) - A_imag = random_matrix(shape, shape, *batch, dtype=real_dtype, device=device) - A = A_real + 1j * A_imag - A = A @ A.transpose(-2, -1).conj() - else: - A = random_symmetric_pd_matrix(shape, *batch, dtype=dtype, device=device) + A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device) expected_L = np.linalg.cholesky(A.cpu().numpy()) actual_L = torch.linalg.cholesky(A) self.assertEqual(actual_L, expected_L) @@ -718,12 +711,8 @@ def run_test(shape, batch): @skipCUDAIfNoMagma @dtypes(torch.complex64, torch.complex128) def test_cholesky_xfailed(self, device, dtype): - from torch.testing._internal.common_utils import random_matrix - real_dtype = torch.float32 if dtype is torch.complex64 else torch.float64 - A_real = random_matrix(shape, shape, *batch, dtype=real_dtype, device=device) - A_imag = random_matrix(shape, shape, *batch, dtype=real_dtype, device=device) - A = A_real + 1j * A_imag - A = A @ A.transpose(-2, -1).conj() + from torch.testing._internal.common_utils import random_hermitian_pd_matrix + A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device) expected_L = np.linalg.cholesky(A.cpu().numpy()) actual_L = torch.linalg.cholesky(A) self.assertEqual(actual_L, expected_L) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 9c9d27bf195b..186e2ab5a837 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1528,6 +1528,14 @@ def random_symmetric_pd_matrix(matrix_size, *batch_dims, **kwargs): + torch.eye(matrix_size, dtype=dtype, device=device) * 1e-5 +def random_hermitian_pd_matrix(matrix_size, *batch_dims, **kwargs): + dtype = kwargs.get('dtype', torch.double) + device = kwargs.get('device', 'cpu') + A = torch.randn(*(batch_dims + (matrix_size, matrix_size)), + dtype=dtype, device=device) + return torch.matmul(A, A.transpose(-2, -1).conj()) + + def make_nonzero_det(A, sign=None, min_singular_value=0.1): u, s, v = A.svd() s.clamp_(min=min_singular_value) From 9e9e0c088365841caf4f0dfa68e2d60b0fb7873b Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Fri, 9 Oct 2020 01:51:23 -0500 Subject: [PATCH 06/44] Use random_hermitian_pd_matrix in the test_torch/cholesky --- test/test_torch.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/test/test_torch.py b/test/test_torch.py index 3ff5a1d73822..fb2b2c8bb638 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -7795,19 +7795,9 @@ def cholesky_test_helper(n, batch_dims, upper): @skipCPUIfNoLapack @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) def test_cholesky(self, device, dtype): - from torch.testing._internal.common_utils import \ - (random_symmetric_pd_matrix, - random_fullrank_matrix_distinct_singular_value) + from torch.testing._internal.common_utils import random_hermitian_pd_matrix - # This is a workaround while there is no support for complex random_symmetric_pd_matrix - if dtype.is_complex: - real_dtype = torch.float32 if dtype is torch.complex64 else torch.float64 - A_real = random_fullrank_matrix_distinct_singular_value(10, dtype=real_dtype, device=device) - A_imag = random_fullrank_matrix_distinct_singular_value(10, dtype=real_dtype, device=device) - A = A_real + 1j * A_imag - A = A @ A.t().conj() - else: - A = random_symmetric_pd_matrix(10, dtype=dtype, device=device) + A = random_hermitian_pd_matrix(10, dtype=dtype, device=device) # default Case C = torch.cholesky(A) From e74de2caf63547f88f38040e78af49ba6ada5c84 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 13 Oct 2020 07:47:44 -0500 Subject: [PATCH 07/44] No need for skip if numpy not found anymore --- test/test_linalg.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index c7be24600320..c921971ca9f2 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -924,7 +924,6 @@ def test_nuclear_norm_exceptions_old(self, device): @skipCUDAIfNoMagma @skipCPUIfNoLapack - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @dtypesIfCPU(torch.float32, torch.float64, torch.complex64, torch.complex128) @dtypesIfCUDA(torch.float32, torch.float64) def test_cholesky(self, device, dtype): @@ -949,7 +948,6 @@ def run_test(shape, batch): # TODO: once there is more support for complex dtypes on GPU, they shall be added to above test # particularly when RuntimeError: _th_bmm_out not supported on CUDAType for ComplexFloat is fixed @unittest.expectedFailure - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @onlyCUDA @skipCUDAIfNoMagma @dtypes(torch.complex64, torch.complex128) From 934367840e957541a3e52f3bad0d51cae2ef51af Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 13 Oct 2020 08:00:17 -0500 Subject: [PATCH 08/44] Added larger input case --- test/test_linalg.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index c921971ca9f2..734c80f46cb5 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -937,7 +937,8 @@ def run_test(shape, batch): shapes = (0, 3, 5) batches = ((), (3, ), (2, 2)) - for shape, batch in itertools.product(shapes, batches): + larger_input_case = [(100, (5, ))] + for shape, batch in list(itertools.product(shapes, batches)) + larger_input_case: run_test(shape, batch) # cholesky requires a square matrix From 604f0a87f2ef32856daa30a002c9a7cd42ab44ce Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 13 Oct 2020 08:49:11 -0500 Subject: [PATCH 09/44] Added assertRaises tests for cholesky Added precisionOverride. MAGMA cholesky seem to disagree a bit with NumPy on large input hence 1e-2 atol --- test/test_linalg.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 734c80f46cb5..13f91853d13e 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -922,6 +922,8 @@ def test_nuclear_norm_exceptions_old(self, device): self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 0)) self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2)) + + @precisionOverride({torch.float: 1e-2, torch.cfloat: 1e-4}) @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypesIfCPU(torch.float32, torch.float64, torch.complex64, torch.complex128) @@ -941,10 +943,27 @@ def run_test(shape, batch): for shape, batch in list(itertools.product(shapes, batches)) + larger_input_case: run_test(shape, batch) - # cholesky requires a square matrix - t = torch.randn(2, 3, device=device, dtype=dtype) - with self.assertRaises(RuntimeError): - torch.linalg.cholesky(t) + # cholesky requires the input to be a square matrix + A = torch.randn(2, 3, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'): + torch.linalg.cholesky(A) + with self.assertRaisesRegex(np.linalg.LinAlgError, r'Last 2 dimensions of the array must be square'): + np.linalg.cholesky(A.cpu().numpy()) + + # cholesky requires the input to be a matrix + A = torch.randn(2, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, r'must have at least 2 dimensions'): + torch.linalg.cholesky(A) + with self.assertRaisesRegex(np.linalg.LinAlgError, r'1-dimensional array given\. Array must be at least two-dimensional'): + np.linalg.cholesky(A.cpu().numpy()) + + # if the input matrix is singular, an error should be raised + A = torch.eye(3, 3, dtype=dtype, device=device) + A[-1, -1] = 0 # Now A is singular + with self.assertRaisesRegex(RuntimeError, r'U\(3,3\) is zero, singular U\.'): + torch.linalg.cholesky(A) + with self.assertRaisesRegex(np.linalg.LinAlgError, r'Matrix is not positive definite'): + np.linalg.cholesky(A.cpu().numpy()) # TODO: once there is more support for complex dtypes on GPU, they shall be added to above test # particularly when RuntimeError: _th_bmm_out not supported on CUDAType for ComplexFloat is fixed From 5f8004277685e7606bc801c36391000f9c5c7fa5 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 13 Oct 2020 09:03:32 -0500 Subject: [PATCH 10/44] Added a note to the docs that the error is given if the input is not positive definite --- torch/linalg/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 94837816d99c..4e2f95bde58e 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -26,6 +26,9 @@ matrices, then the returned tensor will be composed of lower-triangular Cholesky factors of each of the individual matrices. +.. note:: If the :attr:`input` is not Hermitian positive-definite matrix a RuntimeError is raised + saying that the input is singular and mentioning which minor of the input matrix is not positive-definite. + Args: input (Tensor): the input tensor :math:`A` of size :math:`(*, n, n)` where `*` is zero or more batch dimensions consisting of symmetric positive-definite matrices. From da4e88b264a822d254f6a4b1323d551b3306090b Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 13 Oct 2020 09:40:11 -0500 Subject: [PATCH 11/44] Enabled autograd for linalg_cholesky --- test/test_linalg.py | 33 +++++++++++++++++++++++++++-- tools/autograd/derivatives.yaml | 3 +++ tools/autograd/gen_variable_type.py | 2 +- 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 13f91853d13e..a9556b70e150 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -8,9 +8,9 @@ (TestCase, run_tests, TEST_NUMPY, IS_MACOS, IS_WINDOWS, TEST_WITH_ASAN, make_tensor) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, dtypesIfCPU, dtypesIfCUDA, - onlyCUDA, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride) + onlyCUDA, onlyCPU, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride) from torch.testing._internal.jit_metaprogramming_utils import gen_script_fn_and_args -from torch.autograd import gradcheck +from torch.autograd import gradcheck, gradgradcheck if TEST_NUMPY: import numpy as np @@ -978,6 +978,35 @@ def test_cholesky_xfailed(self, device, dtype): actual_L = torch.linalg.cholesky(A) self.assertEqual(actual_L, expected_L) + # TODO: enable CUDA tests once + # RuntimeError: "triangular_solve_cuda" not implemented for 'ComplexDouble' is fixed + @onlyCPU + @skipCPUIfNoLapack + @dtypes(torch.float64, torch.complex128) + def test_cholesky_autograd(self, device, dtype): + def func(root): + x = 0.5 * (root + root.transpose(-1, -2).conj()) + return torch.linalg.cholesky(x) + + def run_test(shape): + root = torch.rand(*shape, dtype=dtype, device=device, requires_grad=True) + root = root + torch.eye(shape[-1], dtype=dtype, device=device) + + gradcheck(func, root) + # TODO: gradgradcheck does not work correctly yet for complex + if not dtype.is_complex: + gradgradcheck(func, root) + + root = torch.rand(*shape, dtype=dtype, device=device) + root = torch.matmul(root, root.transpose(-1, -2).conj()) + root.requires_grad_() + chol = torch.linalg.cholesky(root).sum().backward() + self.assertEqual(root.grad, root.grad.transpose(-1, -2).conj()) # Check the gradient is hermitian + + shapes = ((3, 3), (4, 3, 2, 2)) + for shape in shapes: + run_test(shape) + instantiate_device_type_tests(TestLinalg, globals()) if __name__ == '__main__': diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 9c6cd4c578de..5b2435b20440 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -294,6 +294,9 @@ - name: cholesky(Tensor self, bool upper=False) -> Tensor self: cholesky_backward(grad, upper, result) +- name: linalg_cholesky(Tensor self) -> Tensor + self: cholesky_backward(grad, false, result) + - name: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor self, input2: cholesky_solve_backward(grad, self, input2, result, upper) diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index a1f162f91471..d05cd43de107 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -162,7 +162,7 @@ 'cosh', '__rmul__', 'sgn', 'asin', 'acos', 'sub', 'div', 'cat', 'view_as_complex', 'neg', 'complex', 'select', '_s_where', 'as_strided', 'slice', 'constant_pad_nd', 'unbind', 'split', 'split_with_sizes', 'unsafe_split', 'split_with_sizes_backward', - 'dot', 'vdot', 'cholesky' + 'dot', 'vdot', 'cholesky', 'linalg_cholesky' } # Some operators invalidate the grad_accumulator. Let's reset it. From efb725ca14d3ef524d4f3ca5d1614c27851614f9 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 13 Oct 2020 09:40:42 -0500 Subject: [PATCH 12/44] Added a note to the documentation about complex support --- torch/linalg/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 4e2f95bde58e..f1b94fe53254 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -29,6 +29,10 @@ .. note:: If the :attr:`input` is not Hermitian positive-definite matrix a RuntimeError is raised saying that the input is singular and mentioning which minor of the input matrix is not positive-definite. +.. note:: + Supports real and complex inputs. + Backpropagation for complex inputs is only supported on the CPU. + Args: input (Tensor): the input tensor :math:`A` of size :math:`(*, n, n)` where `*` is zero or more batch dimensions consisting of symmetric positive-definite matrices. From 6297b6b795a6a48b069cde784782c5c4ecefd8f9 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 14 Oct 2020 05:17:30 -0500 Subject: [PATCH 13/44] Added the out= variant --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 14 ++++++++++++++ aten/src/ATen/native/native_functions.yaml | 4 ++++ test/test_linalg.py | 11 +++++++++++ torch/csrc/api/include/torch/linalg.h | 8 ++++++++ torch/linalg/__init__.py | 5 ++++- 5 files changed, 41 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index ae9a74a04a50..0974b146d6b7 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -589,6 +589,20 @@ Tensor linalg_cholesky(const Tensor &self) { return at::_cholesky_helper(self, /*upper=*/false).tril_(); } +Tensor& linalg_cholesky_out(Tensor &result, const Tensor &self) { + squareCheckInputs(self); + CheckedFrom c = "linalg_cholesky_out"; + TensorArg result_arg(result, "result", 0); + TensorArg self_arg(self, "self", 1); + checkSameSize(c, result_arg, self_arg); + checkSameType(c, result_arg, self_arg); + TORCH_CHECK(self.device() == result.device(), + "Expected input and out tensors to be on the same device, but found input on ", + self.device(), " and out on ", result.device(), " instead."); + result.copy_(native::linalg_cholesky(self)); + return result; +} + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index c16e28797d81..f5e9db1133d9 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -8289,6 +8289,10 @@ use_c10_dispatcher: full variants: function +- func: linalg_cholesky.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + variants: function + # torch.linalg.det, alias for torch.det - func: linalg_det(Tensor self) -> Tensor python_module: linalg diff --git a/test/test_linalg.py b/test/test_linalg.py index a9556b70e150..da512790d0ed 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -943,6 +943,17 @@ def run_test(shape, batch): for shape, batch in list(itertools.product(shapes, batches)) + larger_input_case: run_test(shape, batch) + # check the out= variant + A = random_hermitian_pd_matrix(3, 3, dtype=dtype, device=device) + out = torch.empty_like(A) + ans = torch.linalg.cholesky(A, out=out) + self.assertEqual(ans, out) + + # cholesky requires out to have same shape as input + out = torch.empty(2, 3, dtype=dtype, device=device) + with self.assertRaisesRegex(RuntimeError, r'to have same size as tensor for argument'): + torch.linalg.cholesky(A, out=out) + # cholesky requires the input to be a square matrix A = torch.randn(2, 3, device=device, dtype=dtype) with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'): diff --git a/torch/csrc/api/include/torch/linalg.h b/torch/csrc/api/include/torch/linalg.h index 7bdfa76d2ce3..07082cc03640 100644 --- a/torch/csrc/api/include/torch/linalg.h +++ b/torch/csrc/api/include/torch/linalg.h @@ -12,6 +12,10 @@ inline Tensor cholesky(const Tensor& self) { return torch::linalg_cholesky(self); } +inline Tensor cholesky_out(Tensor& result, const Tensor& self) { + return torch::linalg_cholesky_out(result, self); +} + inline Tensor det(const Tensor& self) { return torch::linalg_det(self); } @@ -50,6 +54,10 @@ inline Tensor cholesky(const Tensor& self) { return detail::cholesky(self); } +inline Tensor cholesky_out(Tensor& result, const Tensor& self) { + return detail::cholesky_out(result, self); +} + /// See the documentation of torch.linalg.det inline Tensor linalg_det(const Tensor& self) { return detail::det(self); diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index f1b94fe53254..e9d2a7e69baf 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -9,7 +9,7 @@ # also connects the torch.linalg Python namespace to the torch._C._linalg builtins. cholesky = _add_docstr(_linalg.linalg_cholesky, r""" -linalg.cholesky(input) -> Tensor +linalg.cholesky(input, *, out=None) -> Tensor Returns the Cholesky decomposition. @@ -37,6 +37,9 @@ input (Tensor): the input tensor :math:`A` of size :math:`(*, n, n)` where `*` is zero or more batch dimensions consisting of symmetric positive-definite matrices. +Keyword args: + out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` + Example:: >>> a = torch.randn(2, 2, dtype=torch.complex128) From 709273b06c64dcc8991a3282eb9e3c2e8bb1085a Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 14 Oct 2020 05:35:37 -0500 Subject: [PATCH 14/44] Moved error checks to a separate test --- test/test_linalg.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/test_linalg.py b/test/test_linalg.py index da512790d0ed..075c62895341 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -949,7 +949,14 @@ def run_test(shape, batch): ans = torch.linalg.cholesky(A, out=out) self.assertEqual(ans, out) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_cholesky_errors(self, device, dtype): + from torch.testing._internal.common_utils import random_hermitian_pd_matrix + # cholesky requires out to have same shape as input + A = random_hermitian_pd_matrix(3, dtype=dtype, device=device) out = torch.empty(2, 3, dtype=dtype, device=device) with self.assertRaisesRegex(RuntimeError, r'to have same size as tensor for argument'): torch.linalg.cholesky(A, out=out) From 21cfca07aaf9324d8aaf25f71b1c670072734cac Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 14 Oct 2020 05:38:53 -0500 Subject: [PATCH 15/44] Added xfailed test for cholesky cuda autograd --- test/test_linalg.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 075c62895341..6ccdead6c2b2 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -996,8 +996,6 @@ def test_cholesky_xfailed(self, device, dtype): actual_L = torch.linalg.cholesky(A) self.assertEqual(actual_L, expected_L) - # TODO: enable CUDA tests once - # RuntimeError: "triangular_solve_cuda" not implemented for 'ComplexDouble' is fixed @onlyCPU @skipCPUIfNoLapack @dtypes(torch.float64, torch.complex128) @@ -1025,6 +1023,22 @@ def run_test(shape): for shape in shapes: run_test(shape) + # TODO: enable CUDA tests once (merge with above test) + # RuntimeError: "triangular_solve_cuda" not implemented for 'ComplexDouble' is fixed + @unittest.expectedFailure + @onlyCUDA + @skipCUDAIfNoMagma + @dtypes(torch.complex64, torch.complex128) + def test_cholesky_autograd_xfailed(self, device, dtype): + def func(root): + x = 0.5 * (root + root.transpose(-1, -2).conj()) + return torch.linalg.cholesky(x) + + shape = (3, 3) + root = torch.rand(*shape, dtype=dtype, device=device, requires_grad=True) + root = root + torch.eye(shape[-1], dtype=dtype, device=device) + gradcheck(func, root) + instantiate_device_type_tests(TestLinalg, globals()) if __name__ == '__main__': From 31cbe7531e93775af24596e534c529daf9eef472 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 14 Oct 2020 05:57:38 -0500 Subject: [PATCH 16/44] Only complex128 is needed for autograd test --- test/test_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 6ccdead6c2b2..f4bae853f0e0 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1028,7 +1028,7 @@ def run_test(shape): @unittest.expectedFailure @onlyCUDA @skipCUDAIfNoMagma - @dtypes(torch.complex64, torch.complex128) + @dtypes(torch.complex128) def test_cholesky_autograd_xfailed(self, device, dtype): def func(root): x = 0.5 * (root + root.transpose(-1, -2).conj()) From df0172e78593c86e35d9790da21d4ead3d1033ec Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 14 Oct 2020 05:58:05 -0500 Subject: [PATCH 17/44] Added a docstring for random_hermitian_pd_matrix --- torch/testing/_internal/common_utils.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 186e2ab5a837..8ae90056c002 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1528,9 +1528,14 @@ def random_symmetric_pd_matrix(matrix_size, *batch_dims, **kwargs): + torch.eye(matrix_size, dtype=dtype, device=device) * 1e-5 -def random_hermitian_pd_matrix(matrix_size, *batch_dims, **kwargs): - dtype = kwargs.get('dtype', torch.double) - device = kwargs.get('device', 'cpu') +def random_hermitian_pd_matrix(matrix_size, *batch_dims, dtype, device): + """ + Returns a batch of random Hermitian positive-definite matrices. + The shape of the result is batch_dims + (matrix_size, matrix_size) + + The following example creates a tensor of size 2 x 4 x 3 x 3 + >>> matrices = random_hermitian_pd_matrix(3, 2, 4, dtype=dtype, device=device) + """ A = torch.randn(*(batch_dims + (matrix_size, matrix_size)), dtype=dtype, device=device) return torch.matmul(A, A.transpose(-2, -1).conj()) From b54c7f8a46bb760839e05b35c4038d8620dc1594 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 14 Oct 2020 06:36:07 -0500 Subject: [PATCH 18/44] Updated linalg.cholesky docs --- torch/linalg/__init__.py | 53 ++++++++++++++++++++++++++++------------ 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index e9d2a7e69baf..70bd57371a75 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -11,39 +11,36 @@ cholesky = _add_docstr(_linalg.linalg_cholesky, r""" linalg.cholesky(input, *, out=None) -> Tensor -Returns the Cholesky decomposition. - Computes the Cholesky decomposition of a Hermitian (or symmetric for real-valued matrices) -positive-definite matrix :math:`A` or for batches of Hermitian positive-definite matrices. +positive-definite matrix or the Cholesky decompositions for a batch of such matrices. The returned matrix ``L`` is lower-triangular, and the decomposition has the form: .. math:: - A = LL^H + \text{input} = LL^H -If :attr:`input` is a batch of Hermitian positive-definite -matrices, then the returned tensor will be composed of lower-triangular Cholesky factors -of each of the individual matrices. +where :math:`L^H` is the conjugate transpose of :math:`L`, which is just a transpose for the case +of real-valued input matrices. +In code it translates to ``input = L @ L.t()` if :attr:`input` is real-valued and +``input = L @ L.conj().t()`` if :attr:`input` is complex-valued. -.. note:: If the :attr:`input` is not Hermitian positive-definite matrix a RuntimeError is raised - saying that the input is singular and mentioning which minor of the input matrix is not positive-definite. +Supports real and complex inputs. Backpropagation for complex inputs is only supported on the CPU. -.. note:: - Supports real and complex inputs. - Backpropagation for complex inputs is only supported on the CPU. +.. note:: If :attr:`input` is not a Hermitian positive-definite matrix, or if it's a batch of matrices + and one or more of them is not a Hermitian positive-definite matrix, then a RuntimeError will be thrown. Args: - input (Tensor): the input tensor :math:`A` of size :math:`(*, n, n)` where `*` is zero or more - batch dimensions consisting of symmetric positive-definite matrices. + input (Tensor): the input tensor of size :math:`(*, n, n)` where `*` is zero or more + batch dimensions consisting of Hermitian positive-definite matrices. Keyword args: out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` -Example:: +Examples:: >>> a = torch.randn(2, 2, dtype=torch.complex128) - >>> a = torch.mm(a, a.t().conj()) # To make a Hermitian + >>> a = torch.mm(a, a.t().conj()) # makes a Hermitian positive-definite matrix >>> l = torch.linalg.cholesky(a) >>> a tensor([[2.5266+0.0000j, 1.9586-2.0626j], @@ -54,6 +51,30 @@ >>> torch.mm(l, l.t().conj()) tensor([[2.5266+0.0000j, 1.9586-2.0626j], [1.9586+2.0626j, 9.4160+0.0000j]], dtype=torch.complex128) + + >>> a = torch.randn(3, 2, 2, dtype=torch.float64) + >>> a = torch.matmul(a, a.transpose(-2, -1)) # makes a symmetric positive-definite matrix + >>> l = torch.linalg.cholesky(a) + >>> a + tensor([[[ 1.1629, 2.0237], + [ 2.0237, 6.6593]], + + [[ 0.4187, 0.1830], + [ 0.1830, 0.1018]], + + [[ 1.9348, -2.5744], + [-2.5744, 4.6386]]], dtype=torch.float64) + >>> l + tensor([[[ 1.0784, 0.0000], + [ 1.8766, 1.7713]], + + [[ 0.6471, 0.0000], + [ 0.2829, 0.1477]], + + [[ 1.3910, 0.0000], + [-1.8509, 1.1014]]], dtype=torch.float64) + >>> torch.allclose(torch.matmul(l, l.transpose(-2, -1)), a) + True """) det = _add_docstr(_linalg.linalg_det, r""" From 78781830050d3a55458127eb7a5ae7f3b71598c1 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 15 Oct 2020 15:54:04 +0000 Subject: [PATCH 19/44] Added a note on error message for batch of singular matrices --- torch/linalg/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 70bd57371a75..dc95bd06ce69 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -29,6 +29,7 @@ .. note:: If :attr:`input` is not a Hermitian positive-definite matrix, or if it's a batch of matrices and one or more of them is not a Hermitian positive-definite matrix, then a RuntimeError will be thrown. + The error message tells the index of the first problematic matrix in the batch. Args: input (Tensor): the input tensor of size :math:`(*, n, n)` where `*` is zero or more From 0d4a8c73bf03a1e6a40a2f663e320430ca92a63e Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 15 Oct 2020 16:22:00 +0000 Subject: [PATCH 20/44] In tests compare norms of the resulting matrices --- test/test_linalg.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index f4bae853f0e0..f8dab9043c2e 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -923,7 +923,6 @@ def test_nuclear_norm_exceptions_old(self, device): self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2)) - @precisionOverride({torch.float: 1e-2, torch.cfloat: 1e-4}) @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypesIfCPU(torch.float32, torch.float64, torch.complex64, torch.complex128) @@ -935,7 +934,19 @@ def run_test(shape, batch): A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device) expected_L = np.linalg.cholesky(A.cpu().numpy()) actual_L = torch.linalg.cholesky(A) - self.assertEqual(actual_L, expected_L) + matrices_are_equal = np.allclose(actual_L.cpu().numpy(), expected_L) + + # For fp32 individual entries in matrices can differ between PyTorch and NumPy + # Let's compare the norms of matrices instead + if A.numel() > 0 and not matrices_are_equal: + # axis is specified to calculate matrix norm for batched input + expected_norm = np.linalg.norm(expected_L, ord=1, axis=(-2, -1)) + actual_norm = torch.linalg.norm(actual_L, ord=1, axis=(-2, -1)) + norms_are_equal = np.allclose(actual_norm.cpu().numpy(), expected_norm) + matrices_are_equal = np.allclose(actual_L.cpu().numpy(), expected_L, atol=1e-2, rtol=1e-5) + self.assertTrue(matrices_are_equal and norms_are_equal) + else: + self.assertTrue(matrices_are_equal) shapes = (0, 3, 5) batches = ((), (3, ), (2, 2)) From e4832d35daa4582aad3a5e8d5d74148079b1ba9a Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 15 Oct 2020 16:26:54 +0000 Subject: [PATCH 21/44] Added entry in overrides.py --- torch/overrides.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/overrides.py b/torch/overrides.py index d64d7a4f37a4..c36290091c87 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -277,6 +277,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.chain_matmul: lambda *matrices: -1, torch.channel_shuffle: lambda input, groups : -1, torch.cholesky: lambda input, upper=False, out=None: -1, + torch.linalg.cholesky: lambda input, out=None: -1, torch.cholesky_inverse: lambda input, upper=False, out=None: -1, torch.cholesky_solve: lambda input1, input2, upper=False, out=None: -1, torch.choose_qparams_optimized: lambda input, numel, n_bins, ratio, bit_width: -1, From e800b979c75fb7e027c8be6e8c426f6897165fc9 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 15 Oct 2020 16:39:30 +0000 Subject: [PATCH 22/44] Added test case for batch singular input --- test/test_linalg.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/test_linalg.py b/test/test_linalg.py index f8dab9043c2e..fe7d189f9210 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -994,6 +994,14 @@ def test_cholesky_errors(self, device, dtype): with self.assertRaisesRegex(np.linalg.LinAlgError, r'Matrix is not positive definite'): np.linalg.cholesky(A.cpu().numpy()) + # if at least one matrix in the batch is singular, an error should be raised + A = torch.eye(3, 3, dtype=dtype, device=device) + A = A.reshape((1, 3, 3)) + A = A.repeat(5, 1, 1) + A[4, -1, -1] = 0 # Now A[4] is singular + with self.assertRaisesRegex(RuntimeError, r'For batch 4: U\(3,3\) is zero, singular U\.'): + torch.linalg.cholesky(A) + # TODO: once there is more support for complex dtypes on GPU, they shall be added to above test # particularly when RuntimeError: _th_bmm_out not supported on CUDAType for ComplexFloat is fixed @unittest.expectedFailure From 8b58586063d937c5f05ebcd0313c747540779e7f Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 15 Oct 2020 17:11:24 +0000 Subject: [PATCH 23/44] Moved tests for torch.cholesky to test_linalg.py --- test/test_autograd.py | 26 --------- test/test_linalg.py | 124 +++++++++++++++++++++++++++++++++++++++++- test/test_torch.py | 80 --------------------------- 3 files changed, 122 insertions(+), 108 deletions(-) diff --git a/test/test_autograd.py b/test/test_autograd.py index 91da9e79b885..3625ab22a7d7 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2498,32 +2498,6 @@ def test_var_mean_differentiable(self): torch.autograd.backward(r2, grad) self.assertTrue(torch.allclose(input1.grad, input2.grad, rtol=0.01, atol=0.0)) - @skipIfNoLapack - def test_cholesky(self): - def func(root, upper): - x = 0.5 * (root + root.transpose(-1, -2).conj()) - return torch.cholesky(x, upper) - - def run_test(upper, dims, dtype): - root = torch.rand(*dims, dtype=dtype, requires_grad=True) - root = root + torch.eye(dims[-1]) - - gradcheck(func, [root, upper]) - # TODO: gradgradcheck does not work correctly yet for complex - if not dtype.is_complex: - gradgradcheck(func, [root, upper]) - - root = torch.rand(*dims, dtype=dtype) - root = torch.matmul(root, root.transpose(-1, -2).conj()) - root.requires_grad_() - chol = root.cholesky().sum().backward() - self.assertEqual(root.grad, root.grad.transpose(-1, -2).conj()) # Check the gradient is hermitian - - for upper, dims, dtype in product([True, False], - [(3, 3), (4, 3, 2, 2)], - [torch.double, torch.cdouble]): - run_test(upper, dims, dtype) - @skipIfNoLapack def test_cholesky_solve(self): def _test_with_size(A_dims, B_dims, upper): diff --git a/test/test_linalg.py b/test/test_linalg.py index fe7d189f9210..1641851d9e0d 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -5,10 +5,11 @@ from random import randrange from torch.testing._internal.common_utils import \ - (TestCase, run_tests, TEST_NUMPY, IS_MACOS, IS_WINDOWS, TEST_WITH_ASAN, make_tensor) + (TestCase, run_tests, slowTest, TEST_NUMPY, IS_MACOS, IS_WINDOWS, TEST_WITH_ASAN, make_tensor) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, dtypesIfCPU, dtypesIfCUDA, - onlyCUDA, onlyCPU, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride) + onlyCUDA, onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride) +from torch.testing._internal.common_cuda import tf32_on_and_off from torch.testing._internal.jit_metaprogramming_utils import gen_script_fn_and_args from torch.autograd import gradcheck, gradgradcheck @@ -1058,6 +1059,125 @@ def func(root): root = root + torch.eye(shape[-1], dtype=dtype, device=device) gradcheck(func, root) + # NOTE: old_cholesky* tests were moved here from test_torch.py and test_autograd.py + @slowTest + @skipCUDAIf(True, "See issue #26789.") + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.double) + def test_old_cholesky_batched_many_batches(self, device, dtype): + from torch.testing._internal.common_utils import random_symmetric_pd_matrix + + def cholesky_test_helper(n, batchsize, device, upper): + A = random_symmetric_pd_matrix(n, batchsize, dtype=dtype, device=device) + chol_fact = torch.cholesky(A, upper=upper) + if upper: + # Correctness check + self.assertEqual(A, chol_fact.transpose(-2, -1).matmul(chol_fact)) + # Upper triangular check + self.assertEqual(chol_fact, chol_fact.triu()) + else: + # Correctness check + self.assertEqual(A, chol_fact.matmul(chol_fact.transpose(-2, -1))) + # Lower triangular check + self.assertEqual(chol_fact, chol_fact.tril()) + + for upper, batchsize in itertools.product([True, False], [262144, 524288]): + cholesky_test_helper(2, batchsize, device, upper) + + @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_old_cholesky_batched(self, device, dtype): + from torch.testing._internal.common_utils import \ + (random_symmetric_pd_matrix, + random_fullrank_matrix_distinct_singular_value) + + def cholesky_test_helper(n, batch_dims, upper): + # This is a workaround while there is no support for batched complex random_symmetric_pd_matrix + if dtype.is_complex: + real_dtype = torch.float32 if dtype is torch.complex64 else torch.float64 + A_real = random_fullrank_matrix_distinct_singular_value(n, *batch_dims, dtype=real_dtype, device=device) + A_imag = random_fullrank_matrix_distinct_singular_value(n, *batch_dims, dtype=real_dtype, device=device) + A = A_real + 1j * A_imag + # There is no support for complex batched matmul yet + matmul_list = [] + for mat in A.contiguous().view(-1, n, n): + matmul_list.append(mat @ mat.t().conj()) + A = torch.stack(matmul_list).view(*batch_dims, n, n) + else: + A = random_symmetric_pd_matrix(n, *batch_dims, dtype=dtype, device=device) + cholesky_exp = torch.stack([m.cholesky(upper=upper) for m in A.reshape(-1, n, n)]) + cholesky_exp = cholesky_exp.reshape_as(A) + self.assertEqual(cholesky_exp, torch.cholesky(A, upper=upper)) + + for upper, batchsize in itertools.product([True, False], [(3,), (3, 4), (2, 3, 4)]): + cholesky_test_helper(3, batchsize, upper) + + @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @tf32_on_and_off(0.01) + def test_old_cholesky(self, device, dtype): + from torch.testing._internal.common_utils import random_hermitian_pd_matrix + + A = random_hermitian_pd_matrix(10, dtype=dtype, device=device) + + # default Case + C = torch.cholesky(A) + B = torch.mm(C, C.t().conj()) + self.assertEqual(A, B, atol=1e-14, rtol=0) + + # test Upper Triangular + U = torch.cholesky(A, True) + B = torch.mm(U.t().conj(), U) + self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (upper) did not allow rebuilding the original matrix') + + # test Lower Triangular + L = torch.cholesky(A, False) + B = torch.mm(L, L.t().conj()) + self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (lower) did not allow rebuilding the original matrix') + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_old_cholesky_empty(self, device, dtype): + def run_test(upper): + A = torch.empty(0, 0, dtype=dtype, device=device) + chol = torch.cholesky(A, upper) + chol_A = torch.matmul(chol, chol.t().conj()) + self.assertEqual(A, chol_A) + for upper in [True, False]: + run_test(upper) + + @onlyCPU + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_old_cholesky_autograd(self, device, dtype): + def func(root, upper): + x = 0.5 * (root + root.transpose(-1, -2).conj()) + return torch.cholesky(x, upper) + + def run_test(upper, dims): + root = torch.rand(*dims, dtype=dtype, device=device, requires_grad=True) + root = root + torch.eye(dims[-1]) + + gradcheck(func, [root, upper]) + # TODO: gradgradcheck does not work correctly yet for complex + if not dtype.is_complex: + gradgradcheck(func, [root, upper]) + + root = torch.rand(*dims, dtype=dtype, device=device) + root = torch.matmul(root, root.transpose(-1, -2).conj()) + root.requires_grad_() + chol = root.cholesky().sum().backward() + self.assertEqual(root.grad, root.grad.transpose(-1, -2).conj()) # Check the gradient is hermitian + + for upper, dims in itertools.product([True, False], [(3, 3), (4, 3, 2, 2)]): + run_test(upper, dims) + instantiate_device_type_tests(TestLinalg, globals()) if __name__ == '__main__': diff --git a/test/test_torch.py b/test/test_torch.py index 4eea78a343ad..ac9bfe0522c4 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -7759,86 +7759,6 @@ def test_cholesky_inverse(self, device, dtype): inv1 = torch.cholesky_inverse(chol, False) self.assertLessEqual(inv0.dist(inv1), 1e-12) - @slowTest - @skipCUDAIf(True, "See issue #26789.") - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_cholesky_batched_many_batches(self, device, dtype): - from torch.testing._internal.common_utils import random_symmetric_pd_matrix - - def cholesky_test_helper(n, batchsize, device, upper): - A = random_symmetric_pd_matrix(n, batchsize, dtype=dtype, device=device) - chol_fact = torch.cholesky(A, upper=upper) - if upper: - # Correctness check - self.assertEqual(A, chol_fact.transpose(-2, -1).matmul(chol_fact)) - # Upper triangular check - self.assertEqual(chol_fact, chol_fact.triu()) - else: - # Correctness check - self.assertEqual(A, chol_fact.matmul(chol_fact.transpose(-2, -1))) - # Lower triangular check - self.assertEqual(chol_fact, chol_fact.tril()) - - for upper, batchsize in product([True, False], [262144, 524288]): - cholesky_test_helper(2, batchsize, device, upper) - - @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) - def test_cholesky_batched(self, device, dtype): - from torch.testing._internal.common_utils import \ - (random_symmetric_pd_matrix, - random_fullrank_matrix_distinct_singular_value) - - def cholesky_test_helper(n, batch_dims, upper): - # This is a workaround while there is no support for complex random_symmetric_pd_matrix - if dtype.is_complex: - real_dtype = torch.float32 if dtype is torch.complex64 else torch.float64 - A_real = random_fullrank_matrix_distinct_singular_value(n, *batch_dims, dtype=real_dtype, device=device) - A_imag = random_fullrank_matrix_distinct_singular_value(n, *batch_dims, dtype=real_dtype, device=device) - A = A_real + 1j * A_imag - # There is no support for complex batched matmul yet - matmul_list = [] - for mat in A.contiguous().view(-1, n, n): - matmul_list.append(mat @ mat.t().conj()) - A = torch.stack(matmul_list).view(*batch_dims, n, n) - else: - A = random_symmetric_pd_matrix(n, *batch_dims, dtype=dtype, device=device) - cholesky_exp = torch.stack([m.cholesky(upper=upper) for m in A.reshape(-1, n, n)]) - cholesky_exp = cholesky_exp.reshape_as(A) - self.assertEqual(cholesky_exp, torch.cholesky(A, upper=upper)) - - for upper, batchsize in product([True, False], [(3,), (3, 4), (2, 3, 4)]): - cholesky_test_helper(3, batchsize, upper) - - @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) - @tf32_on_and_off(0.01) - def test_cholesky(self, device, dtype): - from torch.testing._internal.common_utils import random_hermitian_pd_matrix - - A = random_hermitian_pd_matrix(10, dtype=dtype, device=device) - - # default Case - C = torch.cholesky(A) - B = torch.mm(C, C.t().conj()) - self.assertEqual(A, B, atol=1e-14, rtol=0) - - # test Upper Triangular - U = torch.cholesky(A, True) - B = torch.mm(U.t().conj(), U) - self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (upper) did not allow rebuilding the original matrix') - - # test Lower Triangular - L = torch.cholesky(A, False) - B = torch.mm(L, L.t().conj()) - self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (lower) did not allow rebuilding the original matrix') - def test_view(self, device): tensor = torch.rand(15, device=device) template = torch.rand(3, 5, device=device) From 12d11cecb836ca85717a59601bffdabc00c286d5 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 15 Oct 2020 18:24:49 +0000 Subject: [PATCH 24/44] Added a dispatch section with DefaultBackend in native_functions.yaml --- aten/src/ATen/native/native_functions.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 0c12737d184d..3a74651dca3f 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -8818,10 +8818,14 @@ python_module: linalg use_c10_dispatcher: full variants: function + dispatch: + DefaultBackend: linalg_cholesky - func: linalg_cholesky.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) python_module: linalg variants: function + dispatch: + DefaultBackend: linalg_cholesky_out # torch.linalg.det, alias for torch.det - func: linalg_det(Tensor self) -> Tensor From 29f94c9e48d38d4a4f066b6e7271ebef2398d48d Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 15 Oct 2020 18:46:19 +0000 Subject: [PATCH 25/44] gradgradcheck for cholesky now works --- test/test_linalg.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 1641851d9e0d..08c6ea90f7d7 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1029,9 +1029,7 @@ def run_test(shape): root = root + torch.eye(shape[-1], dtype=dtype, device=device) gradcheck(func, root) - # TODO: gradgradcheck does not work correctly yet for complex - if not dtype.is_complex: - gradgradcheck(func, root) + gradgradcheck(func, root) root = torch.rand(*shape, dtype=dtype, device=device) root = torch.matmul(root, root.transpose(-1, -2).conj()) @@ -1165,9 +1163,7 @@ def run_test(upper, dims): root = root + torch.eye(dims[-1]) gradcheck(func, [root, upper]) - # TODO: gradgradcheck does not work correctly yet for complex - if not dtype.is_complex: - gradgradcheck(func, [root, upper]) + gradgradcheck(func, [root, upper]) root = torch.rand(*dims, dtype=dtype, device=device) root = torch.matmul(root, root.transpose(-1, -2).conj()) From 00c41ed6a124ac64de2bd8d28eef27f8835c2fc8 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 19 Oct 2020 08:44:25 +0000 Subject: [PATCH 26/44] Updated documentation for linalg cholesky --- torch/linalg/__init__.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index dc95bd06ce69..a0647b46a3ef 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -13,27 +13,28 @@ Computes the Cholesky decomposition of a Hermitian (or symmetric for real-valued matrices) positive-definite matrix or the Cholesky decompositions for a batch of such matrices. -The returned matrix ``L`` is lower-triangular, and -the decomposition has the form: +Each decomposition has the form: .. math:: \text{input} = LL^H -where :math:`L^H` is the conjugate transpose of :math:`L`, which is just a transpose for the case -of real-valued input matrices. +where :math:`L` is a lower-triangular matrix and :math:`L^H` is the conjugate transpose of :math:`L`, +which is just a transpose for the case of real-valued input matrices. In code it translates to ``input = L @ L.t()` if :attr:`input` is real-valued and ``input = L @ L.conj().t()`` if :attr:`input` is complex-valued. +The batch of :math:`L` matrices is returned. Supports real and complex inputs. Backpropagation for complex inputs is only supported on the CPU. .. note:: If :attr:`input` is not a Hermitian positive-definite matrix, or if it's a batch of matrices and one or more of them is not a Hermitian positive-definite matrix, then a RuntimeError will be thrown. - The error message tells the index of the first problematic matrix in the batch. + If :attr:`input` is a batch of matrices, then the error message will include the batch index + of the first matrix that is not Hermitian positive-definite. Args: - input (Tensor): the input tensor of size :math:`(*, n, n)` where `*` is zero or more - batch dimensions consisting of Hermitian positive-definite matrices. + input (Tensor): the input tensor of size :math:`(*, n, n)` consisting of Hermitian positive-definite + :math:`n \times n` matrices, where `*` is zero or more batch dimensions. Keyword args: out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` @@ -41,7 +42,7 @@ Examples:: >>> a = torch.randn(2, 2, dtype=torch.complex128) - >>> a = torch.mm(a, a.t().conj()) # makes a Hermitian positive-definite matrix + >>> a = torch.mm(a, a.t().conj()) # creates a Hermitian positive-definite matrix >>> l = torch.linalg.cholesky(a) >>> a tensor([[2.5266+0.0000j, 1.9586-2.0626j], @@ -54,7 +55,7 @@ [1.9586+2.0626j, 9.4160+0.0000j]], dtype=torch.complex128) >>> a = torch.randn(3, 2, 2, dtype=torch.float64) - >>> a = torch.matmul(a, a.transpose(-2, -1)) # makes a symmetric positive-definite matrix + >>> a = torch.matmul(a, a.transpose(-2, -1)) # creates a symmetric positive-definite matrix >>> l = torch.linalg.cholesky(a) >>> a tensor([[[ 1.1629, 2.0237], From 34cef7d5bc28c218070ede6047c0e99fe51d6ab8 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 19 Oct 2020 08:47:43 +0000 Subject: [PATCH 27/44] Add one more assert for out= test --- test/test_linalg.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_linalg.py b/test/test_linalg.py index 08c6ea90f7d7..0102ff19e536 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -960,6 +960,8 @@ def run_test(shape, batch): out = torch.empty_like(A) ans = torch.linalg.cholesky(A, out=out) self.assertEqual(ans, out) + expected = torch.linalg.cholesky(A) + self.assertEqual(expected, out) @skipCUDAIfNoMagma @skipCPUIfNoLapack From 5d2230d84e02298852377432726085676c66b4dc Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 19 Oct 2020 08:51:01 +0000 Subject: [PATCH 28/44] Updated test_cholesky_errors --- test/test_linalg.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 0102ff19e536..f90c1ec65822 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -975,14 +975,17 @@ def test_cholesky_errors(self, device, dtype): with self.assertRaisesRegex(RuntimeError, r'to have same size as tensor for argument'): torch.linalg.cholesky(A, out=out) - # cholesky requires the input to be a square matrix + # cholesky requires the input to be a square matrix or batch of square matrices A = torch.randn(2, 3, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'): + torch.linalg.cholesky(A) + A = torch.randn(2, 2, 3, device=device, dtype=dtype) with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'): torch.linalg.cholesky(A) with self.assertRaisesRegex(np.linalg.LinAlgError, r'Last 2 dimensions of the array must be square'): np.linalg.cholesky(A.cpu().numpy()) - # cholesky requires the input to be a matrix + # cholesky requires the input to be at least 2 dimensional tensor A = torch.randn(2, device=device, dtype=dtype) with self.assertRaisesRegex(RuntimeError, r'must have at least 2 dimensions'): torch.linalg.cholesky(A) From 63e922daf6941bbee8fca910857dc354ab6ef786 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 19 Oct 2020 09:04:15 +0000 Subject: [PATCH 29/44] Added non contiguous test --- test/test_linalg.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index f90c1ec65822..a94ce929c66b 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -931,8 +931,11 @@ def test_nuclear_norm_exceptions_old(self, device): def test_cholesky(self, device, dtype): from torch.testing._internal.common_utils import random_hermitian_pd_matrix - def run_test(shape, batch): + def run_test(shape, batch, contiguous): A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device) + if A.numel() > 0 and not contiguous: + A = A.transpose(-2, -1) + self.assertFalse(A.is_contiguous()) expected_L = np.linalg.cholesky(A.cpu().numpy()) actual_L = torch.linalg.cholesky(A) matrices_are_equal = np.allclose(actual_L.cpu().numpy(), expected_L) @@ -951,9 +954,9 @@ def run_test(shape, batch): shapes = (0, 3, 5) batches = ((), (3, ), (2, 2)) - larger_input_case = [(100, (5, ))] - for shape, batch in list(itertools.product(shapes, batches)) + larger_input_case: - run_test(shape, batch) + larger_input_case = [(100, (5, ), True)] + for shape, batch, contiguous in list(itertools.product(shapes, batches, (True, False))) + larger_input_case: + run_test(shape, batch, contiguous) # check the out= variant A = random_hermitian_pd_matrix(3, 3, dtype=dtype, device=device) From 6d64067904706d4a60830dc247d70b1730d6c6bc Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 19 Oct 2020 09:10:29 +0000 Subject: [PATCH 30/44] Make test_cholesky_autograd run on gpu for fp64 --- test/test_linalg.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index a94ce929c66b..347bd6a14a69 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1024,9 +1024,10 @@ def test_cholesky_xfailed(self, device, dtype): actual_L = torch.linalg.cholesky(A) self.assertEqual(actual_L, expected_L) - @onlyCPU + @skipCUDAIfNoMagma @skipCPUIfNoLapack - @dtypes(torch.float64, torch.complex128) + @dtypesIfCPU(torch.float64, torch.complex128) + @dtypes(torch.float64) def test_cholesky_autograd(self, device, dtype): def func(root): x = 0.5 * (root + root.transpose(-1, -2).conj()) From 40935f1f073ac91ae61f123d1b0b674e655a29ee Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 19 Oct 2020 12:14:16 +0000 Subject: [PATCH 31/44] Changed np.allclose -> torch.allclose --- test/test_linalg.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 347bd6a14a69..36ef7a9ec55e 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -938,7 +938,7 @@ def run_test(shape, batch, contiguous): self.assertFalse(A.is_contiguous()) expected_L = np.linalg.cholesky(A.cpu().numpy()) actual_L = torch.linalg.cholesky(A) - matrices_are_equal = np.allclose(actual_L.cpu().numpy(), expected_L) + matrices_are_equal = torch.allclose(actual_L, torch.from_numpy(expected_L).to(device)) # For fp32 individual entries in matrices can differ between PyTorch and NumPy # Let's compare the norms of matrices instead @@ -946,9 +946,10 @@ def run_test(shape, batch, contiguous): # axis is specified to calculate matrix norm for batched input expected_norm = np.linalg.norm(expected_L, ord=1, axis=(-2, -1)) actual_norm = torch.linalg.norm(actual_L, ord=1, axis=(-2, -1)) - norms_are_equal = np.allclose(actual_norm.cpu().numpy(), expected_norm) - matrices_are_equal = np.allclose(actual_L.cpu().numpy(), expected_L, atol=1e-2, rtol=1e-5) - self.assertTrue(matrices_are_equal and norms_are_equal) + # Compare the norms with standard tolerances + self.assertEqual(actual_norm, expected_norm) + # and individual values with a higher tolerance + self.assertEqual(actual_L, expected_L, atol=1e-2, rtol=1e-5) else: self.assertTrue(matrices_are_equal) From 22be1a52a906a0ac2f133aebc8f1a8cda7b85254 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 28 Oct 2020 12:10:47 -0500 Subject: [PATCH 32/44] Fix long lines --- test/test_linalg.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 2ff7ed281cd0..d52673b3e282 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -986,7 +986,8 @@ def test_cholesky_errors(self, device, dtype): A = torch.randn(2, device=device, dtype=dtype) with self.assertRaisesRegex(RuntimeError, r'must have at least 2 dimensions'): torch.linalg.cholesky(A) - with self.assertRaisesRegex(np.linalg.LinAlgError, r'1-dimensional array given\. Array must be at least two-dimensional'): + with self.assertRaisesRegex(np.linalg.LinAlgError, + r'1-dimensional array given\. Array must be at least two-dimensional'): np.linalg.cholesky(A.cpu().numpy()) # if the input matrix is singular, an error should be raised @@ -1099,8 +1100,10 @@ def cholesky_test_helper(n, batch_dims, upper): # This is a workaround while there is no support for batched complex random_symmetric_pd_matrix if dtype.is_complex: real_dtype = torch.float32 if dtype is torch.complex64 else torch.float64 - A_real = random_fullrank_matrix_distinct_singular_value(n, *batch_dims, dtype=real_dtype, device=device) - A_imag = random_fullrank_matrix_distinct_singular_value(n, *batch_dims, dtype=real_dtype, device=device) + A_real = random_fullrank_matrix_distinct_singular_value(n, *batch_dims, + dtype=real_dtype, device=device) + A_imag = random_fullrank_matrix_distinct_singular_value(n, *batch_dims, + dtype=real_dtype, device=device) A = A_real + 1j * A_imag # There is no support for complex batched matmul yet matmul_list = [] From d547b6391f9892a9b122b47cdb7a69998723402b Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 28 Oct 2020 12:23:10 -0500 Subject: [PATCH 33/44] Remove unused import --- test/test_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_torch.py b/test/test_torch.py index 1f4bb8338bdc..9bba63f559f9 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -36,7 +36,7 @@ from torch.testing._internal.common_device_type import instantiate_device_type_tests, \ skipCPUIfNoLapack, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCUDAIfNotRocm, \ onlyCUDA, onlyCPU, \ - dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, skipCUDAIf, precisionOverride, \ + dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, precisionOverride, \ PYTORCH_CUDA_MEMCHECK, largeCUDATensorTest, largeTensorTest, onlyOnCPUAndCUDA, expectedAlertNondeterministic from typing import Dict, List, Tuple, Union import torch.backends.quantized From f723421e2aa0684d5c8c8f560adaa1ae59656a8d Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 3 Nov 2020 03:54:26 -0600 Subject: [PATCH 34/44] Use at::native::resize_output --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 16 ++++++-------- test/test_linalg.py | 23 ++++++++++++++------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 54c0f8a93c80..c24e6027624b 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -591,16 +592,11 @@ Tensor linalg_cholesky(const Tensor &self) { } Tensor& linalg_cholesky_out(Tensor &result, const Tensor &self) { - squareCheckInputs(self); - CheckedFrom c = "linalg_cholesky_out"; - TensorArg result_arg(result, "result", 0); - TensorArg self_arg(self, "self", 1); - checkSameSize(c, result_arg, self_arg); - checkSameType(c, result_arg, self_arg); - TORCH_CHECK(self.device() == result.device(), - "Expected input and out tensors to be on the same device, but found input on ", - self.device(), " and out on ", result.device(), " instead."); - result.copy_(native::linalg_cholesky(self)); + TORCH_CHECK(result.scalar_type() == self.scalar_type(), + "result dtype ", result.scalar_type(), " does not match self dtype ", self.scalar_type()); + Tensor result_tmp = at::linalg_cholesky(self); + at::native::resize_output(result, result_tmp.sizes()); + result.copy_(result_tmp); return result; } diff --git a/test/test_linalg.py b/test/test_linalg.py index 8d0686d2e02a..439876275544 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1084,15 +1084,9 @@ def run_test(shape, batch, contiguous): @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) - def test_cholesky_errors(self, device, dtype): + def test_cholesky_errors_and_warnings(self, device, dtype): from torch.testing._internal.common_utils import random_hermitian_pd_matrix - # cholesky requires out to have same shape as input - A = random_hermitian_pd_matrix(3, dtype=dtype, device=device) - out = torch.empty(2, 3, dtype=dtype, device=device) - with self.assertRaisesRegex(RuntimeError, r'to have same size as tensor for argument'): - torch.linalg.cholesky(A, out=out) - # cholesky requires the input to be a square matrix or batch of square matrices A = torch.randn(2, 3, device=device, dtype=dtype) with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'): @@ -1127,6 +1121,21 @@ def test_cholesky_errors(self, device, dtype): with self.assertRaisesRegex(RuntimeError, r'For batch 4: U\(3,3\) is zero, singular U\.'): torch.linalg.cholesky(A) + # if out tensor with wrong shape is passed a warning is given + A = random_hermitian_pd_matrix(3, dtype=dtype, device=device) + out = torch.empty(2, 3, dtype=dtype, device=device) + with warnings.catch_warnings(record=True) as w: + # Trigger warning + torch.linalg.cholesky(A, out=out) + # Check warning occurs + self.assertEqual(len(w), 1) + self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) + + # dtypes should match + out = torch.empty_like(A).to(torch.int) + with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match self dtype"): + torch.linalg.cholesky(A, out=out) + # TODO: once there is more support for complex dtypes on GPU, they shall be added to above test # particularly when RuntimeError: _th_bmm_out not supported on CUDAType for ComplexFloat is fixed @unittest.expectedFailure From 048316d0d17e82d992a9a5994675f0b67c53b464 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 10 Nov 2020 10:35:03 -0600 Subject: [PATCH 35/44] Added a warning about data movement for cuda inputs --- torch/linalg/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 83776057a6ad..476d80894553 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -32,6 +32,9 @@ If :attr:`input` is a batch of matrices, then the error message will include the batch index of the first matrix that is not Hermitian positive-definite. +.. warning:: This function always checks whether :attr:`input` is a Hermitian positive-definite matrix + using `info` argument to LAPACK/MAGMA call. For CUDA this causes cross-device memory synchronization. + Args: input (Tensor): the input tensor of size :math:`(*, n, n)` consisting of Hermitian positive-definite :math:`n \times n` matrices, where `*` is zero or more batch dimensions. From 0a13e1da1a76e541f8493dc2d9576f5273852367 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 10 Nov 2020 10:54:55 -0600 Subject: [PATCH 36/44] Use typed std::max --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 2 +- aten/src/ATen/native/cuda/BatchLinearAlgebra.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index c24e6027624b..270ffeeee5c5 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -536,7 +536,7 @@ static void apply_cholesky(Tensor& self, bool upper, std::vector& infos auto self_matrix_stride = matrixStride(self); auto batch_size = batchCount(self); auto n = self.size(-2); - auto lda = std::max(int64_t{1}, n); + auto lda = std::max(1, n); int info; for (int64_t i = 0; i < batch_size; i++) { diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index dcc069814e13..4d2259cfecca 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -1139,7 +1139,7 @@ AT_ERROR("cholesky: MAGMA library not found in " auto self_data = self.data_ptr(); magma_int_t n = magma_int_cast(self.size(-2), "self.size(-2)"); - auto lda = std::max(magma_int_t{1}, n); + auto lda = std::max(1, n); if (self.dim() == 2) { magma_int_t info = 0; From 173fea1fbe37d39a36f64765dc0e0542c3cd1fe4 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 10 Nov 2020 10:58:16 -0600 Subject: [PATCH 37/44] Replaced torch.allclose with self.assertEqual --- test/test_linalg.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 45a4c18e97d4..bf0d8a0b0fd9 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -6,7 +6,7 @@ from random import randrange from torch.testing._internal.common_utils import \ - (TestCase, run_tests, TEST_NUMPY, IS_MACOS, IS_WINDOWS, TEST_WITH_ASAN, make_tensor) + (TestCase, run_tests, TEST_NUMPY, IS_MACOS, IS_WINDOWS, TEST_WITH_ASAN, slowTest, make_tensor) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, dtypesIfCPU, dtypesIfCUDA, onlyCUDA, onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, @@ -1224,11 +1224,10 @@ def run_test(shape, batch, contiguous): self.assertFalse(A.is_contiguous()) expected_L = np.linalg.cholesky(A.cpu().numpy()) actual_L = torch.linalg.cholesky(A) - matrices_are_equal = torch.allclose(actual_L, torch.from_numpy(expected_L).to(device)) # For fp32 individual entries in matrices can differ between PyTorch and NumPy # Let's compare the norms of matrices instead - if A.numel() > 0 and not matrices_are_equal: + if A.numel() > 0 and dtype in [torch.float32, torch.complex64]: # axis is specified to calculate matrix norm for batched input expected_norm = np.linalg.norm(expected_L, ord=1, axis=(-2, -1)) actual_norm = torch.linalg.norm(actual_L, ord=1, axis=(-2, -1)) @@ -1237,7 +1236,7 @@ def run_test(shape, batch, contiguous): # and individual values with a higher tolerance self.assertEqual(actual_L, expected_L, atol=1e-2, rtol=1e-5) else: - self.assertTrue(matrices_are_equal) + self.assertEqual(actual_L, expected_L) shapes = (0, 3, 5) batches = ((), (3, ), (2, 2)) From b26f52cc4b504272991e466b19980ba78858223a Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 10 Nov 2020 10:59:35 -0600 Subject: [PATCH 38/44] Removed unused import --- test/test_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index bf0d8a0b0fd9..88cb38c369cc 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -10,7 +10,7 @@ from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, dtypesIfCPU, dtypesIfCUDA, onlyCUDA, onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, - skipCUDAIfNoMagmaAndNoCusolver, onlyOnCPUAndCUDA) + onlyOnCPUAndCUDA) from torch.testing._internal.common_cuda import tf32_on_and_off from torch.testing._internal.jit_metaprogramming_utils import gen_script_fn_and_args from torch.autograd import gradcheck, gradgradcheck From ba3708e2c18e51370c734c4dfb7e2c50ee1cd612 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 10 Nov 2020 13:17:20 -0600 Subject: [PATCH 39/44] Fix imports --- test/test_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_torch.py b/test/test_torch.py index 980297de47a3..ab2fe444cc64 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -38,7 +38,7 @@ from torch.testing._internal.common_device_type import instantiate_device_type_tests, \ skipCPUIfNoLapack, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCUDAIfNotRocm, \ onlyCUDA, onlyCPU, \ - dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, precisionOverride, \ + dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, skipCUDAIf, precisionOverride, \ PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyOnCPUAndCUDA, expectedAlertNondeterministic from typing import Dict, List, Tuple, Union import torch.backends.quantized From f2ee1f636aa3a0528ab8604d7957c55ff6b1c28a Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 12 Nov 2020 04:05:20 -0600 Subject: [PATCH 40/44] Finish merge --- test/test_linalg.py | 271 -------------------------------------------- 1 file changed, 271 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 7fb3e77b3abf..2eebeedc1bc9 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1734,49 +1734,6 @@ def check(equation, operands, regex, exception=RuntimeError): check('a, ba', [x, y], r'operands do not broadcast with remapped shapes \[original->remapped\]: ' r'\[2\]->\[1, 2\] \[2, 3\]->\[2, 3\]') -<<<<<<< HEAD - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypesIfCPU(torch.float32, torch.float64, torch.complex64, torch.complex128) - @dtypesIfCUDA(torch.float32, torch.float64) - def test_cholesky(self, device, dtype): - from torch.testing._internal.common_utils import random_hermitian_pd_matrix - - def run_test(shape, batch, contiguous): - A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device) - if A.numel() > 0 and not contiguous: - A = A.transpose(-2, -1) - self.assertFalse(A.is_contiguous()) - expected_L = np.linalg.cholesky(A.cpu().numpy()) - actual_L = torch.linalg.cholesky(A) - - # For fp32 individual entries in matrices can differ between PyTorch and NumPy - # Let's compare the norms of matrices instead - if A.numel() > 0 and dtype in [torch.float32, torch.complex64]: - # axis is specified to calculate matrix norm for batched input - expected_norm = np.linalg.norm(expected_L, ord=1, axis=(-2, -1)) - actual_norm = torch.linalg.norm(actual_L, ord=1, axis=(-2, -1)) - # Compare the norms with standard tolerances - self.assertEqual(actual_norm, expected_norm) - # and individual values with a higher tolerance - self.assertEqual(actual_L, expected_L, atol=1e-2, rtol=1e-5) - else: - self.assertEqual(actual_L, expected_L) - - shapes = (0, 3, 5) - batches = ((), (3, ), (2, 2)) - larger_input_case = [(100, (5, ), True)] - for shape, batch, contiguous in list(itertools.product(shapes, batches, (True, False))) + larger_input_case: - run_test(shape, batch, contiguous) - - # check the out= variant - A = random_hermitian_pd_matrix(3, 3, dtype=dtype, device=device) - out = torch.empty_like(A) - ans = torch.linalg.cholesky(A, out=out) - self.assertEqual(ans, out) - expected = torch.linalg.cholesky(A) - self.assertEqual(expected, out) -======= def triangular_solve_test_helper(self, A_dims, b_dims, upper, unitriangular, device, dtype): triangle_function = torch.triu if upper else torch.tril @@ -1795,214 +1752,10 @@ def triangular_solve_test_helper(self, A_dims, b_dims, upper, unitriangular, if unitriangular: A_triangular.diagonal(dim1=-2, dim2=-1).fill_(1.) return b, A_triangular ->>>>>>> upstream/master - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) -<<<<<<< HEAD - def test_cholesky_errors_and_warnings(self, device, dtype): - from torch.testing._internal.common_utils import random_hermitian_pd_matrix - - # cholesky requires the input to be a square matrix or batch of square matrices - A = torch.randn(2, 3, device=device, dtype=dtype) - with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'): - torch.linalg.cholesky(A) - A = torch.randn(2, 2, 3, device=device, dtype=dtype) - with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'): - torch.linalg.cholesky(A) - with self.assertRaisesRegex(np.linalg.LinAlgError, r'Last 2 dimensions of the array must be square'): - np.linalg.cholesky(A.cpu().numpy()) - - # cholesky requires the input to be at least 2 dimensional tensor - A = torch.randn(2, device=device, dtype=dtype) - with self.assertRaisesRegex(RuntimeError, r'must have at least 2 dimensions'): - torch.linalg.cholesky(A) - with self.assertRaisesRegex(np.linalg.LinAlgError, - r'1-dimensional array given\. Array must be at least two-dimensional'): - np.linalg.cholesky(A.cpu().numpy()) - - # if the input matrix is singular, an error should be raised - A = torch.eye(3, 3, dtype=dtype, device=device) - A[-1, -1] = 0 # Now A is singular - with self.assertRaisesRegex(RuntimeError, r'U\(3,3\) is zero, singular U\.'): - torch.linalg.cholesky(A) - with self.assertRaisesRegex(np.linalg.LinAlgError, r'Matrix is not positive definite'): - np.linalg.cholesky(A.cpu().numpy()) - - # if at least one matrix in the batch is singular, an error should be raised - A = torch.eye(3, 3, dtype=dtype, device=device) - A = A.reshape((1, 3, 3)) - A = A.repeat(5, 1, 1) - A[4, -1, -1] = 0 # Now A[4] is singular - with self.assertRaisesRegex(RuntimeError, r'For batch 4: U\(3,3\) is zero, singular U\.'): - torch.linalg.cholesky(A) - - # if out tensor with wrong shape is passed a warning is given - A = random_hermitian_pd_matrix(3, dtype=dtype, device=device) - out = torch.empty(2, 3, dtype=dtype, device=device) - with warnings.catch_warnings(record=True) as w: - # Trigger warning - torch.linalg.cholesky(A, out=out) - # Check warning occurs - self.assertEqual(len(w), 1) - self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) - - # dtypes should match - out = torch.empty_like(A).to(torch.int) - with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match self dtype"): - torch.linalg.cholesky(A, out=out) - - # TODO: once there is more support for complex dtypes on GPU, they shall be added to above test - # particularly when RuntimeError: _th_bmm_out not supported on CUDAType for ComplexFloat is fixed - @unittest.expectedFailure - @onlyCUDA - @skipCUDAIfNoMagma - @dtypes(torch.complex64, torch.complex128) - def test_cholesky_xfailed(self, device, dtype): - from torch.testing._internal.common_utils import random_hermitian_pd_matrix - A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device) - expected_L = np.linalg.cholesky(A.cpu().numpy()) - actual_L = torch.linalg.cholesky(A) - self.assertEqual(actual_L, expected_L) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypesIfCPU(torch.float64, torch.complex128) - @dtypes(torch.float64) - def test_cholesky_autograd(self, device, dtype): - def func(root): - x = 0.5 * (root + root.transpose(-1, -2).conj()) - return torch.linalg.cholesky(x) - - def run_test(shape): - root = torch.rand(*shape, dtype=dtype, device=device, requires_grad=True) - root = root + torch.eye(shape[-1], dtype=dtype, device=device) - - gradcheck(func, root) - gradgradcheck(func, root) - - root = torch.rand(*shape, dtype=dtype, device=device) - root = torch.matmul(root, root.transpose(-1, -2).conj()) - root.requires_grad_() - chol = torch.linalg.cholesky(root).sum().backward() - self.assertEqual(root.grad, root.grad.transpose(-1, -2).conj()) # Check the gradient is hermitian - shapes = ((3, 3), (4, 3, 2, 2)) - for shape in shapes: - run_test(shape) - - # TODO: enable CUDA tests once (merge with above test) - # RuntimeError: "triangular_solve_cuda" not implemented for 'ComplexDouble' is fixed - @unittest.expectedFailure - @onlyCUDA - @skipCUDAIfNoMagma - @dtypes(torch.complex128) - def test_cholesky_autograd_xfailed(self, device, dtype): - def func(root): - x = 0.5 * (root + root.transpose(-1, -2).conj()) - return torch.linalg.cholesky(x) - - shape = (3, 3) - root = torch.rand(*shape, dtype=dtype, device=device, requires_grad=True) - root = root + torch.eye(shape[-1], dtype=dtype, device=device) - gradcheck(func, root) - - # NOTE: old_cholesky* tests were moved here from test_torch.py and test_autograd.py - @slowTest - @skipCUDAIf(True, "See issue #26789.") - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_old_cholesky_batched_many_batches(self, device, dtype): - from torch.testing._internal.common_utils import random_symmetric_pd_matrix - - def cholesky_test_helper(n, batchsize, device, upper): - A = random_symmetric_pd_matrix(n, batchsize, dtype=dtype, device=device) - chol_fact = torch.cholesky(A, upper=upper) - if upper: - # Correctness check - self.assertEqual(A, chol_fact.transpose(-2, -1).matmul(chol_fact)) - # Upper triangular check - self.assertEqual(chol_fact, chol_fact.triu()) - else: - # Correctness check - self.assertEqual(A, chol_fact.matmul(chol_fact.transpose(-2, -1))) - # Lower triangular check - self.assertEqual(chol_fact, chol_fact.tril()) - - for upper, batchsize in itertools.product([True, False], [262144, 524288]): - cholesky_test_helper(2, batchsize, device, upper) - - @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) - def test_old_cholesky_batched(self, device, dtype): - from torch.testing._internal.common_utils import \ - (random_symmetric_pd_matrix, - random_fullrank_matrix_distinct_singular_value) - - def cholesky_test_helper(n, batch_dims, upper): - # This is a workaround while there is no support for batched complex random_symmetric_pd_matrix - if dtype.is_complex: - real_dtype = torch.float32 if dtype is torch.complex64 else torch.float64 - A_real = random_fullrank_matrix_distinct_singular_value(n, *batch_dims, - dtype=real_dtype, device=device) - A_imag = random_fullrank_matrix_distinct_singular_value(n, *batch_dims, - dtype=real_dtype, device=device) - A = A_real + 1j * A_imag - # There is no support for complex batched matmul yet - matmul_list = [] - for mat in A.contiguous().view(-1, n, n): - matmul_list.append(mat @ mat.t().conj()) - A = torch.stack(matmul_list).view(*batch_dims, n, n) - else: - A = random_symmetric_pd_matrix(n, *batch_dims, dtype=dtype, device=device) - cholesky_exp = torch.stack([m.cholesky(upper=upper) for m in A.reshape(-1, n, n)]) - cholesky_exp = cholesky_exp.reshape_as(A) - self.assertEqual(cholesky_exp, torch.cholesky(A, upper=upper)) - - for upper, batchsize in itertools.product([True, False], [(3,), (3, 4), (2, 3, 4)]): - cholesky_test_helper(3, batchsize, upper) - - @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) - @tf32_on_and_off(0.01) - def test_old_cholesky(self, device, dtype): - from torch.testing._internal.common_utils import random_hermitian_pd_matrix - - A = random_hermitian_pd_matrix(10, dtype=dtype, device=device) - - # default Case - C = torch.cholesky(A) - B = torch.mm(C, C.t().conj()) - self.assertEqual(A, B, atol=1e-14, rtol=0) - - # test Upper Triangular - U = torch.cholesky(A, True) - B = torch.mm(U.t().conj(), U) - self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (upper) did not allow rebuilding the original matrix') - - # test Lower Triangular - L = torch.cholesky(A, False) - B = torch.mm(L, L.t().conj()) - self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (lower) did not allow rebuilding the original matrix') - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) - def test_old_cholesky_empty(self, device, dtype): - def run_test(upper): - A = torch.empty(0, 0, dtype=dtype, device=device) - chol = torch.cholesky(A, upper) - chol_A = torch.matmul(chol, chol.t().conj()) - self.assertEqual(A, chol_A) - for upper in [True, False]: - run_test(upper) -======= @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, torch.float64: 1e-8, torch.complex128: 1e-8}) def test_triangular_solve(self, device, dtype): @@ -2128,33 +1881,10 @@ def run_test(A_dims, b_dims, device, upper, transpose, unitriangular): run_test((2, 1, 3, 4, 4), (4, 6), device, upper, transpose, unitriangular) # broadcasting b run_test((4, 4), (2, 1, 3, 4, 2), device, upper, transpose, unitriangular) # broadcasting A run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), device, upper, transpose, unitriangular) # broadcasting A & b ->>>>>>> upstream/master @onlyCPU @skipCPUIfNoLapack @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) -<<<<<<< HEAD - def test_old_cholesky_autograd(self, device, dtype): - def func(root, upper): - x = 0.5 * (root + root.transpose(-1, -2).conj()) - return torch.cholesky(x, upper) - - def run_test(upper, dims): - root = torch.rand(*dims, dtype=dtype, device=device, requires_grad=True) - root = root + torch.eye(dims[-1]) - - gradcheck(func, [root, upper]) - gradgradcheck(func, [root, upper]) - - root = torch.rand(*dims, dtype=dtype, device=device) - root = torch.matmul(root, root.transpose(-1, -2).conj()) - root.requires_grad_() - chol = root.cholesky().sum().backward() - self.assertEqual(root.grad, root.grad.transpose(-1, -2).conj()) # Check the gradient is hermitian - - for upper, dims in itertools.product([True, False], [(3, 3), (4, 3, 2, 2)]): - run_test(upper, dims) -======= def test_triangular_solve_singular(self, device, dtype): b = torch.rand(3, 1, dtype=dtype, device=device) A = torch.eye(3, 3, dtype=dtype, device=device) @@ -2182,7 +1912,6 @@ def func(A, b): run_test((3, 3), (3, 2)) run_test((2, 3, 3), (2, 3, 4)) run_test((2, 3, 3), (2, 3, 2)) ->>>>>>> upstream/master instantiate_device_type_tests(TestLinalg, globals()) From d52b83bb5470bbbe4e8a41dc090e2cf5f9923be1 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 12 Nov 2020 13:58:17 -0600 Subject: [PATCH 41/44] Differentiation of complex cholesky on cuda now works for single input; batched complex input does not work until batched complex matmul is implemented --- test/test_linalg.py | 4 ++-- torch/linalg/__init__.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 81ef607cff11..5f2ab2f5d654 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -196,7 +196,7 @@ def run_test(shape): run_test(shape) # TODO: enable CUDA tests once (merge with above test) - # RuntimeError: "triangular_solve_cuda" not implemented for 'ComplexDouble' is fixed + # batched matmul for complex dtypes on CUDA is implemented @unittest.expectedFailure @onlyCUDA @skipCUDAIfNoMagma @@ -206,7 +206,7 @@ def func(root): x = 0.5 * (root + root.transpose(-1, -2).conj()) return torch.linalg.cholesky(x) - shape = (3, 3) + shape = (3, 2, 2) root = torch.rand(*shape, dtype=dtype, device=device, requires_grad=True) root = root + torch.eye(shape[-1], dtype=dtype, device=device) gradcheck(func, root) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 476d80894553..9adc05b4d319 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -25,7 +25,7 @@ ``input = L @ L.conj().t()`` if :attr:`input` is complex-valued. The batch of :math:`L` matrices is returned. -Supports real and complex inputs. Backpropagation for complex inputs is only supported on the CPU. +Supports real and complex inputs. Backpropagation for batched complex inputs is only supported on the CPU. .. note:: If :attr:`input` is not a Hermitian positive-definite matrix, or if it's a batch of matrices and one or more of them is not a Hermitian positive-definite matrix, then a RuntimeError will be thrown. From 287c878dbb67c32f184c6212830f43bbcdf61599 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Fri, 13 Nov 2020 02:21:10 -0600 Subject: [PATCH 42/44] Batched matmul for complex on CUDA is implemented now; fix tests --- test/test_linalg.py | 39 ++++----------------------------------- torch/linalg/__init__.py | 2 +- 2 files changed, 5 insertions(+), 36 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 78bc2fd0367e..7d06488d140b 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -8,8 +8,8 @@ from torch.testing._internal.common_utils import \ (TestCase, run_tests, TEST_NUMPY, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest, TEST_WITH_ASAN, make_tensor) from torch.testing._internal.common_device_type import \ - (instantiate_device_type_tests, dtypes, dtypesIfCPU, dtypesIfCUDA, - onlyCUDA, onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, + (instantiate_device_type_tests, dtypes, + onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, onlyOnCPUAndCUDA) from torch.testing._internal.common_cuda import tf32_on_and_off from torch.testing._internal.jit_metaprogramming_utils import gen_script_fn_and_args @@ -61,8 +61,7 @@ def run_test_case(a, b): @skipCUDAIfNoMagma @skipCPUIfNoLapack - @dtypesIfCPU(torch.float32, torch.float64, torch.complex64, torch.complex128) - @dtypesIfCUDA(torch.float32, torch.float64) + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) def test_cholesky(self, device, dtype): from torch.testing._internal.common_utils import random_hermitian_pd_matrix @@ -156,23 +155,9 @@ def test_cholesky_errors_and_warnings(self, device, dtype): with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match self dtype"): torch.linalg.cholesky(A, out=out) - # TODO: once there is more support for complex dtypes on GPU, they shall be added to above test - # particularly when RuntimeError: _th_bmm_out not supported on CUDAType for ComplexFloat is fixed - @unittest.expectedFailure - @onlyCUDA - @skipCUDAIfNoMagma - @dtypes(torch.complex64, torch.complex128) - def test_cholesky_xfailed(self, device, dtype): - from torch.testing._internal.common_utils import random_hermitian_pd_matrix - A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device) - expected_L = np.linalg.cholesky(A.cpu().numpy()) - actual_L = torch.linalg.cholesky(A) - self.assertEqual(actual_L, expected_L) - @skipCUDAIfNoMagma @skipCPUIfNoLapack - @dtypesIfCPU(torch.float64, torch.complex128) - @dtypes(torch.float64) + @dtypes(torch.float64, torch.complex128) def test_cholesky_autograd(self, device, dtype): def func(root): x = 0.5 * (root + root.transpose(-1, -2).conj()) @@ -195,22 +180,6 @@ def run_test(shape): for shape in shapes: run_test(shape) - # TODO: enable CUDA tests once (merge with above test) - # batched matmul for complex dtypes on CUDA is implemented - @unittest.expectedFailure - @onlyCUDA - @skipCUDAIfNoMagma - @dtypes(torch.complex128) - def test_cholesky_autograd_xfailed(self, device, dtype): - def func(root): - x = 0.5 * (root + root.transpose(-1, -2).conj()) - return torch.linalg.cholesky(x) - - shape = (3, 2, 2) - root = torch.rand(*shape, dtype=dtype, device=device, requires_grad=True) - root = root + torch.eye(shape[-1], dtype=dtype, device=device) - gradcheck(func, root) - # NOTE: old_cholesky* tests were moved here from test_torch.py and test_autograd.py @slowTest @skipCUDAIf(True, "See issue #26789.") diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 9fef97f0ad13..247519d30a1b 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -25,7 +25,7 @@ ``input = L @ L.conj().t()`` if :attr:`input` is complex-valued. The batch of :math:`L` matrices is returned. -Supports real and complex inputs. Backpropagation for batched complex inputs is only supported on the CPU. +Supports real-valued and complex-valued inputs. .. note:: If :attr:`input` is not a Hermitian positive-definite matrix, or if it's a batch of matrices and one or more of them is not a Hermitian positive-definite matrix, then a RuntimeError will be thrown. From 88e23c3a6091505a9148ff80083f414cfa2e1379 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Fri, 13 Nov 2020 02:33:01 -0600 Subject: [PATCH 43/44] Remove redundant code from test_old_cholesky_batched --- test/test_linalg.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 7d06488d140b..ed5334392739 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -211,26 +211,10 @@ def cholesky_test_helper(n, batchsize, device, upper): @skipCPUIfNoLapack @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) def test_old_cholesky_batched(self, device, dtype): - from torch.testing._internal.common_utils import \ - (random_symmetric_pd_matrix, - random_fullrank_matrix_distinct_singular_value) + from torch.testing._internal.common_utils import random_hermitian_pd_matrix def cholesky_test_helper(n, batch_dims, upper): - # This is a workaround while there is no support for batched complex random_symmetric_pd_matrix - if dtype.is_complex: - real_dtype = torch.float32 if dtype is torch.complex64 else torch.float64 - A_real = random_fullrank_matrix_distinct_singular_value(n, *batch_dims, - dtype=real_dtype, device=device) - A_imag = random_fullrank_matrix_distinct_singular_value(n, *batch_dims, - dtype=real_dtype, device=device) - A = A_real + 1j * A_imag - # There is no support for complex batched matmul yet - matmul_list = [] - for mat in A.contiguous().view(-1, n, n): - matmul_list.append(mat @ mat.t().conj()) - A = torch.stack(matmul_list).view(*batch_dims, n, n) - else: - A = random_symmetric_pd_matrix(n, *batch_dims, dtype=dtype, device=device) + A = random_hermitian_pd_matrix(n, *batch_dims, dtype=dtype, device=device) cholesky_exp = torch.stack([m.cholesky(upper=upper) for m in A.reshape(-1, n, n)]) cholesky_exp = cholesky_exp.reshape_as(A) self.assertEqual(cholesky_exp, torch.cholesky(A, upper=upper)) From eb507e5ec7e5288b4711e0a8cb3f3457b58f6ca9 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Fri, 13 Nov 2020 03:00:12 -0600 Subject: [PATCH 44/44] flake8 fix --- test/test_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index ed5334392739..b86c71f0ed37 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -211,7 +211,7 @@ def cholesky_test_helper(n, batchsize, device, upper): @skipCPUIfNoLapack @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) def test_old_cholesky_batched(self, device, dtype): - from torch.testing._internal.common_utils import random_hermitian_pd_matrix + from torch.testing._internal.common_utils import random_hermitian_pd_matrix def cholesky_test_helper(n, batch_dims, upper): A = random_hermitian_pd_matrix(n, *batch_dims, dtype=dtype, device=device)