Skip to content

Commit

Permalink
Added linalg.eigh, linalg.eigvalsh (#45526)
Browse files Browse the repository at this point in the history
Summary:
This PR adds `torch.linalg.eigh`, and `torch.linalg.eigvalsh` for NumPy compatibility.
The current `torch.symeig` uses (on CPU) a different LAPACK routine than NumPy (`syev` vs `syevd`). Even though it shouldn't matter in practice, `torch.linalg.eigh` uses `syevd` (as NumPy does).

Ref #42666

Pull Request resolved: #45526

Reviewed By: gchanan

Differential Revision: D25022659

Pulled By: mruberry

fbshipit-source-id: 3676b77a121c4b5abdb712ad06702ac4944e900a
  • Loading branch information
IvanYashchuk authored and facebook-github-bot committed Nov 22, 2020
1 parent b665490 commit 4ed7f36
Show file tree
Hide file tree
Showing 13 changed files with 689 additions and 22 deletions.
180 changes: 180 additions & 0 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Expand Up @@ -72,6 +72,12 @@ extern "C" void cheev_(char *jobz, char *uplo, int *n, std::complex<float> *a, i
extern "C" void dsyev_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *info);
extern "C" void ssyev_(char *jobz, char *uplo, int *n, float *a, int *lda, float *w, float *work, int *lwork, int *info);

// syevd
extern "C" void zheevd_(char *jobz, char *uplo, int *n, std::complex<double> *a, int *lda, double *w, std::complex<double> *work, int *lwork, double *rwork, int *lrwork, int *iwork, int *liwork, int *info);
extern "C" void cheevd_(char *jobz, char *uplo, int *n, std::complex<float> *a, int *lda, float *w, std::complex<float> *work, int *lwork, float *rwork, int *lrwork, int *iwork, int *liwork, int *info);
extern "C" void dsyevd_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *iwork, int *liwork, int *info);
extern "C" void ssyevd_(char *jobz, char *uplo, int *n, float *a, int *lda, float *w, float *work, int *lwork, int *iwork, int *liwork, int *info);

// gesdd
extern "C" void zgesdd_(char *jobz, int *m, int *n, std::complex<double> *a, int *lda,
double *s, std::complex<double> *u, int *ldu, std::complex<double> *vt, int *ldvt, std::complex<double> *work, int *lwork, double *rwork, int *iwork, int *info);
Expand Down Expand Up @@ -122,6 +128,9 @@ void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scala
template<class scalar_t, class value_t=scalar_t>
void lapackSymeig(char jobz, char uplo, int n, scalar_t *a, int lda, value_t *w, scalar_t *work, int lwork, value_t *rwork, int *info);

template<class scalar_t, class value_t=scalar_t>
void lapackSyevd(char jobz, char uplo, int n, scalar_t *a, int lda, value_t *w, scalar_t *work, int lwork, value_t *rwork, int lrwork, int *iwork, int liwork, int *info);

template<class scalar_t, class value_t=scalar_t>
void lapackSvd(char jobz, int m, int n, scalar_t *a, int lda,
value_t *s, scalar_t *u, int ldu, scalar_t *vt, int ldvt, scalar_t *work, int lwork, value_t *rwork, int *iwork, int *info);
Expand Down Expand Up @@ -276,6 +285,26 @@ template<> void lapackSymeig<float>(char jobz, char uplo, int n, float *a, int l
ssyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info);
}

template<> void lapackSyevd<c10::complex<double>, double>(char jobz, char uplo, int n, c10::complex<double> *a, int lda, double *w, c10::complex<double> *work, int lwork, double *rwork, int lrwork, int *iwork, int liwork, int *info) {
zheevd_(&jobz, &uplo, &n, reinterpret_cast<std::complex<double>*>(a), &lda, w, reinterpret_cast<std::complex<double>*>(work), &lwork, rwork, &lrwork, iwork, &liwork, info);
}

template<> void lapackSyevd<c10::complex<float>, float>(char jobz, char uplo, int n, c10::complex<float> *a, int lda, float *w, c10::complex<float> *work, int lwork, float *rwork, int lrwork, int *iwork, int liwork, int *info) {
cheevd_(&jobz, &uplo, &n, reinterpret_cast<std::complex<float>*>(a), &lda, w, reinterpret_cast<std::complex<float>*>(work), &lwork, rwork, &lrwork, iwork, &liwork, info);
}

