Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added linalg.cholesky #46083

Closed
wants to merge 62 commits into from
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
6b7f14a
wip linalg.cholesky
IvanYashchuk Oct 7, 2020
f7a08f4
Added xfailed test case
IvanYashchuk Oct 8, 2020
32e10f8
Added cholesky to csrc/api/include/torch/linalg.h
IvanYashchuk Oct 8, 2020
3c4d5a4
Updated example in docs
IvanYashchuk Oct 9, 2020
307020e
Added random_hermitian_pd_matrix for the test
IvanYashchuk Oct 9, 2020
9e9e0c0
Use random_hermitian_pd_matrix in the test_torch/cholesky
IvanYashchuk Oct 9, 2020
ae4f3ee
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Oct 11, 2020
1798832
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Oct 13, 2020
e74de2c
No need for skip if numpy not found anymore
IvanYashchuk Oct 13, 2020
9343678
Added larger input case
IvanYashchuk Oct 13, 2020
604f0a8
Added assertRaises tests for cholesky
IvanYashchuk Oct 13, 2020
5f80042
Added a note to the docs that the error is given if the input is not
IvanYashchuk Oct 13, 2020
da4e88b
Enabled autograd for linalg_cholesky
IvanYashchuk Oct 13, 2020
efb725c
Added a note to the documentation about complex support
IvanYashchuk Oct 13, 2020
6297b6b
Added the out= variant
IvanYashchuk Oct 14, 2020
709273b
Moved error checks to a separate test
IvanYashchuk Oct 14, 2020
21cfca0
Added xfailed test for cholesky cuda autograd
IvanYashchuk Oct 14, 2020
31cbe75
Only complex128 is needed for autograd test
IvanYashchuk Oct 14, 2020
df0172e
Added a docstring for random_hermitian_pd_matrix
IvanYashchuk Oct 14, 2020
b54c7f8
Updated linalg.cholesky docs
IvanYashchuk Oct 14, 2020
7878183
Added a note on error message for batch of singular matrices
IvanYashchuk Oct 15, 2020
0d4a8c7
In tests compare norms of the resulting matrices
IvanYashchuk Oct 15, 2020
e4832d3
Added entry in overrides.py
IvanYashchuk Oct 15, 2020
e800b97
Added test case for batch singular input
IvanYashchuk Oct 15, 2020
8b58586
Moved tests for torch.cholesky to test_linalg.py
IvanYashchuk Oct 15, 2020
7aabce3
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Oct 15, 2020
12d11ce
Added a dispatch section with DefaultBackend in native_functions.yaml
IvanYashchuk Oct 15, 2020
97501c6
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Oct 15, 2020
29f94c9
gradgradcheck for cholesky now works
IvanYashchuk Oct 15, 2020
00c41ed
Updated documentation for linalg cholesky
IvanYashchuk Oct 19, 2020
34cef7d
Add one more assert for out= test
IvanYashchuk Oct 19, 2020
5d2230d
Updated test_cholesky_errors
IvanYashchuk Oct 19, 2020
63e922d
Added non contiguous test
IvanYashchuk Oct 19, 2020
6d64067
Make test_cholesky_autograd run on gpu for fp64
IvanYashchuk Oct 19, 2020
40935f1
Changed np.allclose -> torch.allclose
IvanYashchuk Oct 19, 2020
aaad340
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Oct 19, 2020
1d67cd0
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Oct 27, 2020
22be1a5
Fix long lines
IvanYashchuk Oct 28, 2020
d547b63
Remove unused import
IvanYashchuk Oct 28, 2020
fe89410
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Oct 29, 2020
9e2cfb4
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Nov 2, 2020
30239f2
Merge branch 'master' into linalg-cholesky
IvanYashchuk Nov 3, 2020
f723421
Use at::native::resize_output
IvanYashchuk Nov 3, 2020
1b46c19
Merge branch 'master' into linalg-cholesky
IvanYashchuk Nov 4, 2020
10d276a
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Nov 6, 2020
e224c97
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Nov 10, 2020
048316d
Added a warning about data movement for cuda inputs
IvanYashchuk Nov 10, 2020
0a13e1d
Use typed std::max
IvanYashchuk Nov 10, 2020
173fea1
Replaced torch.allclose with self.assertEqual
IvanYashchuk Nov 10, 2020
b26f52c
Removed unused import
IvanYashchuk Nov 10, 2020
ba3708e
Fix imports
IvanYashchuk Nov 10, 2020
5f0abff
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Nov 11, 2020
b670f26
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Nov 12, 2020
f2ee1f6
Finish merge
IvanYashchuk Nov 12, 2020
344053f
Merge branch 'master' into linalg-cholesky
Nov 12, 2020
869b40a
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Nov 12, 2020
d52b83b
Differentiation of complex cholesky on cuda now works for single input;
IvanYashchuk Nov 12, 2020
c98076b
Merge branch 'master' into linalg-cholesky
IvanYashchuk Nov 12, 2020
287c878
Batched matmul for complex on CUDA is implemented now; fix tests
IvanYashchuk Nov 13, 2020
03b29a3
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Nov 13, 2020
88e23c3
Remove redundant code from test_old_cholesky_batched
IvanYashchuk Nov 13, 2020
eb507e5
flake8 fix
IvanYashchuk Nov 13, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 7 additions & 1 deletion aten/src/ATen/native/BatchLinearAlgebra.cpp
Expand Up @@ -534,11 +534,12 @@ static void apply_cholesky(Tensor& self, bool upper, std::vector<int64_t>& infos
auto self_matrix_stride = matrixStride(self);
auto batch_size = batchCount(self);
auto n = self.size(-2);
auto lda = std::max(int64_t{1}, n);
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved

int info;
for (int64_t i = 0; i < batch_size; i++) {
scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
lapackCholesky<scalar_t>(uplo, n, self_working_ptr, n, &info);
lapackCholesky<scalar_t>(uplo, n, self_working_ptr, lda, &info);
infos[i] = info;
if (info != 0) {
return;
Expand Down Expand Up @@ -583,6 +584,11 @@ Tensor& cholesky_out(Tensor &result, const Tensor &self, bool upper) {
return result;
}

Tensor linalg_cholesky(const Tensor &self) {
squareCheckInputs(self);
return at::_cholesky_helper(self, /*upper=*/false).tril_();
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template<typename scalar_t>
Expand Down
7 changes: 4 additions & 3 deletions aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
Expand Up @@ -1068,10 +1068,11 @@ AT_ERROR("cholesky: MAGMA library not found in "

auto self_data = self.data_ptr<scalar_t>();
magma_int_t n = magma_int_cast(self.size(-2), "self.size(-2)");
auto lda = std::max(magma_int_t{1}, n);
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved

if (self.dim() == 2) {
magma_int_t info = 0;
magmaCholesky<scalar_t>(uplo, n, self_data, n, &info);
magmaCholesky<scalar_t>(uplo, n, self_data, lda, &info);
infos[0] = info;
} else {
auto self_mat_stride = matrixStride(self);
Expand Down Expand Up @@ -1102,14 +1103,14 @@ AT_ERROR("cholesky: MAGMA library not found in "
magma_int_t* info_array_cur = &info_array[mini_idx];

magmaCholeskyBatched<scalar_t>(
uplo, n, self_array_cur, n, info_array_cur, batch_limit, magma_queue);
uplo, n, self_array_cur, lda, info_array_cur, batch_limit, magma_queue);
}

// Compute whatever is left = batch_size - floor(batch_size / batch_limit) * batch_limit
// which concisely is equal to batch_size % batch_limit
if (batch_size % batch_limit != 0) {
magmaCholeskyBatched<scalar_t>(
uplo, n, &self_array[mini_idx], n, &info_array[mini_idx], batch_size % batch_limit, magma_queue);
uplo, n, &self_array[mini_idx], lda, &info_array[mini_idx], batch_size % batch_limit, magma_queue);
}

for (int64_t i = 0; i < batch_size; i++) {
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -8284,6 +8284,11 @@
#
# See linalg_det as an example.

- func: linalg_cholesky(Tensor self) -> Tensor
mruberry marked this conversation as resolved.
Show resolved Hide resolved
python_module: linalg
use_c10_dispatcher: full
variants: function

# torch.linalg.det, alias for torch.det
- func: linalg_det(Tensor self) -> Tensor
python_module: linalg
Expand Down
1 change: 1 addition & 0 deletions docs/source/linalg.rst
Expand Up @@ -12,5 +12,6 @@ Common linear algebra operations.
Functions
---------

.. autofunction:: cholesky
.. autofunction:: det
.. autofunction:: norm
89 changes: 87 additions & 2 deletions test/test_linalg.py
Expand Up @@ -7,9 +7,10 @@
from torch.testing._internal.common_utils import \
(TestCase, run_tests, TEST_NUMPY, IS_MACOS, IS_WINDOWS, TEST_WITH_ASAN, make_tensor)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, dtypes, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride)
(instantiate_device_type_tests, dtypes, dtypesIfCPU, dtypesIfCUDA,
onlyCUDA, onlyCPU, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride)
from torch.testing._internal.jit_metaprogramming_utils import gen_script_fn_and_args
from torch.autograd import gradcheck
from torch.autograd import gradcheck, gradgradcheck

if TEST_NUMPY:
import numpy as np
Expand Down Expand Up @@ -922,6 +923,90 @@ def test_nuclear_norm_exceptions_old(self, device):
self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2))


@precisionOverride({torch.float: 1e-2, torch.cfloat: 1e-4})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The absolute tolerance for float32 is surprisingly large. What's going on there?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume that NumPy could silently convert the input to fp64 and then back to fp32.
CPU tests pass for 1e-3 tolerance. GPU passes only with 1e-2. It doesn't mean that the decomposition is incorrect, it's all valid, but the individual entries can differ a bit across different libraries/environments.
A similar thing was observed for inverse #45034, where the results of cuSOLVER and MAGMA differ a lot for fp32.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your analysis is probably correct. Would you verify what dtype NumPy computes in and a comment here explaining the absolute tolerance value?

How significant a loss of precision is 1e-2 here? In terms of relative tolerance, I mean?

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypesIfCPU(torch.float32, torch.float64, torch.complex64, torch.complex128)
@dtypesIfCUDA(torch.float32, torch.float64)
def test_cholesky(self, device, dtype):
from torch.testing._internal.common_utils import random_hermitian_pd_matrix

def run_test(shape, batch):
A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device)
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
expected_L = np.linalg.cholesky(A.cpu().numpy())
actual_L = torch.linalg.cholesky(A)
self.assertEqual(actual_L, expected_L)

shapes = (0, 3, 5)
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
batches = ((), (3, ), (2, 2))
larger_input_case = [(100, (5, ))]
for shape, batch in list(itertools.product(shapes, batches)) + larger_input_case:
run_test(shape, batch)
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved

# cholesky requires the input to be a square matrix
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
A = torch.randn(2, 3, device=device, dtype=dtype)
with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'):
torch.linalg.cholesky(A)
with self.assertRaisesRegex(np.linalg.LinAlgError, r'Last 2 dimensions of the array must be square'):
np.linalg.cholesky(A.cpu().numpy())

# cholesky requires the input to be a matrix
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
A = torch.randn(2, device=device, dtype=dtype)
with self.assertRaisesRegex(RuntimeError, r'must have at least 2 dimensions'):
torch.linalg.cholesky(A)
with self.assertRaisesRegex(np.linalg.LinAlgError, r'1-dimensional array given\. Array must be at least two-dimensional'):
np.linalg.cholesky(A.cpu().numpy())

# if the input matrix is singular, an error should be raised
A = torch.eye(3, 3, dtype=dtype, device=device)
A[-1, -1] = 0 # Now A is singular
with self.assertRaisesRegex(RuntimeError, r'U\(3,3\) is zero, singular U\.'):
torch.linalg.cholesky(A)
with self.assertRaisesRegex(np.linalg.LinAlgError, r'Matrix is not positive definite'):
np.linalg.cholesky(A.cpu().numpy())

# TODO: once there is more support for complex dtypes on GPU, they shall be added to above test
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
# particularly when RuntimeError: _th_bmm_out not supported on CUDAType for ComplexFloat is fixed
@unittest.expectedFailure
@onlyCUDA
@skipCUDAIfNoMagma
@dtypes(torch.complex64, torch.complex128)
def test_cholesky_xfailed(self, device, dtype):
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device)
expected_L = np.linalg.cholesky(A.cpu().numpy())
actual_L = torch.linalg.cholesky(A)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this produce the aforementioned RuntimeError? If so, then I think it's probably better to use a with self.assertRaisesRegex(...):, rather than expectedFailure.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea to use an xfail test (unittest.expectedFailure) because that is a bug we can't fix right now.
I would expect to use self.assertRaisesRegex(...) (and consequently mark the test as 'pass') for the intended errors that are not bugs.

