From 4ed7f36ed181ee784f9904d5eabf073701e1fb78 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Sun, 22 Nov 2020 04:55:55 -0800 Subject: [PATCH] Added linalg.eigh, linalg.eigvalsh (#45526) Summary: This PR adds `torch.linalg.eigh`, and `torch.linalg.eigvalsh` for NumPy compatibility. The current `torch.symeig` uses (on CPU) a different LAPACK routine than NumPy (`syev` vs `syevd`). Even though it shouldn't matter in practice, `torch.linalg.eigh` uses `syevd` (as NumPy does). Ref https://github.com/pytorch/pytorch/issues/42666 Pull Request resolved: https://github.com/pytorch/pytorch/pull/45526 Reviewed By: gchanan Differential Revision: D25022659 Pulled By: mruberry fbshipit-source-id: 3676b77a121c4b5abdb712ad06702ac4944e900a --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 180 +++++++++++++ aten/src/ATen/native/LinearAlgebraUtils.h | 12 +- .../ATen/native/cuda/BatchLinearAlgebra.cu | 15 ++ aten/src/ATen/native/native_functions.yaml | 31 +++ docs/source/linalg.rst | 2 + test/test_linalg.py | 253 ++++++++++++++++++ test/test_namedtuple_return_api.py | 19 +- tools/autograd/derivatives.yaml | 6 + tools/autograd/gen_variable_type.py | 3 +- torch/csrc/api/include/torch/linalg.h | 38 +++ torch/csrc/autograd/FunctionsManual.cpp | 25 +- torch/linalg/__init__.py | 125 +++++++++ torch/overrides.py | 2 + 13 files changed, 689 insertions(+), 22 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 270ffeeee5c5..37b7c5bbb223 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -72,6 +72,12 @@ extern "C" void cheev_(char *jobz, char *uplo, int *n, std::complex *a, i extern "C" void dsyev_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *info); extern "C" void ssyev_(char *jobz, char *uplo, int *n, float *a, int *lda, float *w, float *work, int *lwork, int *info); +// syevd +extern "C" void zheevd_(char *jobz, char *uplo, int *n, std::complex *a, int *lda, double *w, std::complex *work, int *lwork, double *rwork, int *lrwork, int *iwork, int *liwork, int *info); +extern "C" void cheevd_(char *jobz, char *uplo, int *n, std::complex *a, int *lda, float *w, std::complex *work, int *lwork, float *rwork, int *lrwork, int *iwork, int *liwork, int *info); +extern "C" void dsyevd_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *iwork, int *liwork, int *info); +extern "C" void ssyevd_(char *jobz, char *uplo, int *n, float *a, int *lda, float *w, float *work, int *lwork, int *iwork, int *liwork, int *info); + // gesdd extern "C" void zgesdd_(char *jobz, int *m, int *n, std::complex *a, int *lda, double *s, std::complex *u, int *ldu, std::complex *vt, int *ldvt, std::complex *work, int *lwork, double *rwork, int *iwork, int *info); @@ -122,6 +128,9 @@ void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scala template void lapackSymeig(char jobz, char uplo, int n, scalar_t *a, int lda, value_t *w, scalar_t *work, int lwork, value_t *rwork, int *info); +template +void lapackSyevd(char jobz, char uplo, int n, scalar_t *a, int lda, value_t *w, scalar_t *work, int lwork, value_t *rwork, int lrwork, int *iwork, int liwork, int *info); + template void lapackSvd(char jobz, int m, int n, scalar_t *a, int lda, value_t *s, scalar_t *u, int ldu, scalar_t *vt, int ldvt, scalar_t *work, int lwork, value_t *rwork, int *iwork, int *info); @@ -276,6 +285,26 @@ template<> void lapackSymeig(char jobz, char uplo, int n, float *a, int l ssyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info); } +template<> void lapackSyevd, double>(char jobz, char uplo, int n, c10::complex *a, int lda, double *w, c10::complex *work, int lwork, double *rwork, int lrwork, int *iwork, int liwork, int *info) { + zheevd_(&jobz, &uplo, &n, reinterpret_cast*>(a), &lda, w, reinterpret_cast*>(work), &lwork, rwork, &lrwork, iwork, &liwork, info); +} + +template<> void lapackSyevd, float>(char jobz, char uplo, int n, c10::complex *a, int lda, float *w, c10::complex *work, int lwork, float *rwork, int lrwork, int *iwork, int liwork, int *info) { + cheevd_(&jobz, &uplo, &n, reinterpret_cast*>(a), &lda, w, reinterpret_cast*>(work), &lwork, rwork, &lrwork, iwork, &liwork, info); +} + +template<> void lapackSyevd(char jobz, char uplo, int n, double *a, int lda, double *w, double *work, int lwork, double *rwork, int lrwork, int *iwork, int liwork, int *info) { + (void)rwork; // unused + (void)lrwork; // unused + dsyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info); +} + +template<> void lapackSyevd(char jobz, char uplo, int n, float *a, int lda, float *w, float *work, int lwork, float *rwork, int lrwork, int *iwork, int liwork, int *info) { + (void)rwork; // unused + (void)lrwork; // unused + ssyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info); +} + template<> void lapackSvd, double>(char jobz, int m, int n, c10::complex *a, int lda, double *s, c10::complex *u, int ldu, c10::complex *vt, int ldvt, c10::complex *work, int lwork, double *rwork, int *iwork, int *info) { zgesdd_(&jobz, &m, &n, reinterpret_cast*>(a), &lda, s, reinterpret_cast*>(u), &ldu, @@ -879,6 +908,157 @@ std::tuple qr_out(Tensor& Q, Tensor& R, const Tensor& self, boo return std::tuple(Q, R); } +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ syevd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// This function computes eigenvalues 'w' and eigenvectors 'v' of the input that is stored initially in 'v' +// The computation is done in-place: 'v' stores the input and will be overriden, 'w' should be an allocated empty array +// compute_v controls whether eigenvectors should be computed +// uplo_str controls the portion of input matrix to consider in computations, allowed values are "u", "U", "l", "L" +// infos is used to store information for possible checks for error +// This function doesn't do any error checks and it's assumed that every argument is valid +template +static void apply_syevd(Tensor& w, Tensor& v, bool compute_v, const std::string& uplo_str, std::vector& infos) { +#ifndef USE_LAPACK + AT_ERROR("syevd: LAPACK library not found in compilation"); +#else + using value_t = typename c10::scalar_value_type::type; + + auto v_data = v.data_ptr(); + auto w_data = w.data_ptr(); + auto v_matrix_stride = matrixStride(v); + auto w_stride = w.size(-1); + auto batch_size = batchCount(v); + auto n = v.size(-1); + auto lda = std::max(int64_t{1}, n); + + // NumPy allows lowercase input for UPLO argument + // It is assumed that uplo_str is either "U" or "L" + char uplo = std::toupper(uplo_str[0]); + char jobz = compute_v ? 'V' : 'N'; + + // Using 'int' instead of int32_t or int64_t is consistent with the current LAPACK interface + // It really should be changed in the future to something like lapack_int that depends on the specific LAPACK library that is linked + // or switch to supporting only 64-bit indexing by default. + int info; + int lwork = -1; + int lrwork = -1; + int liwork = -1; + scalar_t work_query; + value_t rwork_query; + int iwork_query; + + // Run lapackSyevd once, first to get the optimum work size. + // Since we deal with batches of matrices with the same dimensions, doing this outside + // the main loop saves (batch_size - 1) workspace queries which would provide the same result + // and (batch_size - 1) calls to allocate and deallocate workspace using at::empty() + lapackSyevd(jobz, uplo, n, v_data, lda, w_data, &work_query, lwork, &rwork_query, lrwork, &iwork_query, liwork, &info); + + lwork = std::max(1, real_impl(work_query)); + Tensor work = at::empty({lwork}, v.options()); + liwork = std::max(1, iwork_query); + Tensor iwork = at::empty({liwork}, at::kInt); + + Tensor rwork; + value_t* rwork_data = nullptr; + if (isComplexType(at::typeMetaToScalarType(v.dtype()))) { + lrwork = std::max(1, rwork_query); + rwork = at::empty({lrwork}, w.options()); + rwork_data = rwork.data_ptr(); + } + + // Now call lapackSyevd for each matrix in the batched input + for (auto i = decltype(batch_size){0}; i < batch_size; i++) { + scalar_t* v_working_ptr = &v_data[i * v_matrix_stride]; + value_t* w_working_ptr = &w_data[i * w_stride]; + lapackSyevd(jobz, uplo, n, v_working_ptr, lda, w_working_ptr, work.data_ptr(), lwork, rwork_data, lrwork, iwork.data_ptr(), liwork, &info); + infos[i] = info; + // The current behaviour for Linear Algebra functions to raise an error if something goes wrong or input doesn't satisfy some requirement + // therefore return early since further computations will be wasted anyway + if (info != 0) { + return; + } + } +#endif +} + +// This function computes eigenvalues 'w' and eigenvectors 'v' of the tensor 'self' +// compute_eigenvectors controls whether eigenvectors should be computed +// uplo controls the portion of input matrix to consider in computations, allowed values are "u", "U", "l", "L" +// This function prepares correct input for 'apply_syevd' and checks for possible errors using 'infos' +std::tuple _syevd_helper_cpu(const Tensor& self, bool compute_eigenvectors, std::string uplo) { + std::vector infos(batchCount(self), 0); + + auto self_sizes = self.sizes().vec(); + self_sizes.pop_back(); + ScalarType dtype = toValueType(typeMetaToScalarType(self.dtype())); + auto eigvals = at::empty(self_sizes, self.options().dtype(dtype)); + + auto eigvecs = cloneBatchedColumnMajor(self); + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "syevd_cpu", [&]{ + apply_syevd(eigvals, eigvecs, compute_eigenvectors, uplo, infos); + }); + + if (self.dim() > 2) { + batchCheckErrors(infos, "syevd_cpu"); + } else { + singleCheckErrors(infos[0], "syevd_cpu"); + } + if (compute_eigenvectors) { + return std::tuple(eigvals, eigvecs); + } else { + return std::tuple(eigvals, at::empty({0}, self.options())); + } +} + +std::tuple linalg_eigh(const Tensor& self, std::string uplo) { + squareCheckInputs(self); + checkUplo(uplo); + return at::_syevd_helper(self, /*compute_eigenvectors=*/true, uplo); +} + +// TODO: it's possible to make the _out variant to be a primal function and implement linalg_eigh on top of _out +// TODO: implement _out variant avoiding copy and using already allocated storage directly +std::tuple linalg_eigh_out(Tensor& eigvals, Tensor& eigvecs, const Tensor& self, std::string uplo) { + TORCH_CHECK(eigvecs.scalar_type() == self.scalar_type(), + "eigvecs dtype ", eigvecs.scalar_type(), " does not match self dtype ", self.scalar_type()); + ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype())); + TORCH_CHECK(eigvals.scalar_type() == real_dtype, + "eigvals dtype ", eigvals.scalar_type(), " does not match self dtype ", real_dtype); + + Tensor eigvals_tmp, eigvecs_tmp; + std::tie(eigvals_tmp, eigvecs_tmp) = at::linalg_eigh(self, uplo); + + at::native::resize_output(eigvals, eigvals_tmp.sizes()); + eigvals.copy_(eigvals_tmp); + at::native::resize_output(eigvecs, eigvecs_tmp.sizes()); + eigvecs.copy_(eigvecs_tmp); + + return std::tuple(eigvals, eigvecs); +} + +Tensor linalg_eigvalsh(const Tensor& self, std::string uplo) { + squareCheckInputs(self); + checkUplo(uplo); + Tensor eigvals, eigvecs; + std::tie(eigvals, eigvecs) = at::_syevd_helper(self, /*compute_eigenvectors=*/false, uplo); + return eigvals; +} + +// TODO: it's possible to make the _out variant to be a primal function and implement linalg_eigvalsh on top of _out +// TODO: implement _out variant avoiding copy and using already allocated storage directly +Tensor& linalg_eigvalsh_out(Tensor& result, const Tensor& self, std::string uplo) { + ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype())); + TORCH_CHECK(result.scalar_type() == real_dtype, + "result dtype ", result.scalar_type(), " does not match self dtype ", real_dtype); + + Tensor result_tmp = at::linalg_eigvalsh(self, uplo); + + at::native::resize_output(result, result_tmp.sizes()); + result.copy_(result_tmp); + + return result; +} + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ symeig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h index ff7060323624..2d2ab8cc4cd1 100644 --- a/aten/src/ATen/native/LinearAlgebraUtils.h +++ b/aten/src/ATen/native/LinearAlgebraUtils.h @@ -7,6 +7,7 @@ #include #include #include +#include namespace at { namespace native { @@ -97,7 +98,7 @@ static inline void batchCheckErrors(std::vector& infos, const char* nam } else if (info > 0) { if (strstr(name, "svd")) { AT_ERROR(name, ": the updating process of SBDSDC did not converge (error: ", info, ")"); - } else if (strstr(name, "symeig")) { + } else if (strstr(name, "symeig") || strstr(name, "syevd")) { AT_ERROR(name, ": For batch ", i, ": the algorithm failed to converge; ", info, " off-diagonal elements of an intermediate tridiagonal form did not converge to zero."); } else if (!allow_singular) { @@ -333,4 +334,13 @@ static inline int64_t computeLRWorkDim(const char jobz, int64_t m, int64_t n) { return std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn); } +// This function checks whether the uplo argument input is valid +// Allowed strings are "u", "U", "l", "L" +static inline void checkUplo(const std::string& uplo) { + // To use std::toupper safely with plain chars (or signed chars), the argument should first be converted to unsigned char + char uplo_uppercase = static_cast(std::toupper(static_cast(uplo[0]))); + TORCH_CHECK(uplo.size() == 1 && (uplo_uppercase == 'U' || uplo_uppercase == 'L'), + "Expected UPLO argument to be 'L' or 'U', but got ", uplo); +} + }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index 5f52a4fa2a51..e5d0b689eb6e 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -1783,6 +1783,21 @@ std::tuple _symeig_helper_cuda(const Tensor& self, bool eigenvec } } +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ syevd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// This function computes eigenvalues 'w' and eigenvectors 'v' of the tensor 'self' +// compute_eigenvectors controls whether eigenvectors should be computed +// uplo controls the portion of input matrix to consider in computations, allowed values are "u", "U", "l", "L" +// '_symeig_helper_cuda' prepares correct input for 'apply_symeig' and checks for possible errors using 'infos' +// See also CPU implementation in aten/src/ATen/native/BatchLinearAlgebra.cpp +std::tuple _syevd_helper_cuda(const Tensor& self, bool compute_eigenvectors, std::string uplo_str) { + // NumPy allows lowercase input for UPLO argument + // It is assumed that uplo_str is either "U" or "L" + char uplo = std::toupper(uplo_str[0]); + bool upper = uplo == 'U' ? true : false; + return _symeig_helper_cuda(self, compute_eigenvectors, upper); +} + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 2cda27b7b68d..2d5cc1464946 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -9373,6 +9373,37 @@ dispatch: DefaultBackend: det +- func: _syevd_helper(Tensor self, bool compute_eigenvectors, str uplo) -> (Tensor, Tensor) + use_c10_dispatcher: full + variants: function + dispatch: + CPU: _syevd_helper_cpu + CUDA: _syevd_helper_cuda + +- func: linalg_eigh(Tensor self, str UPLO="L") -> (Tensor eigenvalues, Tensor eigenvectors) + python_module: linalg + use_c10_dispatcher: full + variants: function + dispatch: + DefaultBackend: linalg_eigh + +- func: linalg_eigh.eigvals(Tensor self, str UPLO="L", *, Tensor(a!) eigvals, Tensor(b!) eigvecs) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) + python_module: linalg + dispatch: + DefaultBackend: linalg_eigh_out + +- func: linalg_eigvalsh(Tensor self, str UPLO="L") -> Tensor + python_module: linalg + use_c10_dispatcher: full + variants: function + dispatch: + DefaultBackend: linalg_eigvalsh + +- func: linalg_eigvalsh.out(Tensor self, str UPLO='L', *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + dispatch: + DefaultBackend: linalg_eigvalsh_out + # torch.outer, alias for torch.ger - func: outer(Tensor self, Tensor vec2) -> Tensor use_c10_dispatcher: full diff --git a/docs/source/linalg.rst b/docs/source/linalg.rst index b420684c84bb..eb7b3c120c61 100644 --- a/docs/source/linalg.rst +++ b/docs/source/linalg.rst @@ -14,6 +14,8 @@ Functions .. autofunction:: cholesky .. autofunction:: det +.. autofunction:: eigh +.. autofunction:: eigvalsh .. autofunction:: norm .. autofunction:: tensorinv .. autofunction:: tensorsolve diff --git a/test/test_linalg.py b/test/test_linalg.py index e651bc7e96ec..37a5c2d9d2fc 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -449,6 +449,259 @@ def test_det(self, device, dtype): with self.assertRaises(RuntimeError): op(t) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) + def test_eigh(self, device, dtype): + from torch.testing._internal.common_utils import random_hermitian_matrix + + def run_test(shape, batch, uplo): + matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device) + expected_w, expected_v = np.linalg.eigh(matrix.cpu().numpy(), UPLO=uplo) + actual_w, actual_v = torch.linalg.eigh(matrix, UPLO=uplo) + self.assertEqual(actual_w, expected_w) + # sign of eigenvectors is not unique and therefore absolute values are compared + self.assertEqual(abs(actual_v), abs(expected_v)) + # additionally we can flip the sign and then compare the values + # let's choose the convention that the first element of the eigenvector should be positive, + # otherwise flip the sign of the eigenvector + if matrix.numel() > 0: + sign = np.sign(expected_v[..., 0, :]).reshape(batch + (1, shape)) + expected_v = sign * expected_v + torch_real_slice = actual_v[..., 0, :].real if dtype.is_complex else actual_v[..., 0, :] + sign = torch.sign(torch_real_slice).reshape(batch + (1, shape)) + actual_v = sign * actual_v + self.assertEqual(actual_v, expected_v) + + # check the out= variant + out_w = torch.empty_like(actual_w) + out_v = torch.empty_like(actual_v) + ans_w, ans_v = torch.linalg.eigh(matrix, UPLO=uplo, out=(out_w, out_v)) + self.assertEqual(ans_w, out_w) + self.assertEqual(ans_v, out_v) + self.assertEqual(ans_w, actual_w) + self.assertEqual(abs(ans_v), abs(actual_v)) + + shapes = (0, 3, 5) + batches = ((), (3, ), (2, 2)) + uplos = ["U", "L"] + for shape, batch, uplo in itertools.product(shapes, batches, uplos): + run_test(shape, batch, uplo) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) + def test_eigh_lower_uplo(self, device, dtype): + def run_test(shape, batch, uplo): + # check lower case uplo + # use non-symmetric input to check whether uplo argument is working as intended + matrix = torch.randn(shape, shape, *batch, dtype=dtype, device=device) + expected_w, expected_v = np.linalg.eigh(matrix.cpu().numpy(), UPLO=uplo) + actual_w, actual_v = torch.linalg.eigh(matrix, UPLO=uplo) + self.assertEqual(actual_w, expected_w) + self.assertEqual(abs(actual_v), abs(expected_v)) + + uplos = ["u", "l"] + for uplo in uplos: + run_test(3, (2, 2), uplo) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_eigh_errors_and_warnings(self, device, dtype): + from torch.testing._internal.common_utils import random_hermitian_matrix + + # eigh requires a square matrix + t = torch.randn(2, 3, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): + torch.linalg.eigh(t) + + # eigh requires 'uplo' parameter to be 'U' or 'L' + t = torch.randn(3, 3, device=device, dtype=dtype) + for uplo in ["a", "wrong"]: + with self.assertRaisesRegex(RuntimeError, "be \'L\' or \'U\'"): + torch.linalg.eigh(t, UPLO=uplo) + with self.assertRaisesRegex(ValueError, "be \'L\' or \'U\'"): + np.linalg.eigh(t.cpu().numpy(), UPLO=uplo) + + # if non-empty out tensor with wrong shape is passed a warning is given + a = random_hermitian_matrix(3, dtype=dtype, device=device) + real_dtype = a.real.dtype if dtype.is_complex else dtype + out_w = torch.empty(7, 7, dtype=real_dtype, device=device) + out_v = torch.empty(7, 7, dtype=dtype, device=device) + with warnings.catch_warnings(record=True) as w: + # Trigger warning + torch.linalg.eigh(a, out=(out_w, out_v)) + # Check warning occurs + self.assertEqual(len(w), 2) + self.assertTrue("An output with one or more elements was resized" in str(w[-2].message)) + self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) + + # dtypes should match + out_w = torch.empty_like(a).to(torch.int) + out_v = torch.empty_like(a) + with self.assertRaisesRegex(RuntimeError, "dtype Int does not match self dtype"): + torch.linalg.eigh(a, out=(out_w, out_v)) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) + def test_eigh_non_contiguous(self, device, dtype): + from torch.testing._internal.common_utils import random_hermitian_matrix + + def run_test(matrix, uplo): + self.assertFalse(matrix.is_contiguous()) + expected_w, expected_v = np.linalg.eigh(matrix.cpu().numpy(), UPLO=uplo) + actual_w, actual_v = torch.linalg.eigh(matrix, UPLO=uplo) + self.assertEqual(actual_w, expected_w) + # sign of eigenvectors is not unique and therefore absolute values are compared + self.assertEqual(abs(actual_v), abs(expected_v)) + + def run_test_permuted(shape, batch, uplo): + # check for permuted / transposed inputs + matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device) + matrix = matrix.transpose(-2, -1) + run_test(matrix, uplo) + + def run_test_skipped_elements(shape, batch, uplo): + # check for inputs with skipped elements + matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device) + matrix = matrix[::2] + run_test(matrix, uplo) + + shapes = (3, 5) + batches = ((4, ), (4, 2)) + uplos = ["U", "L"] + for shape, batch, uplo in itertools.product(shapes, batches, uplos): + run_test_permuted(shape, batch, uplo) + run_test_skipped_elements(shape, batch, uplo) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float64, torch.complex128) + def test_eigh_autograd(self, device, dtype): + from torch.testing._internal.common_utils import random_hermitian_matrix + + def func(x, uplo): + x = 0.5 * (x + x.conj().transpose(-2, -1)) + return torch.linalg.eigh(x, UPLO=uplo) + + def func_grad_w(x, uplo): + return func(x, uplo)[0] + + def func_grad_v(x, uplo): + # gauge invariant loss function + return abs(func(x, uplo)[1]) + + def run_test(dims, uplo): + x = torch.randn(*dims, dtype=dtype, device=device, requires_grad=True) + + gradcheck(func_grad_w, [x, uplo]) + gradgradcheck(func_grad_w, [x, uplo]) + + gradcheck(func_grad_v, [x, uplo]) + gradgradcheck(func_grad_v, [x, uplo]) + + x = random_hermitian_matrix(dims[-1], *dims[:-2]).requires_grad_() + w, v = torch.linalg.eigh(x) + (w.sum() + abs(v).sum()).backward() + self.assertEqual(x.grad, x.grad.conj().transpose(-1, -2)) # Check the gradient is Hermitian + + for dims, uplo in itertools.product([(3, 3), (2, 3, 3)], ["L", "U"]): + run_test(dims, uplo) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) + def test_eigvalsh(self, device, dtype): + from torch.testing._internal.common_utils import random_hermitian_matrix + + def run_test(shape, batch, uplo): + matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device) + expected_w = np.linalg.eigvalsh(matrix.cpu().numpy(), UPLO=uplo) + actual_w = torch.linalg.eigvalsh(matrix, UPLO=uplo) + self.assertEqual(actual_w, expected_w) + + # check the out= variant + out = torch.empty_like(actual_w) + ans = torch.linalg.eigvalsh(matrix, UPLO=uplo, out=out) + self.assertEqual(ans, out) + self.assertEqual(ans, actual_w) + + shapes = (0, 3, 5) + batches = ((), (3, ), (2, 2)) + uplos = ["U", "L"] + for shape, batch, uplo in itertools.product(shapes, batches, uplos): + run_test(shape, batch, uplo) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_eigvalsh_errors_and_warnings(self, device, dtype): + # eigvalsh requires a square matrix + t = torch.randn(2, 3, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): + torch.linalg.eigvalsh(t) + + # eigvalsh requires 'uplo' parameter to be 'U' or 'L' + t = torch.randn(3, 3, device=device, dtype=dtype) + for uplo in ["a", "wrong"]: + with self.assertRaisesRegex(RuntimeError, "be \'L\' or \'U\'"): + torch.linalg.eigvalsh(t, UPLO=uplo) + with self.assertRaisesRegex(ValueError, "be \'L\' or \'U\'"): + np.linalg.eigvalsh(t.cpu().numpy(), UPLO=uplo) + + # if non-empty out tensor with wrong shape is passed a warning is given + real_dtype = t.real.dtype if dtype.is_complex else dtype + out = torch.empty_like(t).to(real_dtype) + with warnings.catch_warnings(record=True) as w: + # Trigger warning + torch.linalg.eigvalsh(t, out=out) + # Check warning occurs + self.assertEqual(len(w), 1) + self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) + + # dtypes should match + out = torch.empty_like(t).to(torch.int) + with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match self dtype"): + torch.linalg.eigvalsh(t, out=out) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) + def test_eigvalsh_non_contiguous(self, device, dtype): + from torch.testing._internal.common_utils import random_hermitian_matrix + + def run_test(matrix, uplo): + self.assertFalse(matrix.is_contiguous()) + expected_w = np.linalg.eigvalsh(matrix.cpu().numpy(), UPLO=uplo) + actual_w = torch.linalg.eigvalsh(matrix, UPLO=uplo) + self.assertEqual(actual_w, expected_w) + + def run_test_permuted(shape, batch, uplo): + # check for permuted / transposed inputs + matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device) + matrix = matrix.transpose(-2, -1) + run_test(matrix, uplo) + + def run_test_skipped_elements(shape, batch, uplo): + # check for inputs with skipped elements + matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device) + matrix = matrix[::2] + run_test(matrix, uplo) + + shapes = (3, 5) + batches = ((4, ), (4, 2)) + uplos = ["U", "L"] + for shape, batch, uplo in itertools.product(shapes, batches, uplos): + run_test_permuted(shape, batch, uplo) + run_test_skipped_elements(shape, batch, uplo) + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) def test_kron(self, device, dtype): diff --git a/test/test_namedtuple_return_api.py b/test/test_namedtuple_return_api.py index 3b3616f64220..88a01e48b5f2 100644 --- a/test/test_namedtuple_return_api.py +++ b/test/test_namedtuple_return_api.py @@ -12,7 +12,7 @@ all_operators_with_namedtuple_return = { 'max', 'min', 'median', 'nanmedian', 'mode', 'kthvalue', 'svd', 'symeig', 'eig', 'qr', 'geqrf', 'solve', 'slogdet', 'sort', 'topk', 'lstsq', - 'triangular_solve', 'cummax', 'cummin' + 'triangular_solve', 'cummax', 'cummin', 'linalg_eigh' } @@ -64,17 +64,24 @@ def test_namedtuple_return(self): op(operators=['symeig', 'eig'], input=(True,), names=('eigenvalues', 'eigenvectors'), hasout=True), op(operators=['triangular_solve'], input=(a,), names=('solution', 'cloned_coefficient'), hasout=True), op(operators=['lstsq'], input=(a,), names=('solution', 'QR'), hasout=True), + op(operators=['linalg_eigh'], input=("L",), names=('eigenvalues', 'eigenvectors'), hasout=True), ] for op in operators: for f in op.operators: - ret = getattr(a, f)(*op.input) - for i, name in enumerate(op.names): - self.assertIs(getattr(ret, name), ret[i]) - if op.hasout: - ret1 = getattr(torch, f)(a, *op.input, out=tuple(ret)) + if 'linalg_' in f: + ret = getattr(torch.linalg, f[7:])(a, *op.input) + ret1 = getattr(torch.linalg, f[7:])(a, *op.input, out=tuple(ret)) for i, name in enumerate(op.names): self.assertIs(getattr(ret, name), ret[i]) + else: + ret = getattr(a, f)(*op.input) + for i, name in enumerate(op.names): + self.assertIs(getattr(ret, name), ret[i]) + if op.hasout: + ret1 = getattr(torch, f)(a, *op.input, out=tuple(ret)) + for i, name in enumerate(op.names): + self.assertIs(getattr(ret, name), ret[i]) all_covered_operators = set([x for y in operators for x in y.operators]) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 62820647db00..dadfe6018939 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1044,6 +1044,12 @@ - name: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors) self: symeig_backward(grads, self, eigenvectors, upper, eigenvalues, eigenvectors_return) +- name: linalg_eigh(Tensor self, str UPLO="L") -> (Tensor eigenvalues, Tensor eigenvectors) + self: symeig_backward(grads, self, /*eigenvectors=*/true, /*upper=*/true, eigenvalues, eigenvectors) + +- name: linalg_eigvalsh(Tensor self, str UPLO="L") -> Tensor + self: non_differentiable + - name: t(Tensor(a) self) -> Tensor(a) self: grad.t() diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 7233c0cd3abc..b149194dc7d6 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -77,7 +77,8 @@ 'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'ger', 'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal', 'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_', - 'exp', 'nonzero', 'mean', 'inverse', 'solve', 'linalg_cholesky', 'addcmul', 'addcdiv' + 'exp', 'nonzero', 'mean', 'inverse', 'solve', 'linalg_cholesky', 'addcmul', 'addcdiv', + 'linalg_eigh', } # Some operators invalidate the grad_accumulator. Let's reset it. diff --git a/torch/csrc/api/include/torch/linalg.h b/torch/csrc/api/include/torch/linalg.h index 0ef9e52314b1..dd44e16ff3f1 100644 --- a/torch/csrc/api/include/torch/linalg.h +++ b/torch/csrc/api/include/torch/linalg.h @@ -20,6 +20,22 @@ inline Tensor det(const Tensor& self) { return torch::linalg_det(self); } +inline std::tuple eigh(const Tensor& self, std::string uplo) { + return torch::linalg_eigh(self, uplo); +} + +inline std::tuple eigh_out(Tensor& eigvals, Tensor& eigvecs, const Tensor& self, std::string uplo) { + return torch::linalg_eigh_out(eigvals, eigvecs, self, uplo); +} + +inline Tensor eigvalsh(const Tensor& self, std::string uplo) { + return torch::linalg_eigvalsh(self, uplo); +} + +inline Tensor& eigvalsh_out(Tensor& result, const Tensor& self, std::string uplo) { + return torch::linalg_eigvalsh_out(result, self, uplo); +} + inline Tensor norm(const Tensor& self, optional opt_ord, optional opt_dim, bool keepdim, optional opt_dtype) { return torch::linalg_norm(self, opt_ord, opt_dim, keepdim, opt_dtype); } @@ -79,6 +95,28 @@ inline Tensor linalg_det(const Tensor& self) { return detail::det(self); } +/// Computes eigenvalues and eigenvectors +/// +/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.eigh +inline std::tuple eigh(const Tensor& self, std::string uplo) { + return detail::eigh(self, uplo); +} + +inline std::tuple eigh_out(Tensor& eigvals, Tensor& eigvecs, const Tensor& self, std::string uplo) { + return detail::eigh_out(eigvals, eigvecs, self, uplo); +} + +/// Computes eigenvalues +/// +/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.eigvalsh +inline Tensor eigvalsh(const Tensor& self, std::string uplo) { + return detail::eigvalsh(self, uplo); +} + +inline Tensor& eigvalsh_out(Tensor& result, const Tensor& self, std::string uplo) { + return detail::eigvalsh_out(result, self, uplo); +} + inline Tensor linalg_norm(const Tensor& self, optional opt_ord, optional opt_dim, bool keepdim, optional opt_dtype) { return detail::norm(self, opt_ord, opt_dim, keepdim, opt_dtype); } diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 5148665550bc..5e9a22f9ebcb 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1974,32 +1974,29 @@ Tensor symeig_backward(const std::vector &grads, cons auto glambda = grads[0]; auto gv = grads[1]; - auto vt = v.transpose(-2, -1); + auto vh = v.conj().transpose(-2, -1); Tensor result; if (gv.defined()) { Tensor F = lambda.unsqueeze(-2) - lambda.unsqueeze(-1); F.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY); F.pow_(-1); - if (inplace_is_vmap_compatible(F, gv)) { - F.mul_(at::matmul(vt, gv)); - } else { - F = F.mul(at::matmul(vt, gv)); - } - result = at::matmul(v, at::matmul(F, vt)); + result = at::matmul(v, at::matmul(F * at::matmul(vh, gv), vh)); } else { result = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } if (glambda.defined()) { - auto tmp = at::matmul(at::matmul(v, at::diag_embed(glambda, /*offset=*/0, /*dim1=*/-2, /*dim2=*/-1)), vt); - if (inplace_is_vmap_compatible(result, tmp)) { - result.add_(tmp); - } else { - result = result + tmp; - } + glambda = glambda.to(self.dtype()); + // computes v @ diag(glambda) @ vh + Tensor glambda_term = at::matmul(v * glambda.unsqueeze(-2), vh); + if (inplace_is_vmap_compatible(result, glambda_term)) { + result.add_(glambda_term); + } else { + result = result + glambda_term; + } } - return result.add(result.transpose(-2, -1)).mul_(0.5); + return result.add(result.conj().transpose(-2, -1)).mul_(0.5); } Tensor qr_backward(const std::vector &grads, const Tensor& self, diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index c88aeda4006d..edd4d8a8afa6 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -88,6 +88,131 @@ Alias of :func:`torch.det`. """) +eigh = _add_docstr(_linalg.linalg_eigh, r""" +linalg.eigh(input, UPLO='L') -> tuple(Tensor, Tensor) + +This function computes the eigenvalues and eigenvectors +of a complex Hermitian (or real symmetric) matrix, or batch of such matrices, :attr:`input`. +For a single matrix :attr:`input`, the tensor of eigenvalues :math:`w` and the tensor of eigenvectors :math:`V` +decompose the :attr:`input` such that :math:`\text{input} = V \text{diag}(w) V^H`, +where :math:`^H` is the conjugate transpose operation. + +Since the matrix or matrices in :attr:`input` are assumed to be Hermitian, the imaginary part of their diagonals +is always treated as zero. When :attr:`UPLO` is "L", its default value, only the lower triangular part of +each matrix is used in the computation. When :attr:`UPLO` is "U" only the upper triangular part of each matrix is used. + +Supports input of ``float``, ``double``, ``cfloat`` and ``cdouble`` data types. + +See :func:`torch.linalg.eigvalsh` for a related function that computes only eigenvalues, +however that function is not differentiable. + +.. note:: The eigenvalues of real symmetric or complex Hermitian matrices are always real. + +.. note:: The eigenvectors of matrices are not unique, so any eigenvector multiplied by a constant remains + a valid eigenvector. This function may compute different eigenvector representations on + different device types. Usually the difference is only in the sign of the eigenvector. + +.. note:: The eigenvalues/eigenvectors are computed using LAPACK/MAGMA routines ``_syevd`` and ``_heevd``. + This function always checks whether the call to LAPACK/MAGMA is successful + using ``info`` argument of ``_syevd``, ``_heevd`` and throws a RuntimeError if it isn't. + On CUDA this causes a cross-device memory synchronization. + +Args: + input (Tensor): the Hermitian :math:`n \times n` matrix or the batch + of such matrices of size :math:`(*, n, n)` where `*` is one or more batch dimensions. + UPLO ('L', 'U', optional): controls whether to use the upper-triangular or the lower-triangular part + of :attr:`input` in the computations. Default: ``'L'`` + +Returns: + (Tensor, Tensor): A namedtuple (eigenvalues, eigenvectors) containing + + - **eigenvalues** (*Tensor*): Shape :math:`(*, m)`. + The eigenvalues in ascending order. + - **eigenvectors** (*Tensor*): Shape :math:`(*, m, m)`. + The orthonormal eigenvectors of the ``input``. + +Examples:: + + >>> a = torch.randn(2, 2, dtype=torch.complex128) + >>> a = a + a.t().conj() # creates a Hermitian matrix + >>> a + tensor([[2.9228+0.0000j, 0.2029-0.0862j], + [0.2029+0.0862j, 0.3464+0.0000j]], dtype=torch.complex128) + >>> w, v = torch.linalg.eigh(a) + >>> w + tensor([0.3277, 2.9415], dtype=torch.float64) + >>> v + tensor([[-0.0846+-0.0000j, -0.9964+0.0000j], + [ 0.9170+0.3898j, -0.0779-0.0331j]], dtype=torch.complex128) + >>> torch.allclose(torch.matmul(v, torch.matmul(w.to(v.dtype).diag_embed(), v.t().conj())), a) + True + + >>> a = torch.randn(3, 2, 2, dtype=torch.float64) + >>> a = a + a.transpose(-2, -1) # creates a symmetric matrix + >>> w, v = torch.linalg.eigh(a) + >>> torch.allclose(torch.matmul(v, torch.matmul(w.diag_embed(), v.transpose(-2, -1))), a) + True +""") + +eigvalsh = _add_docstr(_linalg.linalg_eigvalsh, r""" +linalg.eigvalsh(input, UPLO='L') -> Tensor + +This function computes the eigenvalues of a complex Hermitian (or real symmetric) matrix, +or batch of such matrices, :attr:`input`. The eigenvalues are returned in ascending order. + +Since the matrix or matrices in :attr:`input` are assumed to be Hermitian, the imaginary part of their diagonals +is always treated as zero. When :attr:`UPLO` is "L", its default value, only the lower triangular part of +each matrix is used in the computation. When :attr:`UPLO` is "U" only the upper triangular part of each matrix is used. + +Supports input of ``float``, ``double``, ``cfloat`` and ``cdouble`` data types. + +See :func:`torch.linalg.eigh` for a related function that computes both eigenvalues and eigenvectors. + +.. note:: The eigenvalues of real symmetric or complex Hermitian matrices are always real. + +.. note:: The eigenvalues/eigenvectors are computed using LAPACK/MAGMA routines ``_syevd`` and ``_heevd``. + This function always checks whether the call to LAPACK/MAGMA is successful + using ``info`` argument of ``_syevd``, ``_heevd`` and throws a RuntimeError if it isn't. + On CUDA this causes a cross-device memory synchronization. + +.. note:: This function doesn't support backpropagation, please use :func:`torch.linalg.eigh` instead, + that also computes the eigenvectors. + +Args: + input (Tensor): the Hermitian :math:`n \times n` matrix or the batch + of such matrices of size :math:`(*, n, n)` where `*` is one or more batch dimensions. + UPLO ('L', 'U', optional): controls whether to use the upper-triangular or the lower-triangular part + of :attr:`input` in the computations. Default: ``'L'`` + +Examples:: + + >>> a = torch.randn(2, 2, dtype=torch.complex128) + >>> a = a + a.t().conj() # creates a Hermitian matrix + >>> a + tensor([[2.9228+0.0000j, 0.2029-0.0862j], + [0.2029+0.0862j, 0.3464+0.0000j]], dtype=torch.complex128) + >>> w = torch.linalg.eigvalsh(a) + >>> w + tensor([0.3277, 2.9415], dtype=torch.float64) + + >>> a = torch.randn(3, 2, 2, dtype=torch.float64) + >>> a = a + a.transpose(-2, -1) # creates a symmetric matrix + >>> a + tensor([[[ 2.8050, -0.3850], + [-0.3850, 3.2376]], + + [[-1.0307, -2.7457], + [-2.7457, -1.7517]], + + [[ 1.7166, 2.2207], + [ 2.2207, -2.0898]]], dtype=torch.float64) + >>> w = torch.linalg.eigvalsh(a) + >>> w + tensor([[ 2.5797, 3.4629], + [-4.1605, 1.3780], + [-3.1113, 2.7381]], dtype=torch.float64) +""") + norm = _add_docstr(_linalg.linalg_norm, r""" linalg.norm(input, ord=None, dim=None, keepdim=False, *, out=None, dtype=None) -> Tensor diff --git a/torch/overrides.py b/torch/overrides.py index 8b7187ef4d38..191dc9cb910e 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -329,6 +329,8 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.hsmm: lambda mat1, mat2: -1, torch.dstack: lambda tensors, out=None: -1, torch.eig: lambda input, eigenvectors=False, out=None: -1, + torch.linalg.eigh: lambda input, UPLO="L", out=None: -1, + torch.linalg.eigvalsh: lambda input, UPLO="L", out=None: -1, torch.einsum: lambda equation, *operands: -1, torch.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1),