Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added linalg.cholesky #46083

Closed
wants to merge 62 commits into from
Closed
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
6b7f14a
wip linalg.cholesky
IvanYashchuk Oct 7, 2020
f7a08f4
Added xfailed test case
IvanYashchuk Oct 8, 2020
32e10f8
Added cholesky to csrc/api/include/torch/linalg.h
IvanYashchuk Oct 8, 2020
3c4d5a4
Updated example in docs
IvanYashchuk Oct 9, 2020
307020e
Added random_hermitian_pd_matrix for the test
IvanYashchuk Oct 9, 2020
9e9e0c0
Use random_hermitian_pd_matrix in the test_torch/cholesky
IvanYashchuk Oct 9, 2020
ae4f3ee
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Oct 11, 2020
1798832
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Oct 13, 2020
e74de2c
No need for skip if numpy not found anymore
IvanYashchuk Oct 13, 2020
9343678
Added larger input case
IvanYashchuk Oct 13, 2020
604f0a8
Added assertRaises tests for cholesky
IvanYashchuk Oct 13, 2020
5f80042
Added a note to the docs that the error is given if the input is not
IvanYashchuk Oct 13, 2020
da4e88b
Enabled autograd for linalg_cholesky
IvanYashchuk Oct 13, 2020
efb725c
Added a note to the documentation about complex support
IvanYashchuk Oct 13, 2020
6297b6b
Added the out= variant
IvanYashchuk Oct 14, 2020
709273b
Moved error checks to a separate test
IvanYashchuk Oct 14, 2020
21cfca0
Added xfailed test for cholesky cuda autograd
IvanYashchuk Oct 14, 2020
31cbe75
Only complex128 is needed for autograd test
IvanYashchuk Oct 14, 2020
df0172e
Added a docstring for random_hermitian_pd_matrix
IvanYashchuk Oct 14, 2020
b54c7f8
Updated linalg.cholesky docs
IvanYashchuk Oct 14, 2020
7878183
Added a note on error message for batch of singular matrices
IvanYashchuk Oct 15, 2020
0d4a8c7
In tests compare norms of the resulting matrices
IvanYashchuk Oct 15, 2020
e4832d3
Added entry in overrides.py
IvanYashchuk Oct 15, 2020
e800b97
Added test case for batch singular input
IvanYashchuk Oct 15, 2020
8b58586
Moved tests for torch.cholesky to test_linalg.py
IvanYashchuk Oct 15, 2020
7aabce3
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Oct 15, 2020
12d11ce
Added a dispatch section with DefaultBackend in native_functions.yaml
IvanYashchuk Oct 15, 2020
97501c6
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Oct 15, 2020
29f94c9
gradgradcheck for cholesky now works
IvanYashchuk Oct 15, 2020
00c41ed
Updated documentation for linalg cholesky
IvanYashchuk Oct 19, 2020
34cef7d
Add one more assert for out= test
IvanYashchuk Oct 19, 2020
5d2230d
Updated test_cholesky_errors
IvanYashchuk Oct 19, 2020
63e922d
Added non contiguous test
IvanYashchuk Oct 19, 2020
6d64067
Make test_cholesky_autograd run on gpu for fp64
IvanYashchuk Oct 19, 2020
40935f1
Changed np.allclose -> torch.allclose
IvanYashchuk Oct 19, 2020
aaad340
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Oct 19, 2020
1d67cd0
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Oct 27, 2020
22be1a5
Fix long lines
IvanYashchuk Oct 28, 2020
d547b63
Remove unused import
IvanYashchuk Oct 28, 2020
fe89410
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Oct 29, 2020
9e2cfb4
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Nov 2, 2020
30239f2
Merge branch 'master' into linalg-cholesky
IvanYashchuk Nov 3, 2020
f723421
Use at::native::resize_output
IvanYashchuk Nov 3, 2020
1b46c19
Merge branch 'master' into linalg-cholesky
IvanYashchuk Nov 4, 2020
10d276a
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Nov 6, 2020
e224c97
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Nov 10, 2020
048316d
Added a warning about data movement for cuda inputs
IvanYashchuk Nov 10, 2020
0a13e1d
Use typed std::max
IvanYashchuk Nov 10, 2020
173fea1
Replaced torch.allclose with self.assertEqual
IvanYashchuk Nov 10, 2020
b26f52c
Removed unused import
IvanYashchuk Nov 10, 2020
ba3708e
Fix imports
IvanYashchuk Nov 10, 2020
5f0abff
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Nov 11, 2020
b670f26
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Nov 12, 2020
f2ee1f6
Finish merge
IvanYashchuk Nov 12, 2020
344053f
Merge branch 'master' into linalg-cholesky
Nov 12, 2020
869b40a
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Nov 12, 2020
d52b83b
Differentiation of complex cholesky on cuda now works for single input;
IvanYashchuk Nov 12, 2020
c98076b
Merge branch 'master' into linalg-cholesky
IvanYashchuk Nov 12, 2020
287c878
Batched matmul for complex on CUDA is implemented now; fix tests
IvanYashchuk Nov 13, 2020
03b29a3
Merge remote-tracking branch 'upstream/master' into linalg-cholesky
IvanYashchuk Nov 13, 2020
88e23c3
Remove redundant code from test_old_cholesky_batched
IvanYashchuk Nov 13, 2020
eb507e5
flake8 fix
IvanYashchuk Nov 13, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 17 additions & 1 deletion aten/src/ATen/native/BatchLinearAlgebra.cpp
Expand Up @@ -5,6 +5,7 @@
#include <ATen/ExpandUtils.h>

