Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added linalg.eigh, linalg.eigvalsh #45526

Closed
wants to merge 53 commits into from
Closed
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
018604f
Started torch.linalg.eigh and torch.linalg.eigvalsh
IvanYashchuk Sep 29, 2020
4662eb7
Fixed syevd cpu
IvanYashchuk Sep 29, 2020
bdcda38
Added test for linalg.eigh
IvanYashchuk Sep 29, 2020
b6f814c
Expose linalg_eigh, linalg_eigvalsh functions in torch.linalg
IvanYashchuk Sep 29, 2020
781fb2c
Use smaller shape
IvanYashchuk Sep 29, 2020
8f6325b
Added linalg_eigh documentation
IvanYashchuk Oct 5, 2020
6654f4c
Added linalg_eigvalsh documentation
IvanYashchuk Oct 5, 2020
ecd0c80
Remove code duplication due to rebasing
IvanYashchuk Oct 5, 2020
c871eb6
Updated test_eigh
IvanYashchuk Oct 5, 2020
7a15c47
Added test_eigvalsh
IvanYashchuk Oct 5, 2020
432743f
Remove c10::optional uplo
IvanYashchuk Oct 5, 2020
5a94008
Added _syevd_helper_cuda
IvanYashchuk Oct 5, 2020
e8180d8
Use (void) to mark unused arguments.
IvanYashchuk Oct 5, 2020
4f8542c
Renamed compute_v -> compute_eigenvectors
IvanYashchuk Oct 7, 2020
629d934
Added test check with sign flipping of eigenvectors
IvanYashchuk Oct 8, 2020
e4d5a23
Merge remote-tracking branch 'upstream/master' into numpy-eig
IvanYashchuk Nov 2, 2020
62f3e29
Remove skip if numpy not available
IvanYashchuk Nov 2, 2020
f151bfa
Update input argument description
IvanYashchuk Nov 2, 2020
9e94c8b
Added non-contiguous inptu tests
IvanYashchuk Nov 2, 2020
3ce51ca
Updated docs
IvanYashchuk Nov 2, 2020
ec8f824
Make eigh return named tensors, add default value for uplo
IvanYashchuk Nov 2, 2020
0c0dbac
Updated UPLO argument implementation
IvanYashchuk Nov 2, 2020
b55f824
Added out= variant
IvanYashchuk Nov 2, 2020
3c08f0f
Added wrappers to torch/linalg.h
IvanYashchuk Nov 2, 2020
8fdf8b2
Added overrides.py entry
IvanYashchuk Nov 2, 2020
4477796
Added derivative rules for linalg_eigh
IvanYashchuk Nov 2, 2020
87b7c22
Check dtypes for out= variants
IvanYashchuk Nov 2, 2020
794d7cd
Merge remote-tracking branch 'upstream/master' into numpy-eig
IvanYashchuk Nov 2, 2020
4e405bb
flake8
IvanYashchuk Nov 2, 2020
a3adeb4
Merge remote-tracking branch 'upstream/master' into numpy-eig
IvanYashchuk Nov 3, 2020
93b68c4
Add namedtuple entry
IvanYashchuk Nov 3, 2020
685f383
Fix typo
IvanYashchuk Nov 3, 2020
f880538
Merge remote-tracking branch 'upstream/master' into numpy-eig
IvanYashchuk Nov 4, 2020
0be9cac
Added entry for linalg_eigh namedtuple test
IvanYashchuk Nov 4, 2020
efdbb61
Merge branch 'master' into numpy-eig
IvanYashchuk Nov 6, 2020
0ce7224
Typed std::max; added a few comments on the implementation
IvanYashchuk Nov 12, 2020
a8f8366
Remove method variant
IvanYashchuk Nov 12, 2020
434aa9a
Test torch.linalg.eigh namedtuple
IvanYashchuk Nov 12, 2020
9ca046a
Added test with lower case uplo
IvanYashchuk Nov 12, 2020
d9df1b5
Added a few comments describing new functions
IvanYashchuk Nov 12, 2020
890478d
Updated documentation
IvanYashchuk Nov 12, 2020
9abfb24
Merge remote-tracking branch 'upstream/master' into numpy-eig
IvanYashchuk Nov 12, 2020
8a775fd
Doc fixes
IvanYashchuk Nov 12, 2020
0a8669d
Remove unused import
IvanYashchuk Nov 12, 2020
2603839
Added a note on non-uniqueness of eigenvectors
IvanYashchuk Nov 12, 2020
c65451b
Merge remote-tracking branch 'upstream/master' into numpy-eig
IvanYashchuk Nov 17, 2020
78f6b4e
Use same error message for syevd as for symeig when failed to converge
IvanYashchuk Nov 17, 2020
26971ec
Updated documentation
IvanYashchuk Nov 17, 2020
5ed0159
Complex batched matmul now works on cuda
IvanYashchuk Nov 17, 2020
dbc7a72
, -> and
IvanYashchuk Nov 17, 2020
2911420
Safe use of std::toupper with plain chars
IvanYashchuk Nov 18, 2020
85c6c80
Merge remote-tracking branch 'upstream/master' into numpy-eig
IvanYashchuk Nov 20, 2020
bd41343
Fixed backward; can't in-place multiply real and complex tensors
IvanYashchuk Nov 20, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
180 changes: 180 additions & 0 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Expand Up @@ -72,6 +72,12 @@ extern "C" void cheev_(char *jobz, char *uplo, int *n, std::complex<float> *a, i
extern "C" void dsyev_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *info);
extern "C" void ssyev_(char *jobz, char *uplo, int *n, float *a, int *lda, float *w, float *work, int *lwork, int *info);