self.assertEqual(actual_L, expected_L)

# TODO: enable CUDA tests once
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe would also be good to mention that once jit and CUDA complex support are added, this test should be moved to test_autograd_and_jit. Or if the tests in test_autograd_and_jit have been moved into common_methods_invocations.py, it should be moved there.

Jit and CUDA complex support can probably be done in future PRs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test for torch.choleskyfrom test/test_autograd.py is not part of common_methods_invocations.py:method_tests. Maybe the reason is that it requires making the input hermitian for gradcheck to work correctly

def func(x):
    x = 0.5 * (x + x.transpose(-1, -2).conj()) # Make `x` Hermitian
    return torch.linalg.cholesky(x)

func(x) is being tested not torch.linalg.cholesky(x) directly. So I don't know whether it can be moved later to test_autograd_and_jit/common_methods_invocations.py.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understand, I don't think that prevents us from putting this in common_methods_invocations.py:method_tests. The second element of an entry in method_tests can be a contructing function, for instance: https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_methods_invocations.py#L521

So I think it would be possible to make a constructing function for cholesky too.

# RuntimeError: "triangular_solve_cuda" not implemented for 'ComplexDouble' is fixed
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
@onlyCPU
@skipCPUIfNoLapack
@dtypes(torch.float64, torch.complex128)
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
def test_cholesky_autograd(self, device, dtype):
def func(root):
x = 0.5 * (root + root.transpose(-1, -2).conj())
return torch.linalg.cholesky(x)

