Skip to content

Commit

Permalink
Added linalg.cholesky (#46083)
Browse files Browse the repository at this point in the history
Summary:
This PR adds `torch.linalg.cholesky` function that matches `numpy.linalg.cholesky`.

Fixed `lda` argument to `lapackCholesky` calls.
Added `random_hermitian_pd_matrix` helper function for tests.

Ref #42666.

Pull Request resolved: #46083

Reviewed By: ailzhang

Differential Revision: D24861752

Pulled By: mruberry

fbshipit-source-id: 214dbceb4e8a2c589df209493efd843962d25593
  • Loading branch information
IvanYashchuk authored and facebook-github-bot committed Nov 14, 2020
1 parent e8fecd5 commit 260daf0
Show file tree
Hide file tree
Showing 13 changed files with 379 additions and 122 deletions.
18 changes: 17 additions & 1 deletion aten/src/ATen/native/BatchLinearAlgebra.cpp
Expand Up @@ -5,6 +5,7 @@
#include <ATen/ExpandUtils.h>

#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/Resize.h>
#include <ATen/native/cpu/zmath.h>
#include <ATen/Parallel.h>

Expand Down Expand Up @@ -535,11 +536,12 @@ static void apply_cholesky(Tensor& self, bool upper, std::vector<int64_t>& infos
auto self_matrix_stride = matrixStride(self);
auto batch_size = batchCount(self);
auto n = self.size(-2);
auto lda = std::max<int64_t>(1, n);

int info;
for (int64_t i = 0; i < batch_size; i++) {
scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
lapackCholesky<scalar_t>(uplo, n, self_working_ptr, n, &info);
lapackCholesky<scalar_t>(uplo, n, self_working_ptr, lda, &info);
infos[i] = info;
if (info != 0) {
return;
Expand Down Expand Up @@ -584,6 +586,20 @@ Tensor& cholesky_out(Tensor &result, const Tensor &self, bool upper) {
return result;
}

Tensor linalg_cholesky(const Tensor &self) {
squareCheckInputs(self);
return at::_cholesky_helper(self, /*upper=*/false).tril_();
}

Tensor& linalg_cholesky_out(Tensor &result, const Tensor &self) {
TORCH_CHECK(result.scalar_type() == self.scalar_type(),
"result dtype ", result.scalar_type(), " does not match self dtype ", self.scalar_type());
Tensor result_tmp = at::linalg_cholesky(self);
at::native::resize_output(result, result_tmp.sizes());
result.copy_(result_tmp);
return result;
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template<typename scalar_t>
Expand Down
7 changes: 4 additions & 3 deletions aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
Expand Up @@ -1333,10 +1333,11 @@ AT_ERROR("cholesky: MAGMA library not found in "

auto self_data = self.data_ptr<scalar_t>();
magma_int_t n = magma_int_cast(self.size(-2), "self.size(-2)");
auto lda = std::max<magma_int_t>(1, n);

if (self.dim() == 2) {
magma_int_t info = 0;
magmaCholesky<scalar_t>(uplo, n, self_data, n, &info);
magmaCholesky<scalar_t>(uplo, n, self_data, lda, &info);
infos[0] = info;
} else {
auto self_mat_stride = matrixStride(self);
Expand Down Expand Up @@ -1367,14 +1368,14 @@ AT_ERROR("cholesky: MAGMA library not found in "
magma_int_t* info_array_cur = &info_array[mini_idx];

magmaCholeskyBatched<scalar_t>(
uplo, n, self_array_cur, n, info_array_cur, batch_limit, magma_queue);
uplo, n, self_array_cur, lda, info_array_cur, batch_limit, magma_queue);
}

// Compute whatever is left = batch_size - floor(batch_size / batch_limit) * batch_limit
// which concisely is equal to batch_size % batch_limit
if (batch_size % batch_limit != 0) {
magmaCholeskyBatched<scalar_t>(
uplo, n, &self_array[mini_idx], n, &info_array[mini_idx], batch_size % batch_limit, magma_queue);
uplo, n, &self_array[mini_idx], lda, &info_array[mini_idx], batch_size % batch_limit, magma_queue);
}

for (int64_t i = 0; i < batch_size; i++) {
Expand Down
13 changes: 13 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -8909,6 +8909,19 @@
#
# See linalg_det as an example.

- func: linalg_cholesky(Tensor self) -> Tensor
python_module: linalg
use_c10_dispatcher: full
variants: function
dispatch:
DefaultBackend: linalg_cholesky

- func: linalg_cholesky.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
python_module: linalg
variants: function
dispatch:
DefaultBackend: linalg_cholesky_out

# torch.linalg.det, alias for torch.det
- func: linalg_det(Tensor self) -> Tensor
python_module: linalg
Expand Down
1 change: 1 addition & 0 deletions docs/source/linalg.rst
Expand Up @@ -12,6 +12,7 @@ Common linear algebra operations.
Functions
---------

.. autofunction:: cholesky
.. autofunction:: det
.. autofunction:: norm
.. autofunction:: tensorsolve
24 changes: 0 additions & 24 deletions test/test_autograd.py
Expand Up @@ -2595,30 +2595,6 @@ def test_var_mean_differentiable(self):
torch.autograd.backward(r2, grad)
self.assertTrue(torch.allclose(input1.grad, input2.grad, rtol=0.01, atol=0.0))

@skipIfNoLapack
def test_cholesky(self):
def func(root, upper):
x = 0.5 * (root + root.transpose(-1, -2).conj())
return torch.cholesky(x, upper)

def run_test(upper, dims, dtype):
root = torch.rand(*dims, dtype=dtype, requires_grad=True)
root = root + torch.eye(dims[-1])

gradcheck(func, [root, upper])
gradgradcheck(func, [root, upper])

root = torch.rand(*dims, dtype=dtype)
root = torch.matmul(root, root.transpose(-1, -2).conj())
root.requires_grad_()
chol = root.cholesky().sum().backward()
self.assertEqual(root.grad, root.grad.transpose(-1, -2).conj()) # Check the gradient is hermitian

for upper, dims, dtype in product([True, False],
[(3, 3), (4, 3, 2, 2)],
[torch.double, torch.cdouble]):
run_test(upper, dims, dtype)

@skipIfNoLapack
def test_cholesky_solve(self):
def _test_with_size(A_dims, B_dims, upper):
Expand Down
227 changes: 226 additions & 1 deletion test/test_linalg.py
Expand Up @@ -9,8 +9,9 @@
(TestCase, run_tests, TEST_NUMPY, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest, TEST_WITH_ASAN, make_tensor)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, dtypes,
onlyCPU, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride,
onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride,
skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, onlyOnCPUAndCUDA)
from torch.testing._internal.common_cuda import tf32_on_and_off
from torch.testing._internal.jit_metaprogramming_utils import gen_script_fn_and_args
from torch.autograd import gradcheck, gradgradcheck

Expand Down Expand Up @@ -58,6 +59,230 @@ def run_test_case(a, b):
run_test_case(zero_strided, b)
run_test_case(a, zero_strided)

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_cholesky(self, device, dtype):
from torch.testing._internal.common_utils import random_hermitian_pd_matrix

def run_test(shape, batch, contiguous):
A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device)
if A.numel() > 0 and not contiguous:
A = A.transpose(-2, -1)
self.assertFalse(A.is_contiguous())
expected_L = np.linalg.cholesky(A.cpu().numpy())
actual_L = torch.linalg.cholesky(A)

# For fp32 individual entries in matrices can differ between PyTorch and NumPy
# Let's compare the norms of matrices instead
if A.numel() > 0 and dtype in [torch.float32, torch.complex64]:
# axis is specified to calculate matrix norm for batched input
expected_norm = np.linalg.norm(expected_L, ord=1, axis=(-2, -1))
actual_norm = torch.linalg.norm(actual_L, ord=1, axis=(-2, -1))
# Compare the norms with standard tolerances
self.assertEqual(actual_norm, expected_norm)
# and individual values with a higher tolerance
self.assertEqual(actual_L, expected_L, atol=1e-2, rtol=1e-5)
else:
self.assertEqual(actual_L, expected_L)

shapes = (0, 3, 5)
batches = ((), (3, ), (2, 2))
larger_input_case = [(100, (5, ), True)]
for shape, batch, contiguous in list(itertools.product(shapes, batches, (True, False))) + larger_input_case:
run_test(shape, batch, contiguous)

# check the out= variant
A = random_hermitian_pd_matrix(3, 3, dtype=dtype, device=device)
out = torch.empty_like(A)
ans = torch.linalg.cholesky(A, out=out)
self.assertEqual(ans, out)
expected = torch.linalg.cholesky(A)
self.assertEqual(expected, out)

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_cholesky_errors_and_warnings(self, device, dtype):
from torch.testing._internal.common_utils import random_hermitian_pd_matrix

# cholesky requires the input to be a square matrix or batch of square matrices
A = torch.randn(2, 3, device=device, dtype=dtype)
with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'):
torch.linalg.cholesky(A)
A = torch.randn(2, 2, 3, device=device, dtype=dtype)
with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'):
torch.linalg.cholesky(A)
with self.assertRaisesRegex(np.linalg.LinAlgError, r'Last 2 dimensions of the array must be square'):
np.linalg.cholesky(A.cpu().numpy())

# cholesky requires the input to be at least 2 dimensional tensor
A = torch.randn(2, device=device, dtype=dtype)
with self.assertRaisesRegex(RuntimeError, r'must have at least 2 dimensions'):
torch.linalg.cholesky(A)
with self.assertRaisesRegex(np.linalg.LinAlgError,
r'1-dimensional array given\. Array must be at least two-dimensional'):
np.linalg.cholesky(A.cpu().numpy())

# if the input matrix is singular, an error should be raised
A = torch.eye(3, 3, dtype=dtype, device=device)
A[-1, -1] = 0 # Now A is singular
with self.assertRaisesRegex(RuntimeError, r'U\(3,3\) is zero, singular U\.'):
torch.linalg.cholesky(A)
with self.assertRaisesRegex(np.linalg.LinAlgError, r'Matrix is not positive definite'):
np.linalg.cholesky(A.cpu().numpy())

# if at least one matrix in the batch is singular, an error should be raised
A = torch.eye(3, 3, dtype=dtype, device=device)
A = A.reshape((1, 3, 3))
A = A.repeat(5, 1, 1)
A[4, -1, -1] = 0 # Now A[4] is singular
with self.assertRaisesRegex(RuntimeError, r'For batch 4: U\(3,3\) is zero, singular U\.'):
torch.linalg.cholesky(A)

# if out tensor with wrong shape is passed a warning is given
A = random_hermitian_pd_matrix(3, dtype=dtype, device=device)
out = torch.empty(2, 3, dtype=dtype, device=device)
with warnings.catch_warnings(record=True) as w:
# Trigger warning
torch.linalg.cholesky(A, out=out)
# Check warning occurs
self.assertEqual(len(w), 1)
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))

# dtypes should match
out = torch.empty_like(A).to(torch.int)
with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match self dtype"):
torch.linalg.cholesky(A, out=out)

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float64, torch.complex128)
def test_cholesky_autograd(self, device, dtype):
def func(root):
x = 0.5 * (root + root.transpose(-1, -2).conj())
return torch.linalg.cholesky(x)

def run_test(shape):
root = torch.rand(*shape, dtype=dtype, device=device, requires_grad=True)
root = root + torch.eye(shape[-1], dtype=dtype, device=device)

gradcheck(func, root)
gradgradcheck(func, root)

root = torch.rand(*shape, dtype=dtype, device=device)
root = torch.matmul(root, root.transpose(-1, -2).conj())
root.requires_grad_()
chol = torch.linalg.cholesky(root).sum().backward()
self.assertEqual(root.grad, root.grad.transpose(-1, -2).conj()) # Check the gradient is hermitian

shapes = ((3, 3), (4, 3, 2, 2))
for shape in shapes:
run_test(shape)

# NOTE: old_cholesky* tests were moved here from test_torch.py and test_autograd.py
@slowTest
@skipCUDAIf(True, "See issue #26789.")
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.double)
def test_old_cholesky_batched_many_batches(self, device, dtype):
from torch.testing._internal.common_utils import random_symmetric_pd_matrix