// syevd
extern "C" void zheevd_(char *jobz, char *uplo, int *n, std::complex<double> *a, int *lda, double *w, std::complex<double> *work, int *lwork, double *rwork, int *lrwork, int *iwork, int *liwork, int *info);
extern "C" void cheevd_(char *jobz, char *uplo, int *n, std::complex<float> *a, int *lda, float *w, std::complex<float> *work, int *lwork, float *rwork, int *lrwork, int *iwork, int *liwork, int *info);
extern "C" void dsyevd_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *iwork, int *liwork, int *info);
extern "C" void ssyevd_(char *jobz, char *uplo, int *n, float *a, int *lda, float *w, float *work, int *lwork, int *iwork, int *liwork, int *info);

// gesdd
extern "C" void zgesdd_(char *jobz, int *m, int *n, std::complex<double> *a, int *lda,
double *s, std::complex<double> *u, int *ldu, std::complex<double> *vt, int *ldvt, std::complex<double> *work, int *lwork, double *rwork, int *iwork, int *info);
Expand Down Expand Up @@ -122,6 +128,9 @@ void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scala
template<class scalar_t, class value_t=scalar_t>
void lapackSymeig(char jobz, char uplo, int n, scalar_t *a, int lda, value_t *w, scalar_t *work, int lwork, value_t *rwork, int *info);

template<class scalar_t, class value_t=scalar_t>
void lapackSyevd(char jobz, char uplo, int n, scalar_t *a, int lda, value_t *w, scalar_t *work, int lwork, value_t *rwork, int lrwork, int *iwork, int liwork, int *info);

template<class scalar_t, class value_t=scalar_t>
void lapackSvd(char jobz, int m, int n, scalar_t *a, int lda,
value_t *s, scalar_t *u, int ldu, scalar_t *vt, int ldvt, scalar_t *work, int lwork, value_t *rwork, int *iwork, int *info);
Expand Down Expand Up @@ -276,6 +285,26 @@ template<> void lapackSymeig<float>(char jobz, char uplo, int n, float *a, int l
ssyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info);
}

template<> void lapackSyevd<c10::complex<double>, double>(char jobz, char uplo, int n, c10::complex<double> *a, int lda, double *w, c10::complex<double> *work, int lwork, double *rwork, int lrwork, int *iwork, int liwork, int *info) {
zheevd_(&jobz, &uplo, &n, reinterpret_cast<std::complex<double>*>(a), &lda, w, reinterpret_cast<std::complex<double>*>(work), &lwork, rwork, &lrwork, iwork, &liwork, info);
}