def run_test(shape):
root = torch.rand(*shape, dtype=dtype, device=device, requires_grad=True)
root = root + torch.eye(shape[-1], dtype=dtype, device=device)

gradcheck(func, root)
# TODO: gradgradcheck does not work correctly yet for complex
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 has been working on better gradcheck support for complex, I'm not sure about the gradgradcheck status. HOWEVER, I think @peterbell10 found a nice workaround for this issue. @peterbell10, can you comment?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gradgradcheck works with #45737. Once it's is merged, I'll update the code here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

if not dtype.is_complex:
gradgradcheck(func, root)

root = torch.rand(*shape, dtype=dtype, device=device)
root = torch.matmul(root, root.transpose(-1, -2).conj())
root.requires_grad_()
chol = torch.linalg.cholesky(root).sum().backward()
self.assertEqual(root.grad, root.grad.transpose(-1, -2).conj()) # Check the gradient is hermitian

shapes = ((3, 3), (4, 3, 2, 2))
for shape in shapes:
run_test(shape)

instantiate_device_type_tests(TestLinalg, globals())

if __name__ == '__main__':
Expand Down
14 changes: 2 additions & 12 deletions test/test_torch.py
Expand Up @@ -7820,19 +7820,9 @@ def cholesky_test_helper(n, batch_dims, upper):
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@tf32_on_and_off(0.01)
def test_cholesky(self, device, dtype):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you move this test to test_linalg.py? It will need a new name, like "test_torch_cholesky" or "test_old_cholesky."

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, is there a change to this test that would demonstrate your lda fix?

