From 4786d9b0efcd567c06b67ce4bf476d26c3452f05 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Thu, 10 Dec 2020 15:58:12 +0000 Subject: [PATCH 01/12] enable the first eig/complex test, which currently fails [ci skip] --- test/test_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 1b5e2f0ee712..12b6530888a0 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1468,7 +1468,7 @@ def test_norm_fastpaths(self, device): @skipCPUIfNoLapack @skipCUDAIfNoMagma - @dtypes(torch.double, torch.float) + @dtypes(*floating_and_complex_types()) def test_eig_basic(self, device, dtype): a = torch.tensor([[1.96, 0.00, 0.00, 0.00, 0.00], [-6.49, 3.80, 0.00, 0.00, 0.00], From 3809df9d912fd5006b8bc3261f352101f521c06b Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Fri, 11 Dec 2020 16:36:36 +0000 Subject: [PATCH 02/12] prepare lapackEig to support complex types: instead of passing wr and wi separately, we pass only w and compute the two sub-arrays later, and add a (so far unused) rwork argument --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 12 ++++++++++-- aten/src/ATen/native/BatchLinearAlgebra.h | 2 +- aten/src/ATen/native/BatchLinearAlgebraKernel.cpp | 5 ++--- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index f8191c633d8b..4ff7f71cf977 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -310,11 +310,19 @@ template<> void lapackSyevd(char jobz, char uplo, int n, float *a, int ld ssyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info); } -template<> void lapackEig(char jobvl, char jobvr, int n, double *a, int lda, double *wr, double *wi, double* vl, int ldvl, double *vr, int ldvr, double *work, int lwork, int *info) { +template<> void lapackEig(char jobvl, char jobvr, int n, double *a, int lda, double *w, double* vl, int ldvl, double *vr, int ldvr, double *work, int lwork, int *info) { + // lapack [sd]geev wants to separate output arrays: wr and wi for the real + // and imaginary parts + double *wr = w; + double *wi = w + n; dgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info); } -template<> void lapackEig(char jobvl, char jobvr, int n, float *a, int lda, float *wr, float *wi, float* vl, int ldvl, float *vr, int ldvr, float *work, int lwork, int *info) { +template<> void lapackEig(char jobvl, char jobvr, int n, float *a, int lda, float *w, float* vl, int ldvl, float *vr, int ldvr, float *work, int lwork, int *info) { + // lapack [sd]geev wants to separate output arrays: wr and wi for the real + // and imaginary parts + float *wr = w; + float *wi = w + n; sgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info); } diff --git a/aten/src/ATen/native/BatchLinearAlgebra.h b/aten/src/ATen/native/BatchLinearAlgebra.h index 95fc2c6097ce..b8a1ce5d508f 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.h +++ b/aten/src/ATen/native/BatchLinearAlgebra.h @@ -13,7 +13,7 @@ namespace at { namespace native { // linear algebra operations template -void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *wr, scalar_t *wi, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, int *info); +void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, int *info); #endif diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp index d251245c60c5..ca97f781814c 100644 --- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -21,7 +21,6 @@ void apply_eig(const Tensor& self, bool eigenvectors, Tensor& vals_, Tensor& vec auto vals_data = vals_.data_ptr(); scalar_t* wr = vals_data; - scalar_t* wi = vals_data + n; scalar_t* vecs_data = eigenvectors ? vecs_.data_ptr() : nullptr; int ldvr = eigenvectors ? n : 1; @@ -30,13 +29,13 @@ void apply_eig(const Tensor& self, bool eigenvectors, Tensor& vals_, Tensor& vec // call lapackEig once to get the optimal size for work data scalar_t wkopt; int info; - lapackEig('N', jobvr, n, self_data, n, wr, wi, + lapackEig('N', jobvr, n, self_data, n, wr, nullptr, 1, vecs_data, ldvr, &wkopt, -1, &info); int lwork = static_cast(wkopt); // call again to do the actual work Tensor work = at::empty({lwork}, self.dtype()); - lapackEig('N', jobvr, n, self_data, n, wr, wi, + lapackEig('N', jobvr, n, self_data, n, wr, nullptr, 1, vecs_data, ldvr, work.data_ptr(), lwork, &info); *info_ptr = info; } From a0034e873b38c81dadacb64aa28d81808dd2ec4d Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Fri, 11 Dec 2020 17:14:38 +0000 Subject: [PATCH 03/12] progress: dispatch apply_eig also to complex types, and tweak things around until the code compiles again. Add empty implementations for lapackEig on complex types --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 15 +++++++++++++-- aten/src/ATen/native/BatchLinearAlgebra.h | 4 ++-- .../src/ATen/native/BatchLinearAlgebraKernel.cpp | 16 ++++++++++------ 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 4ff7f71cf977..da6c52faf40c 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -310,22 +310,33 @@ template<> void lapackSyevd(char jobz, char uplo, int n, float *a, int ld ssyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info); } -template<> void lapackEig(char jobvl, char jobvr, int n, double *a, int lda, double *w, double* vl, int ldvl, double *vr, int ldvr, double *work, int lwork, int *info) { +template<> void lapackEig(char jobvl, char jobvr, int n, double *a, int lda, double *w, double* vl, int ldvl, double *vr, int ldvr, double *work, int lwork, double *rwork, int *info) { // lapack [sd]geev wants to separate output arrays: wr and wi for the real // and imaginary parts double *wr = w; double *wi = w + n; + (void)rwork; // unused dgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info); } -template<> void lapackEig(char jobvl, char jobvr, int n, float *a, int lda, float *w, float* vl, int ldvl, float *vr, int ldvr, float *work, int lwork, int *info) { +template<> void lapackEig(char jobvl, char jobvr, int n, float *a, int lda, float *w, float* vl, int ldvl, float *vr, int ldvr, float *work, int lwork, float *rwork, int *info) { // lapack [sd]geev wants to separate output arrays: wr and wi for the real // and imaginary parts float *wr = w; float *wi = w + n; + (void)rwork; // unused sgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info); } +template<> void lapackEig, float>(char jobvl, char jobvr, int n, c10::complex *a, int lda, c10::complex *w, c10::complex *vl, int ldvl, c10::complex *vr, int ldvr, c10::complex *work, int lwork, float *rwork, int *info) { + AT_ERROR("lapackEig>"); +} + +template<> void lapackEig, double>(char jobvl, char jobvr, int n, c10::complex *a, int lda, c10::complex *w, c10::complex *vl, int ldvl, c10::complex *vr, int ldvr, c10::complex *work, int lwork, double *rwork, int *info) { + AT_ERROR("lapackEig>"); +} + + template<> void lapackSvd, double>(char jobz, int m, int n, c10::complex *a, int lda, double *s, c10::complex *u, int ldu, c10::complex *vt, int ldvt, c10::complex *work, int lwork, double *rwork, int *iwork, int *info) { zgesdd_(&jobz, &m, &n, reinterpret_cast*>(a), &lda, s, reinterpret_cast*>(u), &ldu, diff --git a/aten/src/ATen/native/BatchLinearAlgebra.h b/aten/src/ATen/native/BatchLinearAlgebra.h index b8a1ce5d508f..0d1784aa3cb6 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.h +++ b/aten/src/ATen/native/BatchLinearAlgebra.h @@ -12,8 +12,8 @@ namespace at { namespace native { // Define per-batch functions to be used in the implementation of batched // linear algebra operations -template -void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, int *info); +template +void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, value_t *rwork, int *info); #endif diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp index ca97f781814c..b6b4f76a2586 100644 --- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include // for USE_LAPACK @@ -15,6 +16,8 @@ void apply_eig(const Tensor& self, bool eigenvectors, Tensor& vals_, Tensor& vec TORCH_CHECK(false, "Calling torch.eig on a CPU tensor requires compiling ", "PyTorch with LAPACK. Please use PyTorch built with LAPACK support."); #else + using value_t = typename c10::scalar_value_type::type; + char jobvr = eigenvectors ? 'V' : 'N'; int64_t n = self.size(-1); auto self_data = self.data_ptr(); @@ -24,19 +27,20 @@ void apply_eig(const Tensor& self, bool eigenvectors, Tensor& vals_, Tensor& vec scalar_t* vecs_data = eigenvectors ? vecs_.data_ptr() : nullptr; int ldvr = eigenvectors ? n : 1; + value_t* rwork = nullptr; // XXX: remember to malloc() it for complex if (n > 0) { // call lapackEig once to get the optimal size for work data scalar_t wkopt; int info; - lapackEig('N', jobvr, n, self_data, n, wr, - nullptr, 1, vecs_data, ldvr, &wkopt, -1, &info); - int lwork = static_cast(wkopt); + lapackEig('N', jobvr, n, self_data, n, wr, + nullptr, 1, vecs_data, ldvr, &wkopt, -1, rwork, &info); + int lwork = static_cast(real_impl(wkopt)); // call again to do the actual work Tensor work = at::empty({lwork}, self.dtype()); - lapackEig('N', jobvr, n, self_data, n, wr, - nullptr, 1, vecs_data, ldvr, work.data_ptr(), lwork, &info); + lapackEig('N', jobvr, n, self_data, n, wr, + nullptr, 1, vecs_data, ldvr, work.data_ptr(), lwork, rwork, &info); *info_ptr = info; } #endif @@ -60,7 +64,7 @@ std::tuple eig_kernel_impl(const Tensor& self, bool& eigenvector : Tensor(); int64_t info; - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "eig_cpu", [&]{ + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "eig_cpu", [&]{ apply_eig(self_, eigenvectors, vals_, vecs_, &info); }); singleCheckErrors(info, "eig_cpu"); From 5b2793d1f64673e1c45ac6dadd753c637a9f74f3 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Fri, 11 Dec 2020 17:32:49 +0000 Subject: [PATCH 04/12] WIP: implement lapackEig>, and the corresponding call to cgeev_; test_eig_basic stil fails though --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 22 +++++++++++++++---- .../ATen/native/BatchLinearAlgebraKernel.cpp | 14 +++++++++--- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index da6c52faf40c..45085fbebf3c 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -82,6 +82,15 @@ extern "C" void ssyevd_(char *jobz, char *uplo, int *n, float *a, int *lda, floa // geev extern "C" void dgeev_(char *jobvl, char *jobvr, int *n, double *a, int *lda, double *wr, double *wi, double* vl, int *ldvl, double *vr, int *ldvr, double *work, int *lwork, int *info); extern "C" void sgeev_(char *jobvl, char *jobvr, int *n, float *a, int *lda, float *wr, float *wi, float* vl, int *ldvl, float *vr, int *ldvr, float *work, int *lwork, int *info); +extern "C" void cgeev_(char *jobvl, char *jobvr, int *n, + std::complex *a, int *lda, + std::complex *w, + std::complex *vl, int *ldvl, + std::complex *vr, int *ldvr, + std::complex *work, int *lwork, + float *rwork, + int *info); + // gesdd extern "C" void zgesdd_(char *jobz, int *m, int *n, std::complex *a, int *lda, @@ -328,14 +337,19 @@ template<> void lapackEig(char jobvl, char jobvr, int n, float *a, int ld sgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info); } -template<> void lapackEig, float>(char jobvl, char jobvr, int n, c10::complex *a, int lda, c10::complex *w, c10::complex *vl, int ldvl, c10::complex *vr, int ldvr, c10::complex *work, int lwork, float *rwork, int *info) { - AT_ERROR("lapackEig>"); -} - template<> void lapackEig, double>(char jobvl, char jobvr, int n, c10::complex *a, int lda, c10::complex *w, c10::complex *vl, int ldvl, c10::complex *vr, int ldvr, c10::complex *work, int lwork, double *rwork, int *info) { AT_ERROR("lapackEig>"); } +template<> void lapackEig, float>(char jobvl, char jobvr, int n, c10::complex *a, int lda, c10::complex *w, c10::complex *vl, int ldvl, c10::complex *vr, int ldvr, c10::complex *work, int lwork, float *rwork, int *info) { + cgeev_(&jobvl, &jobvr, &n, + reinterpret_cast*>(a), &lda, + reinterpret_cast*>(w), + reinterpret_cast*>(vl), &ldvl, + reinterpret_cast*>(vr), &ldvr, + reinterpret_cast*>(work), &lwork, + rwork, info); +} template<> void lapackSvd, double>(char jobz, int m, int n, c10::complex *a, int lda, double *s, c10::complex *u, int ldu, c10::complex *vt, int ldvt, c10::complex *work, int lwork, double *rwork, int *iwork, int *info) { diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp index b6b4f76a2586..a8eb10575222 100644 --- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -27,20 +27,28 @@ void apply_eig(const Tensor& self, bool eigenvectors, Tensor& vals_, Tensor& vec scalar_t* vecs_data = eigenvectors ? vecs_.data_ptr() : nullptr; int ldvr = eigenvectors ? n : 1; - value_t* rwork = nullptr; // XXX: remember to malloc() it for complex + + Tensor rwork; + value_t* rwork_data = nullptr; + if (isComplexType(at::typeMetaToScalarType(self.dtype()))) { + ScalarType dtype = toValueType(typeMetaToScalarType(self.dtype())); + // XXX: I would like a double/float dtype, not a complex dtype + rwork = at::empty({n*2}, self.options().dtype(dtype)); + rwork_data = rwork.data_ptr(); + } if (n > 0) { // call lapackEig once to get the optimal size for work data scalar_t wkopt; int info; lapackEig('N', jobvr, n, self_data, n, wr, - nullptr, 1, vecs_data, ldvr, &wkopt, -1, rwork, &info); + nullptr, 1, vecs_data, ldvr, &wkopt, -1, rwork_data, &info); int lwork = static_cast(real_impl(wkopt)); // call again to do the actual work Tensor work = at::empty({lwork}, self.dtype()); lapackEig('N', jobvr, n, self_data, n, wr, - nullptr, 1, vecs_data, ldvr, work.data_ptr(), lwork, rwork, &info); + nullptr, 1, vecs_data, ldvr, work.data_ptr(), lwork, rwork_data, &info); *info_ptr = info; } #endif From e1daece74a9f07ec5e57028ae815490b1150f2bd Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Thu, 24 Dec 2020 10:36:34 +0000 Subject: [PATCH 05/12] make sure that the shape of eigenvals is correct in the complex case --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 6 +++++- aten/src/ATen/native/BatchLinearAlgebraKernel.cpp | 12 +++++++++++- test/test_linalg.py | 11 +++++++---- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 45085fbebf3c..313d1eab634e 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1214,7 +1214,11 @@ std::tuple eig_out(Tensor& e, Tensor& v, const Tensor& self, b TORCH_CHECK(v.dtype() == self.dtype(), "Expected 'v' to have dtype ", self.dtype(), " but got ", v.dtype()); int64_t n = self.size(-1); - at::native::resize_output(e, {n, 2}); + if (isComplexType(at::typeMetaToScalarType(self.dtype()))) { + at::native::resize_output(e, {n}); + } else { + at::native::resize_output(e, {n, 2}); + } if (eigenvectors) { at::native::resize_output(v, self.sizes()); } diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp index a8eb10575222..63408f486a67 100644 --- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -66,7 +66,17 @@ std::tuple eig_kernel_impl(const Tensor& self, bool& eigenvector self_.copy_(self); auto options = self.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT); - Tensor vals_ = at::empty_strided({n, 2}, {1, n}, options); + + // the API is slightly different for the complex vs real case: if the input + // is complex, eigenvals will be a vector of complex. If the input is real, + // eigenvals will be a (n, 2) matrix containing the real and imaginary parts + // in each column + Tensor vals_; + if (isComplexType(at::typeMetaToScalarType(self.dtype()))) { + vals_ = at::empty({n}, options); + } else { + vals_ = at::empty_strided({n, 2}, {1, n}, options); + } Tensor vecs_ = eigenvectors ? at::empty_strided({n, n}, {1, n}, options) : Tensor(); diff --git a/test/test_linalg.py b/test/test_linalg.py index 12b6530888a0..5d70b32060d6 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1489,10 +1489,13 @@ def test_eig_basic(self, device, dtype): # # compare with numpy np_e, np_v = np.linalg.eig(a.cpu().numpy()) - # np_e.shape == (n, 2), where each column contain the real and - # imaginary parts of the result - self.assertEqual(ee[:, 0], np_e) # real part - self.assertEqual(ee[:, 1], torch.zeros(ee.shape[0], dtype=dtype)) # imaginary part + if dtype.is_complex: + self.assertEqual(ee, np_e) + else: + # np_e.shape == (n, 2), where each column contain the real and + # imaginary parts of the result + self.assertEqual(ee[:, 0], np_e) # real part + self.assertEqual(ee[:, 1], torch.zeros(ee.shape[0], dtype=dtype)) # imaginary part self.assertEqual(vv, np_v) @skipCPUIfNoLapack From ab8ac04f155266da678207dd9dbd51b54da8eacd Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Thu, 24 Dec 2020 10:46:25 +0000 Subject: [PATCH 06/12] add support for complex128 --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 313d1eab634e..f793a4356ac0 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -90,7 +90,14 @@ extern "C" void cgeev_(char *jobvl, char *jobvr, int *n, std::complex *work, int *lwork, float *rwork, int *info); - +extern "C" void zgeev_(char *jobvl, char *jobvr, int *n, + std::complex *a, int *lda, + std::complex *w, + std::complex *vl, int *ldvl, + std::complex *vr, int *ldvr, + std::complex *work, int *lwork, + double *rwork, + int *info); // gesdd extern "C" void zgesdd_(char *jobz, int *m, int *n, std::complex *a, int *lda, @@ -338,7 +345,13 @@ template<> void lapackEig(char jobvl, char jobvr, int n, float *a, int ld } template<> void lapackEig, double>(char jobvl, char jobvr, int n, c10::complex *a, int lda, c10::complex *w, c10::complex *vl, int ldvl, c10::complex *vr, int ldvr, c10::complex *work, int lwork, double *rwork, int *info) { - AT_ERROR("lapackEig>"); + zgeev_(&jobvl, &jobvr, &n, + reinterpret_cast*>(a), &lda, + reinterpret_cast*>(w), + reinterpret_cast*>(vl), &ldvl, + reinterpret_cast*>(vr), &ldvr, + reinterpret_cast*>(work), &lwork, + rwork, info); } template<> void lapackEig, float>(char jobvl, char jobvr, int n, c10::complex *a, int lda, c10::complex *w, c10::complex *vl, int ldvl, c10::complex *vr, int ldvr, c10::complex *work, int lwork, float *rwork, int *info) { From e909de64c8df50ed0374ae30cb9a37fd404e4944 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Thu, 24 Dec 2020 16:30:59 +0000 Subject: [PATCH 07/12] start to add CUDA support: change the signature of magmaEig to match the one used by lapackEig --- .../ATen/native/cuda/BatchLinearAlgebra.cu | 48 ++++++++++++++----- 1 file changed, 35 insertions(+), 13 deletions(-) diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index e5804ba389c5..e1d6a4ece3e9 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -125,11 +125,13 @@ void magmaSymeig( value_t* w, scalar_t* wA, magma_int_t ldwa, scalar_t* work, magma_int_t lwork, value_t* rwork, magma_int_t lrwork, magma_int_t* iwork, magma_int_t liwork, magma_int_t* info); -template +template void magmaEig( magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n, scalar_t *A, magma_int_t lda, - scalar_t *wr, scalar_t *wi, scalar_t *VL, magma_int_t ldvl, - scalar_t *VR, magma_int_t ldvr, scalar_t *work, magma_int_t lwork, magma_int_t *info); + scalar_t *w, scalar_t *VL, magma_int_t ldvl, + scalar_t *VR, magma_int_t ldvr, scalar_t *work, magma_int_t lwork, + value_t *rwork, + magma_int_t *info); template void magmaSvd( @@ -975,23 +977,41 @@ void magmaSymeig, float>( ldwa, reinterpret_cast(work), lwork, rwork, lrwork, iwork, liwork, info); AT_CUDA_CHECK(cudaGetLastError()); } - + template<> void magmaEig( - magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n, double *A, magma_int_t lda, - double *wr, double *wi, double *VL, magma_int_t ldvl, - double *VR, magma_int_t ldvr, double *work, magma_int_t lwork, magma_int_t *info) { + magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n, + double *A, magma_int_t lda, + double *w, + double *VL, magma_int_t ldvl, + double *VR, magma_int_t ldvr, + double *work, magma_int_t lwork, + double *rwork, + magma_int_t *info) { MagmaStreamSyncGuard guard; + // magma [sd]geev wants to separate output arrays: wr and wi for the real + // and imaginary parts + double *wr = w; + double *wi = w + n; + (void)rwork; // unused magma_dgeev(jobvl, jobvr, n, A, lda, wr, wi, VL, ldvl, VR, ldvr, work, lwork, info); AT_CUDA_CHECK(cudaGetLastError()); } template<> void magmaEig( - magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n, float *A, magma_int_t lda, - float *wr, float *wi, float *VL, magma_int_t ldvl, - float *VR, magma_int_t ldvr, float *work, magma_int_t lwork, magma_int_t *info) { + magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n, + float *A, magma_int_t lda, + float *w, + float *VL, magma_int_t ldvl, + float *VR, magma_int_t ldvr, + float *work, magma_int_t lwork, + float *rwork, + magma_int_t *info) { MagmaStreamSyncGuard guard; + float *wr = w; + float *wi = w + n; + (void)rwork; // unused magma_sgeev(jobvl, jobvr, n, A, lda, wr, wi, VL, ldvl, VR, ldvr, work, lwork, info); AT_CUDA_CHECK(cudaGetLastError()); } @@ -1910,13 +1930,13 @@ TORCH_CHECK(false, "Calling torch.eig on a CUDA tensor requires compiling PyTorc "Either transfer the tensor to the CPU before calling torch.eig or recompile with MAGMA."); #else TORCH_INTERNAL_ASSERT(self.device() == at::kCPU, "Internal error: apply_eig needs a CPU tensor"); + using value_t = typename c10::scalar_value_type::type; magma_vec_t jobvr = eigenvectors ? MagmaVec : MagmaNoVec; magma_int_t n = magma_int_cast(self.size(-1), "n"); auto self_data = self.data_ptr(); auto out_eigvals_data = out_eigvals.data_ptr(); scalar_t *wr = out_eigvals_data; - scalar_t *wi = out_eigvals_data+n; scalar_t *vr_data = NULL; magma_int_t ldvr = 1; @@ -1926,17 +1946,19 @@ TORCH_CHECK(false, "Calling torch.eig on a CUDA tensor requires compiling PyTorc ldvr = n; } + value_t *rwork_data = nullptr; + if (n > 0) { // call magmaEig once to get the optimal size of work_data scalar_t wkopt; magma_int_t info; - magmaEig(MagmaNoVec, jobvr, n, self_data, n, wr, wi, NULL, 1, vr_data, ldvr, &wkopt, -1, &info); + magmaEig(MagmaNoVec, jobvr, n, self_data, n, wr, NULL, 1, vr_data, ldvr, &wkopt, -1, rwork_data, &info); magma_int_t lwork = (magma_int_t) wkopt; // call it a 2nd time to to the actual work scalar_t *work_data = nullptr; ALLOCATE_ARRAY(work_data, scalar_t, lwork); - magmaEig(MagmaNoVec, jobvr, n, self_data, n, wr, wi, NULL, 1, vr_data, ldvr, work_data, lwork, &info); + magmaEig(MagmaNoVec, jobvr, n, self_data, n, wr, NULL, 1, vr_data, ldvr, work_data, lwork, rwork_data, &info); *info_ptr = info; } #endif From 07db9c63176b9df784c290c9ea89409f3d329e38 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Thu, 24 Dec 2020 16:44:10 +0000 Subject: [PATCH 08/12] WIP, untested: implement magmaEig for complex types --- .../ATen/native/cuda/BatchLinearAlgebra.cu | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index e1d6a4ece3e9..bd77d00a2926 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -1016,6 +1016,48 @@ void magmaEig( AT_CUDA_CHECK(cudaGetLastError()); } +template<> +void magmaEig, double>( + magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n, + c10::complex *A, magma_int_t lda, + c10::complex *w, + c10::complex *VL, magma_int_t ldvl, + c10::complex *VR, magma_int_t ldvr, + c10::complex *work, magma_int_t lwork, + double *rwork, + magma_int_t *info) { + MagmaStreamSyncGuard guard; + magma_zgeev(jobvl, jobvr, n, + reinterpret_cast(A), lda, + reinterpret_cast(w), + reinterpret_cast(VL), ldvl, + reinterpret_cast(VR), ldvr, + reinterpret_cast(work), lwork, + rwork, info); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template<> +void magmaEig, float>( + magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n, + c10::complex *A, magma_int_t lda, + c10::complex *w, + c10::complex *VL, magma_int_t ldvl, + c10::complex *VR, magma_int_t ldvr, + c10::complex *work, magma_int_t lwork, + float *rwork, + magma_int_t *info) { + MagmaStreamSyncGuard guard; + magma_cgeev(jobvl, jobvr, n, + reinterpret_cast(A), lda, + reinterpret_cast(w), + reinterpret_cast(VL), ldvl, + reinterpret_cast(VR), ldvr, + reinterpret_cast(work), lwork, + rwork, info); + AT_CUDA_CHECK(cudaGetLastError()); +} + template<> void magmaSvd( magma_vec_t jobz, magma_int_t m, magma_int_t n, double* A, From 5080f8853bc005ff7879291c52370c48d30e3f83 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Thu, 24 Dec 2020 17:06:03 +0000 Subject: [PATCH 09/12] add complext support for CUDA eig --- .../src/ATen/native/cuda/BatchLinearAlgebra.cu | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index bd77d00a2926..3719b921df23 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -1989,18 +1989,21 @@ TORCH_CHECK(false, "Calling torch.eig on a CUDA tensor requires compiling PyTorc } value_t *rwork_data = nullptr; + if (isComplexType(at::typeMetaToScalarType(self.dtype()))) { + ALLOCATE_ARRAY(rwork_data, value_t, n*2); + } if (n > 0) { // call magmaEig once to get the optimal size of work_data scalar_t wkopt; magma_int_t info; - magmaEig(MagmaNoVec, jobvr, n, self_data, n, wr, NULL, 1, vr_data, ldvr, &wkopt, -1, rwork_data, &info); - magma_int_t lwork = (magma_int_t) wkopt; + magmaEig(MagmaNoVec, jobvr, n, self_data, n, wr, NULL, 1, vr_data, ldvr, &wkopt, -1, rwork_data, &info); + magma_int_t lwork = static_cast(real_impl(wkopt)); // call it a 2nd time to to the actual work scalar_t *work_data = nullptr; ALLOCATE_ARRAY(work_data, scalar_t, lwork); - magmaEig(MagmaNoVec, jobvr, n, self_data, n, wr, NULL, 1, vr_data, ldvr, work_data, lwork, rwork_data, &info); + magmaEig(MagmaNoVec, jobvr, n, self_data, n, wr, NULL, 1, vr_data, ldvr, work_data, lwork, rwork_data, &info); *info_ptr = info; } #endif @@ -2023,13 +2026,18 @@ std::tuple eig_kernel_impl(const Tensor& self, bool& eigenvector // tensors holding the results. We use empty_strided to make them column-ordered auto options = self.options().device(at::kCPU).memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT); - auto out_eigvals = at::empty_strided({n, 2}, {1, n}, options); + Tensor out_eigvals; + if (isComplexType(at::typeMetaToScalarType(self.dtype()))) { + out_eigvals = at::empty({n}, options); + } else { + out_eigvals = at::empty_strided({n, 2}, {1, n}, options); + } auto out_eigvecs = eigenvectors ? at::empty_strided({n, n}, {1, n}, options) : Tensor(); int64_t info; - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "eig_cuda", [&]{ + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "eig_cuda", [&]{ apply_eig(self_working_copy, eigenvectors, out_eigvals, out_eigvecs, &info); }); singleCheckErrors(info, "eig_cuda"); From e1c1ebeac6ae01e525fc4e35f421d45dc802dab8 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Thu, 24 Dec 2020 17:28:36 +0000 Subject: [PATCH 10/12] kill the comment as the code already does what the XXX wanted. Use a better name for real_dtype to make it clearer --- aten/src/ATen/native/BatchLinearAlgebraKernel.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp index 63408f486a67..d14cdc782c77 100644 --- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -31,9 +31,8 @@ void apply_eig(const Tensor& self, bool eigenvectors, Tensor& vals_, Tensor& vec Tensor rwork; value_t* rwork_data = nullptr; if (isComplexType(at::typeMetaToScalarType(self.dtype()))) { - ScalarType dtype = toValueType(typeMetaToScalarType(self.dtype())); - // XXX: I would like a double/float dtype, not a complex dtype - rwork = at::empty({n*2}, self.options().dtype(dtype)); + ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype())); + rwork = at::empty({n*2}, self.options().dtype(real_dtype)); rwork_data = rwork.data_ptr(); } From 7f4060d6601f80a3a12ffe05c0ad0b562175c0d1 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Mon, 18 Jan 2021 13:41:38 +0000 Subject: [PATCH 11/12] this is a much simpler way to check whether the tensor is complex --- aten/src/ATen/native/BatchLinearAlgebraKernel.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp index d14cdc782c77..df2a407c4c19 100644 --- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -30,7 +30,7 @@ void apply_eig(const Tensor& self, bool eigenvectors, Tensor& vals_, Tensor& vec Tensor rwork; value_t* rwork_data = nullptr; - if (isComplexType(at::typeMetaToScalarType(self.dtype()))) { + if (self.is_complex()) { ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype())); rwork = at::empty({n*2}, self.options().dtype(real_dtype)); rwork_data = rwork.data_ptr(); @@ -71,7 +71,7 @@ std::tuple eig_kernel_impl(const Tensor& self, bool& eigenvector // eigenvals will be a (n, 2) matrix containing the real and imaginary parts // in each column Tensor vals_; - if (isComplexType(at::typeMetaToScalarType(self.dtype()))) { + if (self.is_complex()) { vals_ = at::empty({n}, options); } else { vals_ = at::empty_strided({n, 2}, {1, n}, options); From 01c168dab23e13dd7c07739e8ef5a0ad03aff284 Mon Sep 17 00:00:00 2001 From: Antonio Cuni Date: Tue, 19 Jan 2021 14:36:06 +0000 Subject: [PATCH 12/12] add a test to check what happens with eig on complex types --- test/test_linalg.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/test_linalg.py b/test/test_linalg.py index 4c35929cc85b..8c899b5bc199 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1539,6 +1539,17 @@ def test_eig_basic(self, device, dtype): self.assertEqual(ee[:, 1], torch.zeros(ee.shape[0], dtype=dtype)) # imaginary part self.assertEqual(vv, np_v) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.complex64, torch.complex128) + def test_eig_backward_complex(self, device, dtype): + # torch.eig's backward is not supported yet for complex types. We + # should kill this test once it's implemented. + a = torch.tensor([[1., 2], [3, 4]], device=device, dtype=dtype, requires_grad=True) + with self.assertRaisesRegex(RuntimeError, + "eig does not support automatic differentiation for outputs with complex dtype"): + e, v = torch.eig(a, True) + @skipCPUIfNoLapack @skipCUDAIfNoMagma @dtypes(torch.double, torch.float)