template<> void lapackSyevd<c10::complex<float>, float>(char jobz, char uplo, int n, c10::complex<float> *a, int lda, float *w, c10::complex<float> *work, int lwork, float *rwork, int lrwork, int *iwork, int liwork, int *info) {
cheevd_(&jobz, &uplo, &n, reinterpret_cast<std::complex<float>*>(a), &lda, w, reinterpret_cast<std::complex<float>*>(work), &lwork, rwork, &lrwork, iwork, &liwork, info);
}

template<> void lapackSyevd<double>(char jobz, char uplo, int n, double *a, int lda, double *w, double *work, int lwork, double *rwork, int lrwork, int *iwork, int liwork, int *info) {
(void)rwork; // unused
(void)lrwork; // unused
dsyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info);
}

template<> void lapackSyevd<float>(char jobz, char uplo, int n, float *a, int lda, float *w, float *work, int lwork, float *rwork, int lrwork, int *iwork, int liwork, int *info) {
(void)rwork; // unused
(void)lrwork; // unused
ssyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info);
}

template<> void lapackSvd<c10::complex<double>, double>(char jobz, int m, int n, c10::complex<double> *a, int lda,
double *s, c10::complex<double> *u, int ldu, c10::complex<double> *vt, int ldvt, c10::complex<double> *work, int lwork, double *rwork, int *iwork, int *info) {
zgesdd_(&jobz, &m, &n, reinterpret_cast<std::complex<double>*>(a), &lda, s, reinterpret_cast<std::complex<double>*>(u), &ldu,
Expand Down Expand Up @@ -879,6 +908,157 @@ std::tuple<Tensor&,Tensor&> qr_out(Tensor& Q, Tensor& R, const Tensor& self, boo
return std::tuple<Tensor&, Tensor&>(Q, R);
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ syevd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// This function computes eigenvalues 'w' and eigenvectors 'v' of the input that is stored initially in 'v'
// The computation is done in-place: 'v' stores the input and will be overriden, 'w' should be an allocated empty array
// compute_v controls whether eigenvectors should be computed
// uplo_str controls the portion of input matrix to consider in computations, allowed values are "u", "U", "l", "L"
// infos is used to store information for possible checks for error
// This function doesn't do any error checks and it's assumed that every argument is valid
template <typename scalar_t>
static void apply_syevd(Tensor& w, Tensor& v, bool compute_v, const std::string& uplo_str, std::vector<int64_t>& infos) {
#ifndef USE_LAPACK
AT_ERROR("syevd: LAPACK library not found in compilation");
#else
using value_t = typename c10::scalar_value_type<scalar_t>::type;

auto v_data = v.data_ptr<scalar_t>();
auto w_data = w.data_ptr<value_t>();
auto v_matrix_stride = matrixStride(v);
auto w_stride = w.size(-1);
auto batch_size = batchCount(v);
auto n = v.size(-1);
auto lda = std::max(int64_t{1}, n);

// NumPy allows lowercase input for UPLO argument
// It is assumed that uplo_str is either "U" or "L"
char uplo = std::toupper(uplo_str[0]);
char jobz = compute_v ? 'V' : 'N';

// Using 'int' instead of int32_t or int64_t is consistent with the current LAPACK interface
// It really should be changed in the future to something like lapack_int that depends on the specific LAPACK library that is linked
// or switch to supporting only 64-bit indexing by default.
int info;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment here explaining that the use of imprecise types, like int, is consistent with the LAPACK interface (otherwise we'd use a precise type, like int32_t).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: if we were being extremely correct we would add a using declaration to make it easy to change this int type later

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this kind of type change should happen at once for all functions that call to LAPACK. There should be lapack_int that depends on the type of the linked LAPACK or only 64-bit indexing should be used by default and only 64-bit version of LAPACK allowed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you and the comment looks great.

int lwork = -1;
int lrwork = -1;
int liwork = -1;
scalar_t work_query;
value_t rwork_query;
int iwork_query;

// Run lapackSyevd once, first to get the optimum work size.
// Since we deal with batches of matrices with the same dimensions, doing this outside
// the main loop saves (batch_size - 1) workspace queries which would provide the same result
// and (batch_size - 1) calls to allocate and deallocate workspace using at::empty()
lapackSyevd<scalar_t, value_t>(jobz, uplo, n, v_data, lda, w_data, &work_query, lwork, &rwork_query, lrwork, &iwork_query, liwork, &info);

lwork = std::max<int>(1, real_impl<scalar_t, value_t>(work_query));
Tensor work = at::empty({lwork}, v.options());
liwork = std::max<int>(1, iwork_query);
Tensor iwork = at::empty({liwork}, at::kInt);

Tensor rwork;
value_t* rwork_data = nullptr;
if (isComplexType(at::typeMetaToScalarType(v.dtype()))) {
lrwork = std::max<int>(1, rwork_query);
rwork = at::empty({lrwork}, w.options());
rwork_data = rwork.data_ptr<value_t>();
}

// Now call lapackSyevd for each matrix in the batched input
for (auto i = decltype(batch_size){0}; i < batch_size; i++) {
scalar_t* v_working_ptr = &v_data[i * v_matrix_stride];
value_t* w_working_ptr = &w_data[i * w_stride];
lapackSyevd<scalar_t, value_t>(jobz, uplo, n, v_working_ptr, lda, w_working_ptr, work.data_ptr<scalar_t>(), lwork, rwork_data, lrwork, iwork.data_ptr<int>(), liwork, &info);
infos[i] = info;
// The current behaviour for Linear Algebra functions to raise an error if something goes wrong or input doesn't satisfy some requirement
// therefore return early since further computations will be wasted anyway
if (info != 0) {
return;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment here explaining what triggers this return, and what callers should do if this is triggered

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow-up question:

In the future we expect to add _ex variants of functions in linear algebra that have additional error information. For example, let's say we have torch.linalg.foo that computes the foo operation on a non-singular matrix. If the matrix is singular, however, then the math library we're using returns error information on a device tensor explaining that the matrix is singular.

On the CPU, torch.linalg.foo(cpu_matrix) can throw an error when a singular tensor is encountered. On CUDA, however, translating the device tensor containing the error information requires a cross-device sync, which we try to avoid. torch.linalg.foo(cuda_matrix) should perform and document this sync, however.

torch.linalg.foo_ex, however, will return both device tensors and not perform the error check. This allows the user to check the error themselves, and avoid an immediate cross-device sync on CUDA if they want.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current behavior for linalg (LAPACK/MAGMA based) functions to raise an error if something goes wrong or input doesn't satisfy some requirement therefore all batched CPU functions return early since further computations will be wasted anyway. In the future, this early return should be removed.

What does _ex suffix mean?
I agree that this kind of change will improve performance. I think the default behavior should be the same on CPU and CUDA, so CPU shouldn't raise an error, even though it's cheap to check for it, if CUDA doesn't.
What do you think about instead of introducing new functions with _ex suffix in Python interface modify existing ones to accept optional info= argument, similar to out=, that will be filled with the returned LAPACK/MAGMA/cuSOLVER error codes for the user to check if he wants to.

eigenvalues, eigenvectors = torch.linalg.eigh(input, info=torch.empty(batch_size))
# instead of 
eigenvalues, eigenvectors, info = torch.linalg.eigh_ex(input)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current behavior for linalg (LAPACK/MAGMA based) functions to raise an error if something goes wrong or input doesn't satisfy some requirement therefore all batched CPU functions return early since further computations will be wasted anyway. In the future, this early return should be removed.

Sounds good for now.

What does _ex suffix mean?

"Extra."

I agree that this kind of change will improve performance. I think the default behavior should be the same on CPU and CUDA, so CPU shouldn't raise an error, even though it's cheap to check for it, if CUDA doesn't.

Agreed. Both the CPU and CUDA ex variants will not perform their own error checking and return the info tensor for callers to evaluate.

What do you think about instead of introducing new functions with _ex suffix in Python interface modify existing ones to accept optional info= argument, similar to out=, that will be filled with the returned LAPACK/MAGMA/cuSOLVER error codes for the user to check if he wants to.

This is a really good idea and we did consider an approach like this, but there are technical issues with this approach that make it prohibitive.

}
}
#endif
}

// This function computes eigenvalues 'w' and eigenvectors 'v' of the tensor 'self'
// compute_eigenvectors controls whether eigenvectors should be computed
// uplo controls the portion of input matrix to consider in computations, allowed values are "u", "U", "l", "L"
// This function prepares correct input for 'apply_syevd' and checks for possible errors using 'infos'
std::tuple<Tensor, Tensor> _syevd_helper_cpu(const Tensor& self, bool compute_eigenvectors, std::string uplo) {
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
std::vector<int64_t> infos(batchCount(self), 0);

auto self_sizes = self.sizes().vec();
self_sizes.pop_back();
ScalarType dtype = toValueType(typeMetaToScalarType(self.dtype()));
auto eigvals = at::empty(self_sizes, self.options().dtype(dtype));

auto eigvecs = cloneBatchedColumnMajor(self);
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "syevd_cpu", [&]{
apply_syevd<scalar_t>(eigvals, eigvecs, compute_eigenvectors, uplo, infos);
});

if (self.dim() > 2) {
batchCheckErrors(infos, "syevd_cpu");
} else {
singleCheckErrors(infos[0], "syevd_cpu");
}
if (compute_eigenvectors) {
return std::tuple<Tensor, Tensor>(eigvals, eigvecs);
} else {
return std::tuple<Tensor, Tensor>(eigvals, at::empty({0}, self.options()));
}
}

std::tuple<Tensor, Tensor> linalg_eigh(const Tensor& self, std::string uplo) {
squareCheckInputs(self);
checkUplo(uplo);
return at::_syevd_helper(self, /*compute_eigenvectors=*/true, uplo);
}

// TODO: it's possible to make the _out variant to be a primal function and implement linalg_eigh on top of _out
// TODO: implement _out variant avoiding copy and using already allocated storage directly
std::tuple<Tensor&, Tensor&> linalg_eigh_out(Tensor& eigvals, Tensor& eigvecs, const Tensor& self, std::string uplo) {
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
TORCH_CHECK(eigvecs.scalar_type() == self.scalar_type(),
"eigvecs dtype ", eigvecs.scalar_type(), " does not match self dtype ", self.scalar_type());
ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype()));
TORCH_CHECK(eigvals.scalar_type() == real_dtype,
"eigvals dtype ", eigvals.scalar_type(), " does not match self dtype ", real_dtype);

Tensor eigvals_tmp, eigvecs_tmp;
std::tie(eigvals_tmp, eigvecs_tmp) = at::linalg_eigh(self, uplo);

at::native::resize_output(eigvals, eigvals_tmp.sizes());
eigvals.copy_(eigvals_tmp);
at::native::resize_output(eigvecs, eigvecs_tmp.sizes());
eigvecs.copy_(eigvecs_tmp);

return std::tuple<Tensor&, Tensor&>(eigvals, eigvecs);
}

Tensor linalg_eigvalsh(const Tensor& self, std::string uplo) {
squareCheckInputs(self);
checkUplo(uplo);
Tensor eigvals, eigvecs;
std::tie(eigvals, eigvecs) = at::_syevd_helper(self, /*compute_eigenvectors=*/false, uplo);
return eigvals;
}

// TODO: it's possible to make the _out variant to be a primal function and implement linalg_eigvalsh on top of _out
// TODO: implement _out variant avoiding copy and using already allocated storage directly
Tensor& linalg_eigvalsh_out(Tensor& result, const Tensor& self, std::string uplo) {
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype()));
TORCH_CHECK(result.scalar_type() == real_dtype,
"result dtype ", result.scalar_type(), " does not match self dtype ", real_dtype);

