New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added linalg.eigh, linalg.eigvalsh #45526
Changes from 52 commits
018604f
4662eb7
bdcda38
b6f814c
781fb2c
8f6325b
6654f4c
ecd0c80
c871eb6
7a15c47
432743f
5a94008
e8180d8
4f8542c
629d934
e4d5a23
62f3e29
f151bfa
9e94c8b
3ce51ca
ec8f824
0c0dbac
b55f824
3c08f0f
8fdf8b2
4477796
87b7c22
794d7cd
4e405bb
a3adeb4
93b68c4
685f383
f880538
0be9cac
efdbb61
0ce7224
a8f8366
434aa9a
9ca046a
d9df1b5
890478d
9abfb24
8a775fd
0a8669d
2603839
c65451b
78f6b4e
26971ec
5ed0159
dbc7a72
2911420
85c6c80
bd41343
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -72,6 +72,12 @@ extern "C" void cheev_(char *jobz, char *uplo, int *n, std::complex<float> *a, i | |
extern "C" void dsyev_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *info); | ||
extern "C" void ssyev_(char *jobz, char *uplo, int *n, float *a, int *lda, float *w, float *work, int *lwork, int *info); | ||
|
||
// syevd | ||
extern "C" void zheevd_(char *jobz, char *uplo, int *n, std::complex<double> *a, int *lda, double *w, std::complex<double> *work, int *lwork, double *rwork, int *lrwork, int *iwork, int *liwork, int *info); | ||
extern "C" void cheevd_(char *jobz, char *uplo, int *n, std::complex<float> *a, int *lda, float *w, std::complex<float> *work, int *lwork, float *rwork, int *lrwork, int *iwork, int *liwork, int *info); | ||
extern "C" void dsyevd_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *iwork, int *liwork, int *info); | ||
extern "C" void ssyevd_(char *jobz, char *uplo, int *n, float *a, int *lda, float *w, float *work, int *lwork, int *iwork, int *liwork, int *info); | ||
|
||
// gesdd | ||
extern "C" void zgesdd_(char *jobz, int *m, int *n, std::complex<double> *a, int *lda, | ||
double *s, std::complex<double> *u, int *ldu, std::complex<double> *vt, int *ldvt, std::complex<double> *work, int *lwork, double *rwork, int *iwork, int *info); | ||
|
@@ -122,6 +128,9 @@ void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scala | |
template<class scalar_t, class value_t=scalar_t> | ||
void lapackSymeig(char jobz, char uplo, int n, scalar_t *a, int lda, value_t *w, scalar_t *work, int lwork, value_t *rwork, int *info); | ||
|
||
template<class scalar_t, class value_t=scalar_t> | ||
void lapackSyevd(char jobz, char uplo, int n, scalar_t *a, int lda, value_t *w, scalar_t *work, int lwork, value_t *rwork, int lrwork, int *iwork, int liwork, int *info); | ||
|
||
template<class scalar_t, class value_t=scalar_t> | ||
void lapackSvd(char jobz, int m, int n, scalar_t *a, int lda, | ||
value_t *s, scalar_t *u, int ldu, scalar_t *vt, int ldvt, scalar_t *work, int lwork, value_t *rwork, int *iwork, int *info); | ||
|
@@ -276,6 +285,26 @@ template<> void lapackSymeig<float>(char jobz, char uplo, int n, float *a, int l | |
ssyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info); | ||
} | ||
|
||
template<> void lapackSyevd<c10::complex<double>, double>(char jobz, char uplo, int n, c10::complex<double> *a, int lda, double *w, c10::complex<double> *work, int lwork, double *rwork, int lrwork, int *iwork, int liwork, int *info) { | ||
zheevd_(&jobz, &uplo, &n, reinterpret_cast<std::complex<double>*>(a), &lda, w, reinterpret_cast<std::complex<double>*>(work), &lwork, rwork, &lrwork, iwork, &liwork, info); | ||
} | ||
|
||
template<> void lapackSyevd<c10::complex<float>, float>(char jobz, char uplo, int n, c10::complex<float> *a, int lda, float *w, c10::complex<float> *work, int lwork, float *rwork, int lrwork, int *iwork, int liwork, int *info) { | ||
cheevd_(&jobz, &uplo, &n, reinterpret_cast<std::complex<float>*>(a), &lda, w, reinterpret_cast<std::complex<float>*>(work), &lwork, rwork, &lrwork, iwork, &liwork, info); | ||
} | ||
|
||
template<> void lapackSyevd<double>(char jobz, char uplo, int n, double *a, int lda, double *w, double *work, int lwork, double *rwork, int lrwork, int *iwork, int liwork, int *info) { | ||
(void)rwork; // unused | ||
(void)lrwork; // unused | ||
dsyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info); | ||
} | ||
|
||
template<> void lapackSyevd<float>(char jobz, char uplo, int n, float *a, int lda, float *w, float *work, int lwork, float *rwork, int lrwork, int *iwork, int liwork, int *info) { | ||
(void)rwork; // unused | ||
(void)lrwork; // unused | ||
ssyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info); | ||
} | ||
|
||
template<> void lapackSvd<c10::complex<double>, double>(char jobz, int m, int n, c10::complex<double> *a, int lda, | ||
double *s, c10::complex<double> *u, int ldu, c10::complex<double> *vt, int ldvt, c10::complex<double> *work, int lwork, double *rwork, int *iwork, int *info) { | ||
zgesdd_(&jobz, &m, &n, reinterpret_cast<std::complex<double>*>(a), &lda, s, reinterpret_cast<std::complex<double>*>(u), &ldu, | ||
|
@@ -879,6 +908,157 @@ std::tuple<Tensor&,Tensor&> qr_out(Tensor& Q, Tensor& R, const Tensor& self, boo | |
return std::tuple<Tensor&, Tensor&>(Q, R); | ||
} | ||
|
||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ syevd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
// This function computes eigenvalues 'w' and eigenvectors 'v' of the input that is stored initially in 'v' | ||
// The computation is done in-place: 'v' stores the input and will be overriden, 'w' should be an allocated empty array | ||
// compute_v controls whether eigenvectors should be computed | ||
// uplo_str controls the portion of input matrix to consider in computations, allowed values are "u", "U", "l", "L" | ||
// infos is used to store information for possible checks for error | ||
// This function doesn't do any error checks and it's assumed that every argument is valid | ||
template <typename scalar_t> | ||
static void apply_syevd(Tensor& w, Tensor& v, bool compute_v, const std::string& uplo_str, std::vector<int64_t>& infos) { | ||
#ifndef USE_LAPACK | ||
AT_ERROR("syevd: LAPACK library not found in compilation"); | ||
#else | ||
using value_t = typename c10::scalar_value_type<scalar_t>::type; | ||
|
||
auto v_data = v.data_ptr<scalar_t>(); | ||
auto w_data = w.data_ptr<value_t>(); | ||
auto v_matrix_stride = matrixStride(v); | ||
auto w_stride = w.size(-1); | ||
auto batch_size = batchCount(v); | ||
auto n = v.size(-1); | ||
auto lda = std::max(int64_t{1}, n); | ||
|
||
// NumPy allows lowercase input for UPLO argument | ||
// It is assumed that uplo_str is either "U" or "L" | ||
char uplo = std::toupper(uplo_str[0]); | ||
char jobz = compute_v ? 'V' : 'N'; | ||
|
||
// Using 'int' instead of int32_t or int64_t is consistent with the current LAPACK interface | ||
// It really should be changed in the future to something like lapack_int that depends on the specific LAPACK library that is linked | ||
// or switch to supporting only 64-bit indexing by default. | ||
int info; | ||
int lwork = -1; | ||
int lrwork = -1; | ||
int liwork = -1; | ||
scalar_t work_query; | ||
value_t rwork_query; | ||
int iwork_query; | ||
|
||
// Run lapackSyevd once, first to get the optimum work size. | ||
// Since we deal with batches of matrices with the same dimensions, doing this outside | ||
// the main loop saves (batch_size - 1) workspace queries which would provide the same result | ||
// and (batch_size - 1) calls to allocate and deallocate workspace using at::empty() | ||
lapackSyevd<scalar_t, value_t>(jobz, uplo, n, v_data, lda, w_data, &work_query, lwork, &rwork_query, lrwork, &iwork_query, liwork, &info); | ||
|
||
lwork = std::max<int>(1, real_impl<scalar_t, value_t>(work_query)); | ||
Tensor work = at::empty({lwork}, v.options()); | ||
liwork = std::max<int>(1, iwork_query); | ||
Tensor iwork = at::empty({liwork}, at::kInt); | ||
|
||
Tensor rwork; | ||
value_t* rwork_data = nullptr; | ||
if (isComplexType(at::typeMetaToScalarType(v.dtype()))) { | ||
lrwork = std::max<int>(1, rwork_query); | ||
rwork = at::empty({lrwork}, w.options()); | ||
rwork_data = rwork.data_ptr<value_t>(); | ||
} | ||
|
||
// Now call lapackSyevd for each matrix in the batched input | ||
for (auto i = decltype(batch_size){0}; i < batch_size; i++) { | ||
scalar_t* v_working_ptr = &v_data[i * v_matrix_stride]; | ||
value_t* w_working_ptr = &w_data[i * w_stride]; | ||
lapackSyevd<scalar_t, value_t>(jobz, uplo, n, v_working_ptr, lda, w_working_ptr, work.data_ptr<scalar_t>(), lwork, rwork_data, lrwork, iwork.data_ptr<int>(), liwork, &info); | ||
infos[i] = info; | ||
// The current behaviour for Linear Algebra functions to raise an error if something goes wrong or input doesn't satisfy some requirement | ||
// therefore return early since further computations will be wasted anyway | ||
if (info != 0) { | ||
return; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a comment here explaining what triggers this return, and what callers should do if this is triggered There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Follow-up question: In the future we expect to add On the CPU, torch.linalg.foo(cpu_matrix) can throw an error when a singular tensor is encountered. On CUDA, however, translating the device tensor containing the error information requires a cross-device sync, which we try to avoid. torch.linalg.foo(cuda_matrix) should perform and document this sync, however. torch.linalg.foo_ex, however, will return both device tensors and not perform the error check. This allows the user to check the error themselves, and avoid an immediate cross-device sync on CUDA if they want. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current behavior for linalg (LAPACK/MAGMA based) functions to raise an error if something goes wrong or input doesn't satisfy some requirement therefore all batched CPU functions return early since further computations will be wasted anyway. In the future, this early return should be removed. What does eigenvalues, eigenvectors = torch.linalg.eigh(input, info=torch.empty(batch_size))
# instead of
eigenvalues, eigenvectors, info = torch.linalg.eigh_ex(input) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Sounds good for now.
"Extra."
Agreed. Both the CPU and CUDA ex variants will not perform their own error checking and return the info tensor for callers to evaluate.
This is a really good idea and we did consider an approach like this, but there are technical issues with this approach that make it prohibitive. |
||
} | ||
} | ||
#endif | ||
} | ||
|
||
// This function computes eigenvalues 'w' and eigenvectors 'v' of the tensor 'self' | ||
// compute_eigenvectors controls whether eigenvectors should be computed | ||
// uplo controls the portion of input matrix to consider in computations, allowed values are "u", "U", "l", "L" | ||
// This function prepares correct input for 'apply_syevd' and checks for possible errors using 'infos' | ||
std::tuple<Tensor, Tensor> _syevd_helper_cpu(const Tensor& self, bool compute_eigenvectors, std::string uplo) { | ||
IvanYashchuk marked this conversation as resolved.
Show resolved
Hide resolved
IvanYashchuk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
std::vector<int64_t> infos(batchCount(self), 0); | ||
|
||
auto self_sizes = self.sizes().vec(); | ||
self_sizes.pop_back(); | ||
ScalarType dtype = toValueType(typeMetaToScalarType(self.dtype())); | ||
auto eigvals = at::empty(self_sizes, self.options().dtype(dtype)); | ||
|
||
auto eigvecs = cloneBatchedColumnMajor(self); | ||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "syevd_cpu", [&]{ | ||
apply_syevd<scalar_t>(eigvals, eigvecs, compute_eigenvectors, uplo, infos); | ||
}); | ||
|
||
if (self.dim() > 2) { | ||
batchCheckErrors(infos, "syevd_cpu"); | ||
} else { | ||
singleCheckErrors(infos[0], "syevd_cpu"); | ||
} | ||
if (compute_eigenvectors) { | ||
return std::tuple<Tensor, Tensor>(eigvals, eigvecs); | ||
} else { | ||
return std::tuple<Tensor, Tensor>(eigvals, at::empty({0}, self.options())); | ||
} | ||
} | ||
|
||
std::tuple<Tensor, Tensor> linalg_eigh(const Tensor& self, std::string uplo) { | ||
squareCheckInputs(self); | ||
checkUplo(uplo); | ||
return at::_syevd_helper(self, /*compute_eigenvectors=*/true, uplo); | ||
} | ||
|
||
// TODO: it's possible to make the _out variant to be a primal function and implement linalg_eigh on top of _out | ||
// TODO: implement _out variant avoiding copy and using already allocated storage directly | ||
std::tuple<Tensor&, Tensor&> linalg_eigh_out(Tensor& eigvals, Tensor& eigvecs, const Tensor& self, std::string uplo) { | ||
IvanYashchuk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
TORCH_CHECK(eigvecs.scalar_type() == self.scalar_type(), | ||
"eigvecs dtype ", eigvecs.scalar_type(), " does not match self dtype ", self.scalar_type()); | ||
ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype())); | ||
TORCH_CHECK(eigvals.scalar_type() == real_dtype, | ||
"eigvals dtype ", eigvals.scalar_type(), " does not match self dtype ", real_dtype); | ||
|
||
Tensor eigvals_tmp, eigvecs_tmp; | ||
std::tie(eigvals_tmp, eigvecs_tmp) = at::linalg_eigh(self, uplo); | ||
|
||
at::native::resize_output(eigvals, eigvals_tmp.sizes()); | ||
eigvals.copy_(eigvals_tmp); | ||
at::native::resize_output(eigvecs, eigvecs_tmp.sizes()); | ||
eigvecs.copy_(eigvecs_tmp); | ||
|
||
return std::tuple<Tensor&, Tensor&>(eigvals, eigvecs); | ||
} | ||
|
||
Tensor linalg_eigvalsh(const Tensor& self, std::string uplo) { | ||
squareCheckInputs(self); | ||
checkUplo(uplo); | ||
Tensor eigvals, eigvecs; | ||
std::tie(eigvals, eigvecs) = at::_syevd_helper(self, /*compute_eigenvectors=*/false, uplo); | ||
return eigvals; | ||
} | ||
|
||
// TODO: it's possible to make the _out variant to be a primal function and implement linalg_eigvalsh on top of _out | ||
// TODO: implement _out variant avoiding copy and using already allocated storage directly | ||
Tensor& linalg_eigvalsh_out(Tensor& result, const Tensor& self, std::string uplo) { | ||
IvanYashchuk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype())); | ||
TORCH_CHECK(result.scalar_type() == real_dtype, | ||
"result dtype ", result.scalar_type(), " does not match self dtype ", real_dtype); | ||
|
||
Tensor result_tmp = at::linalg_eigvalsh(self, uplo); | ||
|
||
at::native::resize_output(result, result_tmp.sizes()); | ||
result.copy_(result_tmp); | ||
|
||
return result; | ||
} | ||
|
||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ symeig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
template <typename scalar_t> | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9361,6 +9361,37 @@ | |
dispatch: | ||
DefaultBackend: det | ||
|
||
- func: _syevd_helper(Tensor self, bool compute_eigenvectors, str uplo) -> (Tensor, Tensor) | ||
use_c10_dispatcher: full | ||
variants: function | ||
dispatch: | ||
CPU: _syevd_helper_cpu | ||
CUDA: _syevd_helper_cuda | ||
|
||
- func: linalg_eigh(Tensor self, str UPLO="L") -> (Tensor eigenvalues, Tensor eigenvectors) | ||
python_module: linalg | ||
use_c10_dispatcher: full | ||
variants: function | ||
dispatch: | ||
DefaultBackend: linalg_eigh | ||
|
||
- func: linalg_eigh.eigvals(Tensor self, str UPLO="L", *, Tensor(a!) eigvals, Tensor(b!) eigvecs) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) | ||
python_module: linalg | ||
dispatch: | ||
DefaultBackend: linalg_eigh_out | ||
|
||
- func: linalg_eigvalsh(Tensor self, str UPLO="L") -> Tensor | ||
python_module: linalg | ||
use_c10_dispatcher: full | ||
variants: function | ||
dispatch: | ||
DefaultBackend: linalg_eigvalsh | ||
|
||
- func: linalg_eigvalsh.out(Tensor self, str UPLO='L', *, Tensor(a!) out) -> Tensor(a!) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
python_module: linalg | ||
dispatch: | ||
DefaultBackend: linalg_eigvalsh_out | ||
|
||
# torch.outer, alias for torch.ger | ||
- func: outer(Tensor self, Tensor vec2) -> Tensor | ||
use_c10_dispatcher: full | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment here explaining that the use of imprecise types, like int, is consistent with the LAPACK interface (otherwise we'd use a precise type, like int32_t).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: if we were being extremely correct we would add a using declaration to make it easy to change this int type later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this kind of type change should happen at once for all functions that call to LAPACK. There should be
lapack_int
that depends on the type of the linked LAPACK or only 64-bit indexing should be used by default and only 64-bit version of LAPACK allowed.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with you and the comment looks great.