template<> void lapackSyevd<double>(char jobz, char uplo, int n, double *a, int lda, double *w, double *work, int lwork, double *rwork, int lrwork, int *iwork, int liwork, int *info) {
(void)rwork; // unused
(void)lrwork; // unused
dsyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info);
}

template<> void lapackSyevd<float>(char jobz, char uplo, int n, float *a, int lda, float *w, float *work, int lwork, float *rwork, int lrwork, int *iwork, int liwork, int *info) {
(void)rwork; // unused
(void)lrwork; // unused
ssyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info);
}

template<> void lapackSvd<c10::complex<double>, double>(char jobz, int m, int n, c10::complex<double> *a, int lda,
double *s, c10::complex<double> *u, int ldu, c10::complex<double> *vt, int ldvt, c10::complex<double> *work, int lwork, double *rwork, int *iwork, int *info) {
zgesdd_(&jobz, &m, &n, reinterpret_cast<std::complex<double>*>(a), &lda, s, reinterpret_cast<std::complex<double>*>(u), &ldu,
Expand Down Expand Up @@ -879,6 +908,157 @@ std::tuple<Tensor&,Tensor&> qr_out(Tensor& Q, Tensor& R, const Tensor& self, boo
return std::tuple<Tensor&, Tensor&>(Q, R);
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ syevd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// This function computes eigenvalues 'w' and eigenvectors 'v' of the input that is stored initially in 'v'
// The computation is done in-place: 'v' stores the input and will be overriden, 'w' should be an allocated empty array
// compute_v controls whether eigenvectors should be computed
// uplo_str controls the portion of input matrix to consider in computations, allowed values are "u", "U", "l", "L"
// infos is used to store information for possible checks for error
// This function doesn't do any error checks and it's assumed that every argument is valid
template <typename scalar_t>
static void apply_syevd(Tensor& w, Tensor& v, bool compute_v, const std::string& uplo_str, std::vector<int64_t>& infos) {
#ifndef USE_LAPACK
AT_ERROR("syevd: LAPACK library not found in compilation");
#else
using value_t = typename c10::scalar_value_type<scalar_t>::type;

auto v_data = v.data_ptr<scalar_t>();
auto w_data = w.data_ptr<value_t>();
auto v_matrix_stride = matrixStride(v);
auto w_stride = w.size(-1);
auto batch_size = batchCount(v);
auto n = v.size(-1);
auto lda = std::max(int64_t{1}, n);

// NumPy allows lowercase input for UPLO argument
// It is assumed that uplo_str is either "U" or "L"
char uplo = std::toupper(uplo_str[0]);
char jobz = compute_v ? 'V' : 'N';

// Using 'int' instead of int32_t or int64_t is consistent with the current LAPACK interface
// It really should be changed in the future to something like lapack_int that depends on the specific LAPACK library that is linked
// or switch to supporting only 64-bit indexing by default.
int info;
int lwork = -1;
int lrwork = -1;
int liwork = -1;
scalar_t work_query;
value_t rwork_query;
int iwork_query;

// Run lapackSyevd once, first to get the optimum work size.
// Since we deal with batches of matrices with the same dimensions, doing this outside
// the main loop saves (batch_size - 1) workspace queries which would provide the same result
// and (batch_size - 1) calls to allocate and deallocate workspace using at::empty()
lapackSyevd<scalar_t, value_t>(jobz, uplo, n, v_data, lda, w_data, &work_query, lwork, &rwork_query, lrwork, &iwork_query, liwork, &info);

lwork = std::max<int>(1, real_impl<scalar_t, value_t>(work_query));
Tensor work = at::empty({lwork}, v.options());
liwork = std::max<int>(1, iwork_query);
Tensor iwork = at::empty({liwork}, at::kInt);

Tensor rwork;
value_t* rwork_data = nullptr;
if (isComplexType(at::typeMetaToScalarType(v.dtype()))) {
lrwork = std::max<int>(1, rwork_query);
rwork = at::empty({lrwork}, w.options());
rwork_data = rwork.data_ptr<value_t>();
}

// Now call lapackSyevd for each matrix in the batched input
for (auto i = decltype(batch_size){0}; i < batch_size; i++) {
scalar_t* v_working_ptr = &v_data[i * v_matrix_stride];
value_t* w_working_ptr = &w_data[i * w_stride];
lapackSyevd<scalar_t, value_t>(jobz, uplo, n, v_working_ptr, lda, w_working_ptr, work.data_ptr<scalar_t>(), lwork, rwork_data, lrwork, iwork.data_ptr<int>(), liwork, &info);
infos[i] = info;
// The current behaviour for Linear Algebra functions to raise an error if something goes wrong or input doesn't satisfy some requirement
// therefore return early since further computations will be wasted anyway
if (info != 0) {
return;
}
}
#endif
}

// This function computes eigenvalues 'w' and eigenvectors 'v' of the tensor 'self'
// compute_eigenvectors controls whether eigenvectors should be computed
// uplo controls the portion of input matrix to consider in computations, allowed values are "u", "U", "l", "L"
// This function prepares correct input for 'apply_syevd' and checks for possible errors using 'infos'
std::tuple<Tensor, Tensor> _syevd_helper_cpu(const Tensor& self, bool compute_eigenvectors, std::string uplo) {
std::vector<int64_t> infos(batchCount(self), 0);

auto self_sizes = self.sizes().vec();
self_sizes.pop_back();
ScalarType dtype = toValueType(typeMetaToScalarType(self.dtype()));
auto eigvals = at::empty(self_sizes, self.options().dtype(dtype));

auto eigvecs = cloneBatchedColumnMajor(self);
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "syevd_cpu", [&]{
apply_syevd<scalar_t>(eigvals, eigvecs, compute_eigenvectors, uplo, infos);
});

if (self.dim() > 2) {
batchCheckErrors(infos, "syevd_cpu");
} else {
singleCheckErrors(infos[0], "syevd_cpu");
}
if (compute_eigenvectors) {
return std::tuple<Tensor, Tensor>(eigvals, eigvecs);
} else {
return std::tuple<Tensor, Tensor>(eigvals, at::empty({0}, self.options()));
}
}

std::tuple<Tensor, Tensor> linalg_eigh(const Tensor& self, std::string uplo) {
squareCheckInputs(self);
checkUplo(uplo);
return at::_syevd_helper(self, /*compute_eigenvectors=*/true, uplo);
}

// TODO: it's possible to make the _out variant to be a primal function and implement linalg_eigh on top of _out
// TODO: implement _out variant avoiding copy and using already allocated storage directly
std::tuple<Tensor&, Tensor&> linalg_eigh_out(Tensor& eigvals, Tensor& eigvecs, const Tensor& self, std::string uplo) {
TORCH_CHECK(eigvecs.scalar_type() == self.scalar_type(),
"eigvecs dtype ", eigvecs.scalar_type(), " does not match self dtype ", self.scalar_type());
ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype()));
TORCH_CHECK(eigvals.scalar_type() == real_dtype,
"eigvals dtype ", eigvals.scalar_type(), " does not match self dtype ", real_dtype);

Tensor eigvals_tmp, eigvecs_tmp;
std::tie(eigvals_tmp, eigvecs_tmp) = at::linalg_eigh(self, uplo);

at::native::resize_output(eigvals, eigvals_tmp.sizes());
eigvals.copy_(eigvals_tmp);
at::native::resize_output(eigvecs, eigvecs_tmp.sizes());
eigvecs.copy_(eigvecs_tmp);

return std::tuple<Tensor&, Tensor&>(eigvals, eigvecs);
}