Also also, if you spot other Cholesky tests let's move them, too.

from torch.testing._internal.common_utils import \
(random_symmetric_pd_matrix,
random_fullrank_matrix_distinct_singular_value)
from torch.testing._internal.common_utils import random_hermitian_pd_matrix

# This is a workaround while there is no support for complex random_symmetric_pd_matrix
if dtype.is_complex:
kurtamohler marked this conversation as resolved.
Show resolved Hide resolved
real_dtype = torch.float32 if dtype is torch.complex64 else torch.float64
A_real = random_fullrank_matrix_distinct_singular_value(10, dtype=real_dtype, device=device)
A_imag = random_fullrank_matrix_distinct_singular_value(10, dtype=real_dtype, device=device)
A = A_real + 1j * A_imag
A = A @ A.t().conj()
else:
A = random_symmetric_pd_matrix(10, dtype=dtype, device=device)
A = random_hermitian_pd_matrix(10, dtype=dtype, device=device)

# default Case
C = torch.cholesky(A)
Expand Down
3 changes: 3 additions & 0 deletions tools/autograd/derivatives.yaml
Expand Up @@ -294,6 +294,9 @@
- name: cholesky(Tensor self, bool upper=False) -> Tensor
self: cholesky_backward(grad, upper, result)

- name: linalg_cholesky(Tensor self) -> Tensor
self: cholesky_backward(grad, false, result)