def cholesky_test_helper(n, batchsize, device, upper):
A = random_symmetric_pd_matrix(n, batchsize, dtype=dtype, device=device)
chol_fact = torch.cholesky(A, upper=upper)
if upper:
# Correctness check
self.assertEqual(A, chol_fact.transpose(-2, -1).matmul(chol_fact))
# Upper triangular check
self.assertEqual(chol_fact, chol_fact.triu())
else:
# Correctness check
self.assertEqual(A, chol_fact.matmul(chol_fact.transpose(-2, -1)))
# Lower triangular check
self.assertEqual(chol_fact, chol_fact.tril())

for upper, batchsize in itertools.product([True, False], [262144, 524288]):
cholesky_test_helper(2, batchsize, device, upper)

@precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_old_cholesky_batched(self, device, dtype):
from torch.testing._internal.common_utils import random_hermitian_pd_matrix

def cholesky_test_helper(n, batch_dims, upper):
A = random_hermitian_pd_matrix(n, *batch_dims, dtype=dtype, device=device)
cholesky_exp = torch.stack([m.cholesky(upper=upper) for m in A.reshape(-1, n, n)])
cholesky_exp = cholesky_exp.reshape_as(A)
self.assertEqual(cholesky_exp, torch.cholesky(A, upper=upper))