Tensor result_tmp = at::linalg_eigvalsh(self, uplo);

at::native::resize_output(result, result_tmp.sizes());
result.copy_(result_tmp);

return result;
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ symeig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template <typename scalar_t>
Expand Down
10 changes: 9 additions & 1 deletion aten/src/ATen/native/LinearAlgebraUtils.h
Expand Up @@ -97,7 +97,7 @@ static inline void batchCheckErrors(std::vector<int64_t>& infos, const char* nam
} else if (info > 0) {
if (strstr(name, "svd")) {
AT_ERROR(name, ": the updating process of SBDSDC did not converge (error: ", info, ")");
} else if (strstr(name, "symeig")) {
} else if (strstr(name, "symeig") || strstr(name, "syevd")) {
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
AT_ERROR(name, ": For batch ", i, ": the algorithm failed to converge; ", info,
" off-diagonal elements of an intermediate tridiagonal form did not converge to zero.");
} else if (!allow_singular) {
Expand Down Expand Up @@ -333,4 +333,12 @@ static inline int64_t computeLRWorkDim(const char jobz, int64_t m, int64_t n) {
return std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn);
}

// This function checks whether the uplo argument input is valid
// Allowed strings are "u", "U", "l", "L"
static inline void checkUplo(const std::string& uplo) {
char uplo_uppercase = std::toupper(uplo[0]);
TORCH_CHECK(uplo.size() == 1 && (uplo_uppercase == 'U' || uplo_uppercase == 'L'),
"Expected UPLO argument to be 'L' or 'U', but got ", uplo);
}

}} // namespace at::native
15 changes: 15 additions & 0 deletions aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
Expand Up @@ -1783,6 +1783,21 @@ std::tuple<Tensor, Tensor> _symeig_helper_cuda(const Tensor& self, bool eigenvec
}
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ syevd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// This function computes eigenvalues 'w' and eigenvectors 'v' of the tensor 'self'
// compute_eigenvectors controls whether eigenvectors should be computed
// uplo controls the portion of input matrix to consider in computations, allowed values are "u", "U", "l", "L"
// '_symeig_helper_cuda' prepares correct input for 'apply_symeig' and checks for possible errors using 'infos'
// See also CPU implementation in aten/src/ATen/native/BatchLinearAlgebra.cpp
std::tuple<Tensor, Tensor> _syevd_helper_cuda(const Tensor& self, bool compute_eigenvectors, std::string uplo_str) {
// NumPy allows lowercase input for UPLO argument
// It is assumed that uplo_str is either "U" or "L"
char uplo = std::toupper(uplo_str[0]);
bool upper = uplo == 'U' ? true : false;
return _symeig_helper_cuda(self, compute_eigenvectors, upper);
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template<typename scalar_t>
Expand Down
31 changes: 31 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -8943,6 +8943,37 @@
dispatch:
DefaultBackend: det

- func: _syevd_helper(Tensor self, bool compute_eigenvectors, str uplo) -> (Tensor, Tensor)
use_c10_dispatcher: full
variants: function
dispatch:
CPU: _syevd_helper_cpu
CUDA: _syevd_helper_cuda

- func: linalg_eigh(Tensor self, str UPLO="L") -> (Tensor eigenvalues, Tensor eigenvectors)
python_module: linalg
use_c10_dispatcher: full
variants: function
dispatch:
DefaultBackend: linalg_eigh

- func: linalg_eigh.eigvals(Tensor self, str UPLO="L", *, Tensor(a!) eigvals, Tensor(b!) eigvecs) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)
python_module: linalg
dispatch:
DefaultBackend: linalg_eigh_out

- func: linalg_eigvalsh(Tensor self, str UPLO="L") -> Tensor
python_module: linalg
use_c10_dispatcher: full
variants: function
dispatch:
DefaultBackend: linalg_eigvalsh

- func: linalg_eigvalsh.out(Tensor self, str UPLO='L', *, Tensor(a!) out) -> Tensor(a!)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'L' is used here instead of "L" as it caused errors during the compilation. I didn't investigate it further, I hope it's fine and doesn't break anything.

python_module: linalg
dispatch:
DefaultBackend: linalg_eigvalsh_out

# torch.outer, alias for torch.ger
- func: outer(Tensor self, Tensor vec2) -> Tensor
use_c10_dispatcher: full
Expand Down
2 changes: 2 additions & 0 deletions docs/source/linalg.rst
Expand Up @@ -14,5 +14,7 @@ Functions

.. autofunction:: cholesky
.. autofunction:: det
.. autofunction:: eigh
.. autofunction:: eigvalsh
.. autofunction:: norm
.. autofunction:: tensorsolve