Tensor linalg_eigvalsh(const Tensor& self, std::string uplo) {
squareCheckInputs(self);
checkUplo(uplo);
Tensor eigvals, eigvecs;
std::tie(eigvals, eigvecs) = at::_syevd_helper(self, /*compute_eigenvectors=*/false, uplo);
return eigvals;
}

// TODO: it's possible to make the _out variant to be a primal function and implement linalg_eigvalsh on top of _out
// TODO: implement _out variant avoiding copy and using already allocated storage directly
Tensor& linalg_eigvalsh_out(Tensor& result, const Tensor& self, std::string uplo) {
ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype()));
TORCH_CHECK(result.scalar_type() == real_dtype,
"result dtype ", result.scalar_type(), " does not match self dtype ", real_dtype);

Tensor result_tmp = at::linalg_eigvalsh(self, uplo);

at::native::resize_output(result, result_tmp.sizes());
result.copy_(result_tmp);

return result;
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ symeig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template <typename scalar_t>
Expand Down
12 changes: 11 additions & 1 deletion aten/src/ATen/native/LinearAlgebraUtils.h
Expand Up @@ -7,6 +7,7 @@
#include <limits>
#include <sstream>
#include <cstring>
#include <cctype>

namespace at { namespace native {

Expand Down Expand Up @@ -97,7 +98,7 @@ static inline void batchCheckErrors(std::vector<int64_t>& infos, const char* nam
} else if (info > 0) {
if (strstr(name, "svd")) {
AT_ERROR(name, ": the updating process of SBDSDC did not converge (error: ", info, ")");
} else if (strstr(name, "symeig")) {
} else if (strstr(name, "symeig") || strstr(name, "syevd")) {
AT_ERROR(name, ": For batch ", i, ": the algorithm failed to converge; ", info,
" off-diagonal elements of an intermediate tridiagonal form did not converge to zero.");
} else if (!allow_singular) {
Expand Down Expand Up @@ -333,4 +334,13 @@ static inline int64_t computeLRWorkDim(const char jobz, int64_t m, int64_t n) {
return std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn);
}

// This function checks whether the uplo argument input is valid
// Allowed strings are "u", "U", "l", "L"
static inline void checkUplo(const std::string& uplo) {
// To use std::toupper safely with plain chars (or signed chars), the argument should first be converted to unsigned char
char uplo_uppercase = static_cast<char>(std::toupper(static_cast<unsigned char>(uplo[0])));
TORCH_CHECK(uplo.size() == 1 && (uplo_uppercase == 'U' || uplo_uppercase == 'L'),
"Expected UPLO argument to be 'L' or 'U', but got ", uplo);
}

}} // namespace at::native
15 changes: 15 additions & 0 deletions aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
Expand Up @@ -1783,6 +1783,21 @@ std::tuple<Tensor, Tensor> _symeig_helper_cuda(const Tensor& self, bool eigenvec
}
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ syevd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// This function computes eigenvalues 'w' and eigenvectors 'v' of the tensor 'self'
// compute_eigenvectors controls whether eigenvectors should be computed
// uplo controls the portion of input matrix to consider in computations, allowed values are "u", "U", "l", "L"
// '_symeig_helper_cuda' prepares correct input for 'apply_symeig' and checks for possible errors using 'infos'
// See also CPU implementation in aten/src/ATen/native/BatchLinearAlgebra.cpp
std::tuple<Tensor, Tensor> _syevd_helper_cuda(const Tensor& self, bool compute_eigenvectors, std::string uplo_str) {
// NumPy allows lowercase input for UPLO argument
// It is assumed that uplo_str is either "U" or "L"
char uplo = std::toupper(uplo_str[0]);
bool upper = uplo == 'U' ? true : false;
return _symeig_helper_cuda(self, compute_eigenvectors, upper);
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template<typename scalar_t>
Expand Down
31 changes: 31 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -9373,6 +9373,37 @@
dispatch:
DefaultBackend: det

- func: _syevd_helper(Tensor self, bool compute_eigenvectors, str uplo) -> (Tensor, Tensor)
use_c10_dispatcher: full
variants: function
dispatch:
CPU: _syevd_helper_cpu
CUDA: _syevd_helper_cuda

- func: linalg_eigh(Tensor self, str UPLO="L") -> (Tensor eigenvalues, Tensor eigenvectors)
python_module: linalg
use_c10_dispatcher: full
variants: function
dispatch:
DefaultBackend: linalg_eigh

- func: linalg_eigh.eigvals(Tensor self, str UPLO="L", *, Tensor(a!) eigvals, Tensor(b!) eigvecs) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)
python_module: linalg
dispatch:
DefaultBackend: linalg_eigh_out

- func: linalg_eigvalsh(Tensor self, str UPLO="L") -> Tensor
python_module: linalg
use_c10_dispatcher: full
variants: function
dispatch:
DefaultBackend: linalg_eigvalsh

- func: linalg_eigvalsh.out(Tensor self, str UPLO='L', *, Tensor(a!) out) -> Tensor(a!)
python_module: linalg
dispatch:
DefaultBackend: linalg_eigvalsh_out

# torch.outer, alias for torch.ger
- func: outer(Tensor self, Tensor vec2) -> Tensor
use_c10_dispatcher: full
Expand Down
2 changes: 2 additions & 0 deletions docs/source/linalg.rst
Expand Up @@ -14,6 +14,8 @@ Functions

.. autofunction:: cholesky
.. autofunction:: det
.. autofunction:: eigh
.. autofunction:: eigvalsh
.. autofunction:: norm
.. autofunction:: tensorinv
.. autofunction:: tensorsolve

0 comments on commit 4ed7f36

Please sign in to comment.