for upper, batchsize in itertools.product([True, False], [(3,), (3, 4), (2, 3, 4)]):
cholesky_test_helper(3, batchsize, upper)

@precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@tf32_on_and_off(0.01)
def test_old_cholesky(self, device, dtype):
from torch.testing._internal.common_utils import random_hermitian_pd_matrix

A = random_hermitian_pd_matrix(10, dtype=dtype, device=device)

# default Case
C = torch.cholesky(A)
B = torch.mm(C, C.t().conj())
self.assertEqual(A, B, atol=1e-14, rtol=0)

# test Upper Triangular
U = torch.cholesky(A, True)
B = torch.mm(U.t().conj(), U)
self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (upper) did not allow rebuilding the original matrix')

# test Lower Triangular
L = torch.cholesky(A, False)
B = torch.mm(L, L.t().conj())
self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (lower) did not allow rebuilding the original matrix')

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_old_cholesky_empty(self, device, dtype):
def run_test(upper):
A = torch.empty(0, 0, dtype=dtype, device=device)
chol = torch.cholesky(A, upper)
chol_A = torch.matmul(chol, chol.t().conj())
self.assertEqual(A, chol_A)
for upper in [True, False]:
run_test(upper)

@onlyCPU
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_old_cholesky_autograd(self, device, dtype):
def func(root, upper):
x = 0.5 * (root + root.transpose(-1, -2).conj())
return torch.cholesky(x, upper)

def run_test(upper, dims):
root = torch.rand(*dims, dtype=dtype, device=device, requires_grad=True)
root = root + torch.eye(dims[-1])

gradcheck(func, [root, upper])
gradgradcheck(func, [root, upper])

root = torch.rand(*dims, dtype=dtype, device=device)
root = torch.matmul(root, root.transpose(-1, -2).conj())
root.requires_grad_()
chol = root.cholesky().sum().backward()
self.assertEqual(root.grad, root.grad.transpose(-1, -2).conj()) # Check the gradient is hermitian

for upper, dims in itertools.product([True, False], [(3, 3), (4, 3, 2, 2)]):
run_test(upper, dims)

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
@precisionOverride({torch.bfloat16: 1e-1})
@dtypes(*(torch.testing.get_all_dtypes()))
Expand Down

0 comments on commit 260daf0

Please sign in to comment.