- name: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor
self, input2: cholesky_solve_backward(grad, self, input2, result, upper)

Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/gen_variable_type.py
Expand Up @@ -162,7 +162,7 @@
'cosh', '__rmul__', 'sgn', 'asin', 'acos', 'sub', 'div', 'cat', 'view_as_complex',
'neg', 'complex', 'select', '_s_where', 'as_strided', 'slice', 'constant_pad_nd',
'unbind', 'split', 'split_with_sizes', 'unsafe_split', 'split_with_sizes_backward',
'dot', 'vdot', 'cholesky'
'dot', 'vdot', 'cholesky', 'linalg_cholesky'
}

# Some operators invalidate the grad_accumulator. Let's reset it.
Expand Down
18 changes: 18 additions & 0 deletions torch/csrc/api/include/torch/linalg.h
Expand Up @@ -8,6 +8,10 @@ namespace linalg {
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {

inline Tensor cholesky(const Tensor& self) {
return torch::linalg_cholesky(self);
}

inline Tensor det(const Tensor& self) {
return torch::linalg_det(self);
}
Expand All @@ -31,6 +35,20 @@ inline Tensor& norm_out(Tensor& result, const Tensor& self, std::string ord, opt
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */

/// Cholesky decomposition
///
/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.cholesky
///
/// Example:
/// ```
/// auto A = torch::randn({4, 4});
/// auto A = torch::matmul(A, A.t());
/// auto L = torch::linalg::cholesky(A);
/// assert(torch::allclose(torch::matmul(L, L.t()), A));
/// ```
inline Tensor cholesky(const Tensor& self) {
return detail::cholesky(self);
}

/// See the documentation of torch.linalg.det
inline Tensor linalg_det(const Tensor& self) {
Expand Down
45 changes: 45 additions & 0 deletions torch/linalg/__init__.py
Expand Up @@ -8,6 +8,51 @@
# Note: This not only adds doc strings for functions in the linalg namespace, but
# also connects the torch.linalg Python namespace to the torch._C._linalg builtins.

cholesky = _add_docstr(_linalg.linalg_cholesky, r"""
linalg.cholesky(input) -> Tensor

Returns the Cholesky decomposition.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think combining the first few paragraphs would be helpful. What about something like:

"Computes the Cholesky decomposition of a Hermitian positive-definite matrix or the Cholesky decompositions of a batch of such matrices. If the matrices are real-valued then each of their Cholesky decompositions can be written as A = LL^t, for a lower triangular matrix L, and this function returns the matrix L for each input matrix. If the matrices are complex-valued then their Cholesky decompositions are A = L @ L.H for a lower triangular matrix L, where L.H is the conjugate transpose of L."


Computes the Cholesky decomposition of a Hermitian (or symmetric for real-valued matrices)
kurtamohler marked this conversation as resolved.
Show resolved Hide resolved
positive-definite matrix :math:`A` or for batches of Hermitian positive-definite matrices.
The returned matrix ``L`` is lower-triangular, and
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
the decomposition has the form:

.. math::

A = LL^H

If :attr:`input` is a batch of Hermitian positive-definite
matrices, then the returned tensor will be composed of lower-triangular Cholesky factors
of each of the individual matrices.

.. note:: If the :attr:`input` is not Hermitian positive-definite matrix a RuntimeError is raised
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This note also needs to account for the batching behavior. Maybe something like:

"If :attr:`input` is not a Hermitian positive-definite matrix, or if it's a batch of matrices and one of them is not a Hermitian positive-definite matrix, then a RuntimeError will be thrown."

As for the error behavior: does it specify which matrix in the batch was discovered to not be a Hermitian positive-definite matrix, too? Another option here would be to not elaborate about the content of the error message.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great; let's be clear about that (and maybe even add a test for it erroring on the correct batch element).

saying that the input is singular and mentioning which minor of the input matrix is not positive-definite.

.. note::
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
Supports real and complex inputs.
Backpropagation for complex inputs is only supported on the CPU.

Copy link
Collaborator

@mruberry mruberry Nov 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a warning that this causes cross-device synchronization when called on CUDA inputs

cc @heitorschueroff who's looking into cross-device data movement in operations that use MAGMA

Args:
input (Tensor): the input tensor :math:`A` of size :math:`(*, n, n)` where `*` is zero or more
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling the tensor "A" here isn't helpful and possibly misleading if input is a batch of matrices. Maybe you can say, "a Hermitian positive-definite matrix or batch of such matrices."?

If multiple matrices are passed as input, does each of them have to have the same size? That requirement (or lack thereof) should be mentioned somewhere in the documents.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove the "A".
The input tensor must have the shape of the form (*, n, n), the last two dimensions are equal, i.e. batch of square matrices. There is a check for that and tests. I could add that matrices should be square, but that's already implicit by calling a matrix positive-definite.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And it's not technically possible to pass multiple matrices as the input of different size, as we don't support list of tensors and consider the last two dimensions of a tensor as a matrix we process.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. Good point. I've been in too many discussions about nested tensors, which can be ragged, recently. My mistake.

batch dimensions consisting of symmetric positive-definite matrices.

Example::

>>> a = torch.randn(2, 2, dtype=torch.complex128)
>>> a = torch.mm(a, a.t().conj()) # To make a Hermitian
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
>>> l = torch.linalg.cholesky(a)
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
>>> a
tensor([[2.5266+0.0000j, 1.9586-2.0626j],
[1.9586+2.0626j, 9.4160+0.0000j]], dtype=torch.complex128)
>>> l
tensor([[1.5895+0.0000j, 0.0000+0.0000j],
[1.2322+1.2976j, 2.4928+0.0000j]], dtype=torch.complex128)
>>> torch.mm(l, l.t().conj())
tensor([[2.5266+0.0000j, 1.9586-2.0626j],
[1.9586+2.0626j, 9.4160+0.0000j]], dtype=torch.complex128)
""")

det = _add_docstr(_linalg.linalg_det, r"""
linalg.det(input) -> Tensor

Expand Down
8 changes: 8 additions & 0 deletions torch/testing/_internal/common_utils.py
Expand Up @@ -1528,6 +1528,14 @@ def random_symmetric_pd_matrix(matrix_size, *batch_dims, **kwargs):
+ torch.eye(matrix_size, dtype=dtype, device=device) * 1e-5


def random_hermitian_pd_matrix(matrix_size, *batch_dims, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment describing what this function does and how to call it.

Copy link
Collaborator

@mruberry mruberry Oct 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the signature of this function should be: (matrix_size, *, batch_dims, dtype, device).

Since it's an internal testing function it probably shouldn't have default values. Each caller should be responsible for correctly setting dtype and device.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I can change that. I was just following the other existing functions.

def random_symmetric_pd_matrix(matrix_size, *batch_dims, **kwargs):
    dtype = kwargs.get('dtype', torch.double)
    device = kwargs.get('device', 'cpu')
    A = torch.randn(*(batch_dims + (matrix_size, matrix_size)),
                    dtype=dtype, device=device)
    return torch.matmul(A, A.transpose(-2, -1))

In fact, all calls to random_symmetric_pd_matrix can be safely replaced by random_hermitian_pd_matrix, probably in a follow-up PR.

dtype = kwargs.get('dtype', torch.double)
device = kwargs.get('device', 'cpu')
A = torch.randn(*(batch_dims + (matrix_size, matrix_size)),
dtype=dtype, device=device)
return torch.matmul(A, A.transpose(-2, -1).conj())


def make_nonzero_det(A, sign=None, min_singular_value=0.1):
u, s, v = A.svd()
s.clamp_(min=min_singular_value)
Expand Down