#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/Resize.h>
#include <ATen/native/cpu/zmath.h>
#include <ATen/Parallel.h>

Expand Down Expand Up @@ -535,11 +536,12 @@ static void apply_cholesky(Tensor& self, bool upper, std::vector<int64_t>& infos
auto self_matrix_stride = matrixStride(self);
auto batch_size = batchCount(self);
auto n = self.size(-2);
auto lda = std::max<int64_t>(1, n);

int info;
for (int64_t i = 0; i < batch_size; i++) {
scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
lapackCholesky<scalar_t>(uplo, n, self_working_ptr, n, &info);
lapackCholesky<scalar_t>(uplo, n, self_working_ptr, lda, &info);
infos[i] = info;
if (info != 0) {
return;
Expand Down Expand Up @@ -584,6 +586,20 @@ Tensor& cholesky_out(Tensor &result, const Tensor &self, bool upper) {
return result;
}

Tensor linalg_cholesky(const Tensor &self) {
squareCheckInputs(self);
return at::_cholesky_helper(self, /*upper=*/false).tril_();
}

Tensor& linalg_cholesky_out(Tensor &result, const Tensor &self) {
mruberry marked this conversation as resolved.
Show resolved Hide resolved
TORCH_CHECK(result.scalar_type() == self.scalar_type(),
"result dtype ", result.scalar_type(), " does not match self dtype ", self.scalar_type());
Tensor result_tmp = at::linalg_cholesky(self);
at::native::resize_output(result, result_tmp.sizes());
result.copy_(result_tmp);
return result;
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template<typename scalar_t>
Expand Down
7 changes: 4 additions & 3 deletions aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
Expand Up @@ -1139,10 +1139,11 @@ AT_ERROR("cholesky: MAGMA library not found in "

auto self_data = self.data_ptr<scalar_t>();
magma_int_t n = magma_int_cast(self.size(-2), "self.size(-2)");
auto lda = std::max<magma_int_t>(1, n);

if (self.dim() == 2) {
magma_int_t info = 0;
magmaCholesky<scalar_t>(uplo, n, self_data, n, &info);
magmaCholesky<scalar_t>(uplo, n, self_data, lda, &info);
infos[0] = info;
} else {
auto self_mat_stride = matrixStride(self);
Expand Down Expand Up @@ -1173,14 +1174,14 @@ AT_ERROR("cholesky: MAGMA library not found in "
magma_int_t* info_array_cur = &info_array[mini_idx];

magmaCholeskyBatched<scalar_t>(
uplo, n, self_array_cur, n, info_array_cur, batch_limit, magma_queue);
uplo, n, self_array_cur, lda, info_array_cur, batch_limit, magma_queue);
}

// Compute whatever is left = batch_size - floor(batch_size / batch_limit) * batch_limit
// which concisely is equal to batch_size % batch_limit
if (batch_size % batch_limit != 0) {
magmaCholeskyBatched<scalar_t>(
uplo, n, &self_array[mini_idx], n, &info_array[mini_idx], batch_size % batch_limit, magma_queue);
uplo, n, &self_array[mini_idx], lda, &info_array[mini_idx], batch_size % batch_limit, magma_queue);
}

for (int64_t i = 0; i < batch_size; i++) {
Expand Down
13 changes: 13 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -9261,6 +9261,19 @@
#
# See linalg_det as an example.

- func: linalg_cholesky(Tensor self) -> Tensor
mruberry marked this conversation as resolved.
Show resolved Hide resolved
python_module: linalg
use_c10_dispatcher: full
variants: function
dispatch:
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
DefaultBackend: linalg_cholesky

- func: linalg_cholesky.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
python_module: linalg
variants: function
dispatch:
DefaultBackend: linalg_cholesky_out

# torch.linalg.det, alias for torch.det
- func: linalg_det(Tensor self) -> Tensor
python_module: linalg
Expand Down
1 change: 1 addition & 0 deletions docs/source/linalg.rst
Expand Up @@ -12,6 +12,7 @@ Common linear algebra operations.
Functions
---------

.. autofunction:: cholesky
.. autofunction:: det
.. autofunction:: norm
.. autofunction:: tensorsolve
24 changes: 0 additions & 24 deletions test/test_autograd.py
Expand Up @@ -2595,30 +2595,6 @@ def test_var_mean_differentiable(self):
torch.autograd.backward(r2, grad)
self.assertTrue(torch.allclose(input1.grad, input2.grad, rtol=0.01, atol=0.0))

@skipIfNoLapack
def test_cholesky(self):
def func(root, upper):
x = 0.5 * (root + root.transpose(-1, -2).conj())
return torch.cholesky(x, upper)

def run_test(upper, dims, dtype):
root = torch.rand(*dims, dtype=dtype, requires_grad=True)
root = root + torch.eye(dims[-1])

gradcheck(func, [root, upper])
gradgradcheck(func, [root, upper])

root = torch.rand(*dims, dtype=dtype)
root = torch.matmul(root, root.transpose(-1, -2).conj())
root.requires_grad_()
chol = root.cholesky().sum().backward()
self.assertEqual(root.grad, root.grad.transpose(-1, -2).conj()) # Check the gradient is hermitian

for upper, dims, dtype in product([True, False],
[(3, 3), (4, 3, 2, 2)],
[torch.double, torch.cdouble]):
run_test(upper, dims, dtype)

@skipIfNoLapack
def test_cholesky_solve(self):
def _test_with_size(A_dims, B_dims, upper):
Expand Down