diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 16f706ca0ed5..270ffeeee5c5 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -535,11 +536,12 @@ static void apply_cholesky(Tensor& self, bool upper, std::vector& infos auto self_matrix_stride = matrixStride(self); auto batch_size = batchCount(self); auto n = self.size(-2); + auto lda = std::max(1, n); int info; for (int64_t i = 0; i < batch_size; i++) { scalar_t* self_working_ptr = &self_data[i * self_matrix_stride]; - lapackCholesky(uplo, n, self_working_ptr, n, &info); + lapackCholesky(uplo, n, self_working_ptr, lda, &info); infos[i] = info; if (info != 0) { return; @@ -584,6 +586,20 @@ Tensor& cholesky_out(Tensor &result, const Tensor &self, bool upper) { return result; } +Tensor linalg_cholesky(const Tensor &self) { + squareCheckInputs(self); + return at::_cholesky_helper(self, /*upper=*/false).tril_(); +} + +Tensor& linalg_cholesky_out(Tensor &result, const Tensor &self) { + TORCH_CHECK(result.scalar_type() == self.scalar_type(), + "result dtype ", result.scalar_type(), " does not match self dtype ", self.scalar_type()); + Tensor result_tmp = at::linalg_cholesky(self); + at::native::resize_output(result, result_tmp.sizes()); + result.copy_(result_tmp); + return result; +} + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index 4f9ff63d0ece..5f52a4fa2a51 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -1333,10 +1333,11 @@ AT_ERROR("cholesky: MAGMA library not found in " auto self_data = self.data_ptr(); magma_int_t n = magma_int_cast(self.size(-2), "self.size(-2)"); + auto lda = std::max(1, n); if (self.dim() == 2) { magma_int_t info = 0; - magmaCholesky(uplo, n, self_data, n, &info); + magmaCholesky(uplo, n, self_data, lda, &info); infos[0] = info; } else { auto self_mat_stride = matrixStride(self); @@ -1367,14 +1368,14 @@ AT_ERROR("cholesky: MAGMA library not found in " magma_int_t* info_array_cur = &info_array[mini_idx]; magmaCholeskyBatched( - uplo, n, self_array_cur, n, info_array_cur, batch_limit, magma_queue); + uplo, n, self_array_cur, lda, info_array_cur, batch_limit, magma_queue); } // Compute whatever is left = batch_size - floor(batch_size / batch_limit) * batch_limit // which concisely is equal to batch_size % batch_limit if (batch_size % batch_limit != 0) { magmaCholeskyBatched( - uplo, n, &self_array[mini_idx], n, &info_array[mini_idx], batch_size % batch_limit, magma_queue); + uplo, n, &self_array[mini_idx], lda, &info_array[mini_idx], batch_size % batch_limit, magma_queue); } for (int64_t i = 0; i < batch_size; i++) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 08cae18d4ae0..c847aad12639 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -8909,6 +8909,19 @@ # # See linalg_det as an example. +- func: linalg_cholesky(Tensor self) -> Tensor + python_module: linalg + use_c10_dispatcher: full + variants: function + dispatch: + DefaultBackend: linalg_cholesky + +- func: linalg_cholesky.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + variants: function + dispatch: + DefaultBackend: linalg_cholesky_out + # torch.linalg.det, alias for torch.det - func: linalg_det(Tensor self) -> Tensor python_module: linalg diff --git a/docs/source/linalg.rst b/docs/source/linalg.rst index 14d3ca1767e9..e1b87a1b2f56 100644 --- a/docs/source/linalg.rst +++ b/docs/source/linalg.rst @@ -12,6 +12,7 @@ Common linear algebra operations. Functions --------- +.. autofunction:: cholesky .. autofunction:: det .. autofunction:: norm .. autofunction:: tensorsolve diff --git a/test/test_autograd.py b/test/test_autograd.py index d7948b56f27d..1f5f9a92584e 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2595,30 +2595,6 @@ def test_var_mean_differentiable(self): torch.autograd.backward(r2, grad) self.assertTrue(torch.allclose(input1.grad, input2.grad, rtol=0.01, atol=0.0)) - @skipIfNoLapack - def test_cholesky(self): - def func(root, upper): - x = 0.5 * (root + root.transpose(-1, -2).conj()) - return torch.cholesky(x, upper) - - def run_test(upper, dims, dtype): - root = torch.rand(*dims, dtype=dtype, requires_grad=True) - root = root + torch.eye(dims[-1]) - - gradcheck(func, [root, upper]) - gradgradcheck(func, [root, upper]) - - root = torch.rand(*dims, dtype=dtype) - root = torch.matmul(root, root.transpose(-1, -2).conj()) - root.requires_grad_() - chol = root.cholesky().sum().backward() - self.assertEqual(root.grad, root.grad.transpose(-1, -2).conj()) # Check the gradient is hermitian - - for upper, dims, dtype in product([True, False], - [(3, 3), (4, 3, 2, 2)], - [torch.double, torch.cdouble]): - run_test(upper, dims, dtype) - @skipIfNoLapack def test_cholesky_solve(self): def _test_with_size(A_dims, B_dims, upper): diff --git a/test/test_linalg.py b/test/test_linalg.py index 16cfea8f536f..b86c71f0ed37 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -9,8 +9,9 @@ (TestCase, run_tests, TEST_NUMPY, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest, TEST_WITH_ASAN, make_tensor) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, - onlyCPU, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, + onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, onlyOnCPUAndCUDA) +from torch.testing._internal.common_cuda import tf32_on_and_off from torch.testing._internal.jit_metaprogramming_utils import gen_script_fn_and_args from torch.autograd import gradcheck, gradgradcheck @@ -58,6 +59,230 @@ def run_test_case(a, b): run_test_case(zero_strided, b) run_test_case(a, zero_strided) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_cholesky(self, device, dtype): + from torch.testing._internal.common_utils import random_hermitian_pd_matrix + + def run_test(shape, batch, contiguous): + A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device) + if A.numel() > 0 and not contiguous: + A = A.transpose(-2, -1) + self.assertFalse(A.is_contiguous()) + expected_L = np.linalg.cholesky(A.cpu().numpy()) + actual_L = torch.linalg.cholesky(A) + + # For fp32 individual entries in matrices can differ between PyTorch and NumPy + # Let's compare the norms of matrices instead + if A.numel() > 0 and dtype in [torch.float32, torch.complex64]: + # axis is specified to calculate matrix norm for batched input + expected_norm = np.linalg.norm(expected_L, ord=1, axis=(-2, -1)) + actual_norm = torch.linalg.norm(actual_L, ord=1, axis=(-2, -1)) + # Compare the norms with standard tolerances + self.assertEqual(actual_norm, expected_norm) + # and individual values with a higher tolerance + self.assertEqual(actual_L, expected_L, atol=1e-2, rtol=1e-5) + else: + self.assertEqual(actual_L, expected_L) + + shapes = (0, 3, 5) + batches = ((), (3, ), (2, 2)) + larger_input_case = [(100, (5, ), True)] + for shape, batch, contiguous in list(itertools.product(shapes, batches, (True, False))) + larger_input_case: + run_test(shape, batch, contiguous) + + # check the out= variant + A = random_hermitian_pd_matrix(3, 3, dtype=dtype, device=device) + out = torch.empty_like(A) + ans = torch.linalg.cholesky(A, out=out) + self.assertEqual(ans, out) + expected = torch.linalg.cholesky(A) + self.assertEqual(expected, out) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_cholesky_errors_and_warnings(self, device, dtype): + from torch.testing._internal.common_utils import random_hermitian_pd_matrix + + # cholesky requires the input to be a square matrix or batch of square matrices + A = torch.randn(2, 3, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'): + torch.linalg.cholesky(A) + A = torch.randn(2, 2, 3, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'): + torch.linalg.cholesky(A) + with self.assertRaisesRegex(np.linalg.LinAlgError, r'Last 2 dimensions of the array must be square'): + np.linalg.cholesky(A.cpu().numpy()) + + # cholesky requires the input to be at least 2 dimensional tensor + A = torch.randn(2, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, r'must have at least 2 dimensions'): + torch.linalg.cholesky(A) + with self.assertRaisesRegex(np.linalg.LinAlgError, + r'1-dimensional array given\. Array must be at least two-dimensional'): + np.linalg.cholesky(A.cpu().numpy()) + + # if the input matrix is singular, an error should be raised + A = torch.eye(3, 3, dtype=dtype, device=device) + A[-1, -1] = 0 # Now A is singular + with self.assertRaisesRegex(RuntimeError, r'U\(3,3\) is zero, singular U\.'): + torch.linalg.cholesky(A) + with self.assertRaisesRegex(np.linalg.LinAlgError, r'Matrix is not positive definite'): + np.linalg.cholesky(A.cpu().numpy()) + + # if at least one matrix in the batch is singular, an error should be raised + A = torch.eye(3, 3, dtype=dtype, device=device) + A = A.reshape((1, 3, 3)) + A = A.repeat(5, 1, 1) + A[4, -1, -1] = 0 # Now A[4] is singular + with self.assertRaisesRegex(RuntimeError, r'For batch 4: U\(3,3\) is zero, singular U\.'): + torch.linalg.cholesky(A) + + # if out tensor with wrong shape is passed a warning is given + A = random_hermitian_pd_matrix(3, dtype=dtype, device=device) + out = torch.empty(2, 3, dtype=dtype, device=device) + with warnings.catch_warnings(record=True) as w: + # Trigger warning + torch.linalg.cholesky(A, out=out) + # Check warning occurs + self.assertEqual(len(w), 1) + self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) + + # dtypes should match + out = torch.empty_like(A).to(torch.int) + with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match self dtype"): + torch.linalg.cholesky(A, out=out) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float64, torch.complex128) + def test_cholesky_autograd(self, device, dtype): + def func(root): + x = 0.5 * (root + root.transpose(-1, -2).conj()) + return torch.linalg.cholesky(x) + + def run_test(shape): + root = torch.rand(*shape, dtype=dtype, device=device, requires_grad=True) + root = root + torch.eye(shape[-1], dtype=dtype, device=device) + + gradcheck(func, root) + gradgradcheck(func, root) + + root = torch.rand(*shape, dtype=dtype, device=device) + root = torch.matmul(root, root.transpose(-1, -2).conj()) + root.requires_grad_() + chol = torch.linalg.cholesky(root).sum().backward() + self.assertEqual(root.grad, root.grad.transpose(-1, -2).conj()) # Check the gradient is hermitian + + shapes = ((3, 3), (4, 3, 2, 2)) + for shape in shapes: + run_test(shape) + + # NOTE: old_cholesky* tests were moved here from test_torch.py and test_autograd.py + @slowTest + @skipCUDAIf(True, "See issue #26789.") + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.double) + def test_old_cholesky_batched_many_batches(self, device, dtype): + from torch.testing._internal.common_utils import random_symmetric_pd_matrix + + def cholesky_test_helper(n, batchsize, device, upper): + A = random_symmetric_pd_matrix(n, batchsize, dtype=dtype, device=device) + chol_fact = torch.cholesky(A, upper=upper) + if upper: + # Correctness check + self.assertEqual(A, chol_fact.transpose(-2, -1).matmul(chol_fact)) + # Upper triangular check + self.assertEqual(chol_fact, chol_fact.triu()) + else: + # Correctness check + self.assertEqual(A, chol_fact.matmul(chol_fact.transpose(-2, -1))) + # Lower triangular check + self.assertEqual(chol_fact, chol_fact.tril()) + + for upper, batchsize in itertools.product([True, False], [262144, 524288]): + cholesky_test_helper(2, batchsize, device, upper) + + @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_old_cholesky_batched(self, device, dtype): + from torch.testing._internal.common_utils import random_hermitian_pd_matrix + + def cholesky_test_helper(n, batch_dims, upper): + A = random_hermitian_pd_matrix(n, *batch_dims, dtype=dtype, device=device) + cholesky_exp = torch.stack([m.cholesky(upper=upper) for m in A.reshape(-1, n, n)]) + cholesky_exp = cholesky_exp.reshape_as(A) + self.assertEqual(cholesky_exp, torch.cholesky(A, upper=upper)) + + for upper, batchsize in itertools.product([True, False], [(3,), (3, 4), (2, 3, 4)]): + cholesky_test_helper(3, batchsize, upper) + + @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @tf32_on_and_off(0.01) + def test_old_cholesky(self, device, dtype): + from torch.testing._internal.common_utils import random_hermitian_pd_matrix + + A = random_hermitian_pd_matrix(10, dtype=dtype, device=device) + + # default Case + C = torch.cholesky(A) + B = torch.mm(C, C.t().conj()) + self.assertEqual(A, B, atol=1e-14, rtol=0) + + # test Upper Triangular + U = torch.cholesky(A, True) + B = torch.mm(U.t().conj(), U) + self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (upper) did not allow rebuilding the original matrix') + + # test Lower Triangular + L = torch.cholesky(A, False) + B = torch.mm(L, L.t().conj()) + self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (lower) did not allow rebuilding the original matrix') + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_old_cholesky_empty(self, device, dtype): + def run_test(upper): + A = torch.empty(0, 0, dtype=dtype, device=device) + chol = torch.cholesky(A, upper) + chol_A = torch.matmul(chol, chol.t().conj()) + self.assertEqual(A, chol_A) + for upper in [True, False]: + run_test(upper) + + @onlyCPU + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_old_cholesky_autograd(self, device, dtype): + def func(root, upper): + x = 0.5 * (root + root.transpose(-1, -2).conj()) + return torch.cholesky(x, upper) + + def run_test(upper, dims): + root = torch.rand(*dims, dtype=dtype, device=device, requires_grad=True) + root = root + torch.eye(dims[-1]) + + gradcheck(func, [root, upper]) + gradgradcheck(func, [root, upper]) + + root = torch.rand(*dims, dtype=dtype, device=device) + root = torch.matmul(root, root.transpose(-1, -2).conj()) + root.requires_grad_() + chol = root.cholesky().sum().backward() + self.assertEqual(root.grad, root.grad.transpose(-1, -2).conj()) # Check the gradient is hermitian + + for upper, dims in itertools.product([True, False], [(3, 3), (4, 3, 2, 2)]): + run_test(upper, dims) + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @precisionOverride({torch.bfloat16: 1e-1}) @dtypes(*(torch.testing.get_all_dtypes())) diff --git a/test/test_torch.py b/test/test_torch.py index b8061fc19353..56315cd22784 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -7729,98 +7729,6 @@ def test_cholesky_inverse(self, device, dtype): inv1 = torch.cholesky_inverse(chol, False) self.assertLessEqual(inv0.dist(inv1), 1e-12) - @slowTest - @skipCUDAIf(True, "See issue #26789.") - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double) - def test_cholesky_batched_many_batches(self, device, dtype): - from torch.testing._internal.common_utils import random_symmetric_pd_matrix - - def cholesky_test_helper(n, batchsize, device, upper): - A = random_symmetric_pd_matrix(n, batchsize, dtype=dtype, device=device) - chol_fact = torch.cholesky(A, upper=upper) - if upper: - # Correctness check - self.assertEqual(A, chol_fact.transpose(-2, -1).matmul(chol_fact)) - # Upper triangular check - self.assertEqual(chol_fact, chol_fact.triu()) - else: - # Correctness check - self.assertEqual(A, chol_fact.matmul(chol_fact.transpose(-2, -1))) - # Lower triangular check - self.assertEqual(chol_fact, chol_fact.tril()) - - for upper, batchsize in product([True, False], [262144, 524288]): - cholesky_test_helper(2, batchsize, device, upper) - - @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) - def test_cholesky_batched(self, device, dtype): - from torch.testing._internal.common_utils import \ - (random_symmetric_pd_matrix, - random_fullrank_matrix_distinct_singular_value) - - def cholesky_test_helper(n, batch_dims, upper): - # This is a workaround while there is no support for complex random_symmetric_pd_matrix - if dtype.is_complex: - real_dtype = torch.float32 if dtype is torch.complex64 else torch.float64 - A_real = random_fullrank_matrix_distinct_singular_value(n, *batch_dims, dtype=real_dtype, device=device) - A_imag = random_fullrank_matrix_distinct_singular_value(n, *batch_dims, dtype=real_dtype, device=device) - A = A_real + 1j * A_imag - # There is no support for complex batched matmul yet - matmul_list = [] - for mat in A.contiguous().view(-1, n, n): - matmul_list.append(mat @ mat.t().conj()) - A = torch.stack(matmul_list).view(*batch_dims, n, n) - else: - A = random_symmetric_pd_matrix(n, *batch_dims, dtype=dtype, device=device) - cholesky_exp = torch.stack([m.cholesky(upper=upper) for m in A.reshape(-1, n, n)]) - cholesky_exp = cholesky_exp.reshape_as(A) - self.assertEqual(cholesky_exp, torch.cholesky(A, upper=upper)) - - for upper, batchsize in product([True, False], [(3,), (3, 4), (2, 3, 4)]): - cholesky_test_helper(3, batchsize, upper) - - @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) - def test_cholesky(self, device, dtype): - from torch.testing._internal.common_utils import \ - (random_symmetric_pd_matrix, - random_fullrank_matrix_distinct_singular_value) - - # This is a workaround while there is no support for complex random_symmetric_pd_matrix - if dtype.is_complex: - real_dtype = torch.float32 if dtype is torch.complex64 else torch.float64 - A_real = random_fullrank_matrix_distinct_singular_value(10, dtype=real_dtype, device=device) - A_imag = random_fullrank_matrix_distinct_singular_value(10, dtype=real_dtype, device=device) - A = A_real + 1j * A_imag - A = A @ A.t().conj() - else: - A = random_symmetric_pd_matrix(10, dtype=dtype, device=device) - - # default Case - C = torch.cholesky(A) - C_ = C.cpu().numpy() - B = np.matmul(C_, C_.T.conj()) - self.assertEqual(A, B, atol=1e-14, rtol=0) - - # test Upper Triangular - U = torch.cholesky(A, True) - U_ = U.cpu().numpy() - B = np.matmul(U_.T.conj(), U_) - self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (upper) did not allow rebuilding the original matrix') - - # test Lower Triangular - L = torch.cholesky(A, False) - L_ = L.cpu().numpy() - B = np.matmul(L_, L_.T.conj()) - self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (lower) did not allow rebuilding the original matrix') - def test_view(self, device): tensor = torch.rand(15, device=device) template = torch.rand(3, 5, device=device) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 2ffd28400481..37d46781907f 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -294,6 +294,9 @@ - name: cholesky(Tensor self, bool upper=False) -> Tensor self: cholesky_backward(grad, upper, result) +- name: linalg_cholesky(Tensor self) -> Tensor + self: cholesky_backward(grad, false, result) + - name: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor self, input2: cholesky_solve_backward(grad, self, input2, result, upper) diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index ccce10e820ce..ce03dfaaee7b 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -94,7 +94,7 @@ 'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'ger', 'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal', 'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_', - 'exp', 'nonzero', 'mean', 'inverse', 'solve' + 'exp', 'nonzero', 'mean', 'inverse', 'solve', 'linalg_cholesky' } # Some operators invalidate the grad_accumulator. Let's reset it. diff --git a/torch/csrc/api/include/torch/linalg.h b/torch/csrc/api/include/torch/linalg.h index c0bae62510e6..426c5fd6077d 100644 --- a/torch/csrc/api/include/torch/linalg.h +++ b/torch/csrc/api/include/torch/linalg.h @@ -8,6 +8,14 @@ namespace linalg { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { +inline Tensor cholesky(const Tensor& self) { + return torch::linalg_cholesky(self); +} + +inline Tensor cholesky_out(Tensor& result, const Tensor& self) { + return torch::linalg_cholesky_out(result, self); +} + inline Tensor det(const Tensor& self) { return torch::linalg_det(self); } @@ -39,6 +47,24 @@ inline Tensor& tensorsolve_out(Tensor& result, const Tensor& self, const Tensor& } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ +/// Cholesky decomposition +/// +/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.cholesky +/// +/// Example: +/// ``` +/// auto A = torch::randn({4, 4}); +/// auto A = torch::matmul(A, A.t()); +/// auto L = torch::linalg::cholesky(A); +/// assert(torch::allclose(torch::matmul(L, L.t()), A)); +/// ``` +inline Tensor cholesky(const Tensor& self) { + return detail::cholesky(self); +} + +inline Tensor cholesky_out(Tensor& result, const Tensor& self) { + return detail::cholesky_out(result, self); +} /// See the documentation of torch.linalg.det inline Tensor linalg_det(const Tensor& self) { diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index bf2947da81c8..247519d30a1b 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -8,6 +8,80 @@ # Note: This not only adds doc strings for functions in the linalg namespace, but # also connects the torch.linalg Python namespace to the torch._C._linalg builtins. +cholesky = _add_docstr(_linalg.linalg_cholesky, r""" +linalg.cholesky(input, *, out=None) -> Tensor + +Computes the Cholesky decomposition of a Hermitian (or symmetric for real-valued matrices) +positive-definite matrix or the Cholesky decompositions for a batch of such matrices. +Each decomposition has the form: + +.. math:: + + \text{input} = LL^H + +where :math:`L` is a lower-triangular matrix and :math:`L^H` is the conjugate transpose of :math:`L`, +which is just a transpose for the case of real-valued input matrices. +In code it translates to ``input = L @ L.t()` if :attr:`input` is real-valued and +``input = L @ L.conj().t()`` if :attr:`input` is complex-valued. +The batch of :math:`L` matrices is returned. + +Supports real-valued and complex-valued inputs. + +.. note:: If :attr:`input` is not a Hermitian positive-definite matrix, or if it's a batch of matrices + and one or more of them is not a Hermitian positive-definite matrix, then a RuntimeError will be thrown. + If :attr:`input` is a batch of matrices, then the error message will include the batch index + of the first matrix that is not Hermitian positive-definite. + +.. warning:: This function always checks whether :attr:`input` is a Hermitian positive-definite matrix + using `info` argument to LAPACK/MAGMA call. For CUDA this causes cross-device memory synchronization. + +Args: + input (Tensor): the input tensor of size :math:`(*, n, n)` consisting of Hermitian positive-definite + :math:`n \times n` matrices, where `*` is zero or more batch dimensions. + +Keyword args: + out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` + +Examples:: + + >>> a = torch.randn(2, 2, dtype=torch.complex128) + >>> a = torch.mm(a, a.t().conj()) # creates a Hermitian positive-definite matrix + >>> l = torch.linalg.cholesky(a) + >>> a + tensor([[2.5266+0.0000j, 1.9586-2.0626j], + [1.9586+2.0626j, 9.4160+0.0000j]], dtype=torch.complex128) + >>> l + tensor([[1.5895+0.0000j, 0.0000+0.0000j], + [1.2322+1.2976j, 2.4928+0.0000j]], dtype=torch.complex128) + >>> torch.mm(l, l.t().conj()) + tensor([[2.5266+0.0000j, 1.9586-2.0626j], + [1.9586+2.0626j, 9.4160+0.0000j]], dtype=torch.complex128) + + >>> a = torch.randn(3, 2, 2, dtype=torch.float64) + >>> a = torch.matmul(a, a.transpose(-2, -1)) # creates a symmetric positive-definite matrix + >>> l = torch.linalg.cholesky(a) + >>> a + tensor([[[ 1.1629, 2.0237], + [ 2.0237, 6.6593]], + + [[ 0.4187, 0.1830], + [ 0.1830, 0.1018]], + + [[ 1.9348, -2.5744], + [-2.5744, 4.6386]]], dtype=torch.float64) + >>> l + tensor([[[ 1.0784, 0.0000], + [ 1.8766, 1.7713]], + + [[ 0.6471, 0.0000], + [ 0.2829, 0.1477]], + + [[ 1.3910, 0.0000], + [-1.8509, 1.1014]]], dtype=torch.float64) + >>> torch.allclose(torch.matmul(l, l.transpose(-2, -1)), a) + True +""") + det = _add_docstr(_linalg.linalg_det, r""" linalg.det(input) -> Tensor diff --git a/torch/overrides.py b/torch/overrides.py index d539e7ef120c..875be8caee11 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -272,6 +272,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.chain_matmul: lambda *matrices: -1, torch.channel_shuffle: lambda input, groups : -1, torch.cholesky: lambda input, upper=False, out=None: -1, + torch.linalg.cholesky: lambda input, out=None: -1, torch.cholesky_inverse: lambda input, upper=False, out=None: -1, torch.cholesky_solve: lambda input1, input2, upper=False, out=None: -1, torch.choose_qparams_optimized: lambda input, numel, n_bins, ratio, bit_width: -1, diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index df606c1c4bd1..f158f4d5d0d4 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1503,6 +1503,19 @@ def random_symmetric_pd_matrix(matrix_size, *batch_dims, **kwargs): + torch.eye(matrix_size, dtype=dtype, device=device) * 1e-5 +def random_hermitian_pd_matrix(matrix_size, *batch_dims, dtype, device): + """ + Returns a batch of random Hermitian positive-definite matrices. + The shape of the result is batch_dims + (matrix_size, matrix_size) + + The following example creates a tensor of size 2 x 4 x 3 x 3 + >>> matrices = random_hermitian_pd_matrix(3, 2, 4, dtype=dtype, device=device) + """ + A = torch.randn(*(batch_dims + (matrix_size, matrix_size)), + dtype=dtype, device=device) + return torch.matmul(A, A.transpose(-2, -1).conj()) + + def make_nonzero_det(A, sign=None, min_singular_value=0.1): u, s, v = A.svd() s.clamp_(min=min_singular_value)