Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torch.eig complex forward (CPU, CUDA) #49168

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
56 changes: 53 additions & 3 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Expand Up @@ -82,6 +82,22 @@ extern "C" void ssyevd_(char *jobz, char *uplo, int *n, float *a, int *lda, floa
// geev
extern "C" void dgeev_(char *jobvl, char *jobvr, int *n, double *a, int *lda, double *wr, double *wi, double* vl, int *ldvl, double *vr, int *ldvr, double *work, int *lwork, int *info);
extern "C" void sgeev_(char *jobvl, char *jobvr, int *n, float *a, int *lda, float *wr, float *wi, float* vl, int *ldvl, float *vr, int *ldvr, float *work, int *lwork, int *info);
extern "C" void cgeev_(char *jobvl, char *jobvr, int *n,
std::complex<float> *a, int *lda,
std::complex<float> *w,
std::complex<float> *vl, int *ldvl,
std::complex<float> *vr, int *ldvr,
std::complex<float> *work, int *lwork,
float *rwork,
int *info);
extern "C" void zgeev_(char *jobvl, char *jobvr, int *n,
std::complex<double> *a, int *lda,
std::complex<double> *w,
std::complex<double> *vl, int *ldvl,
std::complex<double> *vr, int *ldvr,
std::complex<double> *work, int *lwork,
double *rwork,
int *info);

// gesdd
extern "C" void zgesdd_(char *jobz, int *m, int *n, std::complex<double> *a, int *lda,
Expand Down Expand Up @@ -310,14 +326,44 @@ template<> void lapackSyevd<float>(char jobz, char uplo, int n, float *a, int ld
ssyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info);
}

template<> void lapackEig<double>(char jobvl, char jobvr, int n, double *a, int lda, double *wr, double *wi, double* vl, int ldvl, double *vr, int ldvr, double *work, int lwork, int *info) {
template<> void lapackEig<double>(char jobvl, char jobvr, int n, double *a, int lda, double *w, double* vl, int ldvl, double *vr, int ldvr, double *work, int lwork, double *rwork, int *info) {
// lapack [sd]geev wants to separate output arrays: wr and wi for the real
// and imaginary parts
double *wr = w;
double *wi = w + n;
(void)rwork; // unused
dgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info);
}

template<> void lapackEig<float>(char jobvl, char jobvr, int n, float *a, int lda, float *wr, float *wi, float* vl, int ldvl, float *vr, int ldvr, float *work, int lwork, int *info) {
template<> void lapackEig<float>(char jobvl, char jobvr, int n, float *a, int lda, float *w, float* vl, int ldvl, float *vr, int ldvr, float *work, int lwork, float *rwork, int *info) {
// lapack [sd]geev wants to separate output arrays: wr and wi for the real
// and imaginary parts
float *wr = w;
float *wi = w + n;
(void)rwork; // unused
sgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info);
}

template<> void lapackEig<c10::complex<double>, double>(char jobvl, char jobvr, int n, c10::complex<double> *a, int lda, c10::complex<double> *w, c10::complex<double> *vl, int ldvl, c10::complex<double> *vr, int ldvr, c10::complex<double> *work, int lwork, double *rwork, int *info) {
zgeev_(&jobvl, &jobvr, &n,
reinterpret_cast<std::complex<double>*>(a), &lda,
reinterpret_cast<std::complex<double>*>(w),
reinterpret_cast<std::complex<double>*>(vl), &ldvl,
reinterpret_cast<std::complex<double>*>(vr), &ldvr,
reinterpret_cast<std::complex<double>*>(work), &lwork,
rwork, info);
}

template<> void lapackEig<c10::complex<float>, float>(char jobvl, char jobvr, int n, c10::complex<float> *a, int lda, c10::complex<float> *w, c10::complex<float> *vl, int ldvl, c10::complex<float> *vr, int ldvr, c10::complex<float> *work, int lwork, float *rwork, int *info) {
cgeev_(&jobvl, &jobvr, &n,
reinterpret_cast<std::complex<float>*>(a), &lda,
reinterpret_cast<std::complex<float>*>(w),
reinterpret_cast<std::complex<float>*>(vl), &ldvl,
reinterpret_cast<std::complex<float>*>(vr), &ldvr,
reinterpret_cast<std::complex<float>*>(work), &lwork,
rwork, info);
}

template<> void lapackSvd<c10::complex<double>, double>(char jobz, int m, int n, c10::complex<double> *a, int lda,
double *s, c10::complex<double> *u, int ldu, c10::complex<double> *vt, int ldvt, c10::complex<double> *work, int lwork, double *rwork, int *iwork, int *info) {
zgesdd_(&jobz, &m, &n, reinterpret_cast<std::complex<double>*>(a), &lda, s, reinterpret_cast<std::complex<double>*>(u), &ldu,
Expand Down Expand Up @@ -1373,7 +1419,11 @@ std::tuple<Tensor&, Tensor&> eig_out(Tensor& e, Tensor& v, const Tensor& self, b
TORCH_CHECK(v.dtype() == self.dtype(), "Expected 'v' to have dtype ", self.dtype(), " but got ", v.dtype());
int64_t n = self.size(-1);

at::native::resize_output(e, {n, 2});
if (isComplexType(at::typeMetaToScalarType(self.dtype()))) {
at::native::resize_output(e, {n});
} else {
at::native::resize_output(e, {n, 2});
}
if (eigenvectors) {
at::native::resize_output(v, self.sizes());
}
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/BatchLinearAlgebra.h
Expand Up @@ -12,8 +12,8 @@ namespace at { namespace native {
// Define per-batch functions to be used in the implementation of batched
// linear algebra operations

template<class scalar_t>
void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *wr, scalar_t *wi, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, int *info);
template<class scalar_t, class value_t=scalar_t>
void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, value_t *rwork, int *info);

#endif

Expand Down
36 changes: 28 additions & 8 deletions aten/src/ATen/native/BatchLinearAlgebraKernel.cpp
Expand Up @@ -2,6 +2,7 @@
#include <ATen/Dispatch.h>
#include <ATen/native/BatchLinearAlgebra.h>
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/cpu/zmath.h>

#include <TH/TH.h> // for USE_LAPACK

Expand All @@ -15,29 +16,38 @@ void apply_eig(const Tensor& self, bool eigenvectors, Tensor& vals_, Tensor& vec
TORCH_CHECK(false, "Calling torch.eig on a CPU tensor requires compiling ",
"PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
#else
using value_t = typename c10::scalar_value_type<scalar_t>::type;

char jobvr = eigenvectors ? 'V' : 'N';
int64_t n = self.size(-1);
auto self_data = self.data_ptr<scalar_t>();

auto vals_data = vals_.data_ptr<scalar_t>();
scalar_t* wr = vals_data;
scalar_t* wi = vals_data + n;

scalar_t* vecs_data = eigenvectors ? vecs_.data_ptr<scalar_t>() : nullptr;
int ldvr = eigenvectors ? n : 1;

Tensor rwork;
value_t* rwork_data = nullptr;
if (self.is_complex()) {
ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype()));
rwork = at::empty({n*2}, self.options().dtype(real_dtype));
rwork_data = rwork.data_ptr<value_t>();
}

if (n > 0) {
// call lapackEig once to get the optimal size for work data
scalar_t wkopt;
int info;
lapackEig<scalar_t>('N', jobvr, n, self_data, n, wr, wi,
nullptr, 1, vecs_data, ldvr, &wkopt, -1, &info);
int lwork = static_cast<int>(wkopt);
lapackEig<scalar_t, value_t>('N', jobvr, n, self_data, n, wr,
nullptr, 1, vecs_data, ldvr, &wkopt, -1, rwork_data, &info);
int lwork = static_cast<int>(real_impl<scalar_t, value_t>(wkopt));

// call again to do the actual work
Tensor work = at::empty({lwork}, self.dtype());
lapackEig<scalar_t>('N', jobvr, n, self_data, n, wr, wi,
nullptr, 1, vecs_data, ldvr, work.data_ptr<scalar_t>(), lwork, &info);
lapackEig<scalar_t, value_t>('N', jobvr, n, self_data, n, wr,
nullptr, 1, vecs_data, ldvr, work.data_ptr<scalar_t>(), lwork, rwork_data, &info);
*info_ptr = info;
}
#endif
Expand All @@ -55,13 +65,23 @@ std::tuple<Tensor, Tensor> eig_kernel_impl(const Tensor& self, bool& eigenvector
self_.copy_(self);

auto options = self.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT);
Tensor vals_ = at::empty_strided({n, 2}, {1, n}, options);

// the API is slightly different for the complex vs real case: if the input
// is complex, eigenvals will be a vector of complex. If the input is real,
// eigenvals will be a (n, 2) matrix containing the real and imaginary parts
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed in #43081, I think we should always return a complex tensor. This will be a bc breaking change, so we should only update this behavior for torch.linalg.eig

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that the point of torch.linalg.* was to be as compatible to numpy as possible. This would be an unnecessary breakage for people porting their code from numpy to pytorch.
What about adding a flag such as returns_complex=False (or =True if we want the correct-but-numpy-incompatible behavior by default) to let the user choose?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, @antocuni , looks like Numpy is already doing it:

In [18]: x = numpy.array([[1, 2], [-2, 1]], numpy.double)

In [19]: x
Out[19]: 
array([[ 1.,  2.],
       [-2.,  1.]])

In [20]: numpy.linalg.eig(x)
Out[20]: 
(array([1.+2.j, 1.-2.j]),
 array([[0.        -0.70710678j, 0.        +0.70710678j],
        [0.70710678+0.j        , 0.70710678-0.j        ]]))

Copy link
Contributor Author

@antocuni antocuni Jan 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ouch, my fault, I overlooked it. Ok, I'll implement the "proper" behavior then, thanks for pointing it out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 @nikitaved I just realized that torch.linalg.eig doesn't exists yet!
So I think that the current behavior of my branch is correct. I agree that torch.linalg.eig should always return complex, but I don't think there is anything we can do in this branch

// in each column
Tensor vals_;
if (self.is_complex()) {
vals_ = at::empty({n}, options);
} else {
vals_ = at::empty_strided({n, 2}, {1, n}, options);
}
Tensor vecs_ = eigenvectors
? at::empty_strided({n, n}, {1, n}, options)
: Tensor();

int64_t info;
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "eig_cpu", [&]{
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "eig_cpu", [&]{
apply_eig<scalar_t>(self_, eigenvectors, vals_, vecs_, &info);
});
singleCheckErrors(info, "eig_cpu");
Expand Down
104 changes: 88 additions & 16 deletions aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
Expand Up @@ -133,11 +133,13 @@ void magmaSymeig(
value_t* w, scalar_t* wA, magma_int_t ldwa, scalar_t* work, magma_int_t lwork, value_t* rwork,
magma_int_t lrwork, magma_int_t* iwork, magma_int_t liwork, magma_int_t* info);

template<class scalar_t>
template<class scalar_t, class value_t=scalar_t>
void magmaEig(
magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n, scalar_t *A, magma_int_t lda,
scalar_t *wr, scalar_t *wi, scalar_t *VL, magma_int_t ldvl,
scalar_t *VR, magma_int_t ldvr, scalar_t *work, magma_int_t lwork, magma_int_t *info);
scalar_t *w, scalar_t *VL, magma_int_t ldvl,
scalar_t *VR, magma_int_t ldvr, scalar_t *work, magma_int_t lwork,
value_t *rwork,
magma_int_t *info);

template<class scalar_t, class value_t=scalar_t>
void magmaSvd(
Expand Down Expand Up @@ -1055,27 +1057,87 @@ void magmaSymeig<c10::complex<float>, float>(
ldwa, reinterpret_cast<magmaFloatComplex*>(work), lwork, rwork, lrwork, iwork, liwork, info);
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaEig<double>(
magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n, double *A, magma_int_t lda,
double *wr, double *wi, double *VL, magma_int_t ldvl,
double *VR, magma_int_t ldvr, double *work, magma_int_t lwork, magma_int_t *info) {
magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n,
double *A, magma_int_t lda,
double *w,
double *VL, magma_int_t ldvl,
double *VR, magma_int_t ldvr,
double *work, magma_int_t lwork,
double *rwork,
magma_int_t *info) {
MagmaStreamSyncGuard guard;
// magma [sd]geev wants to separate output arrays: wr and wi for the real
// and imaginary parts
double *wr = w;
double *wi = w + n;
(void)rwork; // unused
magma_dgeev(jobvl, jobvr, n, A, lda, wr, wi, VL, ldvl, VR, ldvr, work, lwork, info);
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaEig<float>(
magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n, float *A, magma_int_t lda,
float *wr, float *wi, float *VL, magma_int_t ldvl,
float *VR, magma_int_t ldvr, float *work, magma_int_t lwork, magma_int_t *info) {
magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n,
float *A, magma_int_t lda,
float *w,
float *VL, magma_int_t ldvl,
float *VR, magma_int_t ldvr,
float *work, magma_int_t lwork,
float *rwork,
magma_int_t *info) {
MagmaStreamSyncGuard guard;
float *wr = w;
float *wi = w + n;
(void)rwork; // unused
magma_sgeev(jobvl, jobvr, n, A, lda, wr, wi, VL, ldvl, VR, ldvr, work, lwork, info);
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaEig<c10::complex<double>, double>(
magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n,
c10::complex<double> *A, magma_int_t lda,
c10::complex<double> *w,
c10::complex<double> *VL, magma_int_t ldvl,
c10::complex<double> *VR, magma_int_t ldvr,
c10::complex<double> *work, magma_int_t lwork,
double *rwork,
magma_int_t *info) {
MagmaStreamSyncGuard guard;
magma_zgeev(jobvl, jobvr, n,
reinterpret_cast<magmaDoubleComplex*>(A), lda,
reinterpret_cast<magmaDoubleComplex*>(w),
reinterpret_cast<magmaDoubleComplex*>(VL), ldvl,
reinterpret_cast<magmaDoubleComplex*>(VR), ldvr,
reinterpret_cast<magmaDoubleComplex*>(work), lwork,
rwork, info);
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaEig<c10::complex<float>, float>(
magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n,
c10::complex<float> *A, magma_int_t lda,
c10::complex<float> *w,
c10::complex<float> *VL, magma_int_t ldvl,
c10::complex<float> *VR, magma_int_t ldvr,
c10::complex<float> *work, magma_int_t lwork,
float *rwork,
magma_int_t *info) {
MagmaStreamSyncGuard guard;
magma_cgeev(jobvl, jobvr, n,
reinterpret_cast<magmaFloatComplex*>(A), lda,
reinterpret_cast<magmaFloatComplex*>(w),
reinterpret_cast<magmaFloatComplex*>(VL), ldvl,
reinterpret_cast<magmaFloatComplex*>(VR), ldvr,
reinterpret_cast<magmaFloatComplex*>(work), lwork,
rwork, info);
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaSvd<double>(
magma_vec_t jobz, magma_int_t m, magma_int_t n, double* A,
Expand Down Expand Up @@ -2054,13 +2116,13 @@ TORCH_CHECK(false, "Calling torch.eig on a CUDA tensor requires compiling PyTorc
"Either transfer the tensor to the CPU before calling torch.eig or recompile with MAGMA.");
#else
TORCH_INTERNAL_ASSERT(self.device() == at::kCPU, "Internal error: apply_eig needs a CPU tensor");
using value_t = typename c10::scalar_value_type<scalar_t>::type;
magma_vec_t jobvr = eigenvectors ? MagmaVec : MagmaNoVec;
magma_int_t n = magma_int_cast(self.size(-1), "n");
auto self_data = self.data_ptr<scalar_t>();

auto out_eigvals_data = out_eigvals.data_ptr<scalar_t>();
scalar_t *wr = out_eigvals_data;
scalar_t *wi = out_eigvals_data+n;

scalar_t *vr_data = NULL;
magma_int_t ldvr = 1;
Expand All @@ -2070,17 +2132,22 @@ TORCH_CHECK(false, "Calling torch.eig on a CUDA tensor requires compiling PyTorc
ldvr = n;
}

value_t *rwork_data = nullptr;
if (isComplexType(at::typeMetaToScalarType(self.dtype()))) {
ALLOCATE_ARRAY(rwork_data, value_t, n*2);
}

if (n > 0) {
// call magmaEig once to get the optimal size of work_data
scalar_t wkopt;
magma_int_t info;
magmaEig<scalar_t>(MagmaNoVec, jobvr, n, self_data, n, wr, wi, NULL, 1, vr_data, ldvr, &wkopt, -1, &info);
magma_int_t lwork = (magma_int_t) wkopt;
magmaEig<scalar_t, value_t>(MagmaNoVec, jobvr, n, self_data, n, wr, NULL, 1, vr_data, ldvr, &wkopt, -1, rwork_data, &info);
magma_int_t lwork = static_cast<magma_int_t>(real_impl<scalar_t, value_t>(wkopt));

// call it a 2nd time to to the actual work
scalar_t *work_data = nullptr;
ALLOCATE_ARRAY(work_data, scalar_t, lwork);
magmaEig<scalar_t>(MagmaNoVec, jobvr, n, self_data, n, wr, wi, NULL, 1, vr_data, ldvr, work_data, lwork, &info);
magmaEig<scalar_t, value_t>(MagmaNoVec, jobvr, n, self_data, n, wr, NULL, 1, vr_data, ldvr, work_data, lwork, rwork_data, &info);
*info_ptr = info;
}
#endif
Expand All @@ -2103,13 +2170,18 @@ std::tuple<Tensor, Tensor> eig_kernel_impl(const Tensor& self, bool& eigenvector

// tensors holding the results. We use empty_strided to make them column-ordered
auto options = self.options().device(at::kCPU).memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto out_eigvals = at::empty_strided({n, 2}, {1, n}, options);
Tensor out_eigvals;
if (isComplexType(at::typeMetaToScalarType(self.dtype()))) {
out_eigvals = at::empty({n}, options);
} else {
out_eigvals = at::empty_strided({n, 2}, {1, n}, options);
}
auto out_eigvecs = eigenvectors
? at::empty_strided({n, n}, {1, n}, options)
: Tensor();

int64_t info;
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "eig_cuda", [&]{
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "eig_cuda", [&]{
apply_eig<scalar_t>(self_working_copy, eigenvectors, out_eigvals, out_eigvecs, &info);
});
singleCheckErrors(info, "eig_cuda");
Expand Down
24 changes: 19 additions & 5 deletions test/test_linalg.py
Expand Up @@ -1509,7 +1509,7 @@ def test_norm_fastpaths(self, device):

@skipCPUIfNoLapack
@skipCUDAIfNoMagma
@dtypes(torch.double, torch.float)
@dtypes(*floating_and_complex_types())
def test_eig_basic(self, device, dtype):
a = torch.tensor([[1.96, 0.00, 0.00, 0.00, 0.00],
[-6.49, 3.80, 0.00, 0.00, 0.00],
Expand All @@ -1530,12 +1530,26 @@ def test_eig_basic(self, device, dtype):
#
# compare with numpy
np_e, np_v = np.linalg.eig(a.cpu().numpy())
# np_e.shape == (n, 2), where each column contain the real and
# imaginary parts of the result
self.assertEqual(ee[:, 0], np_e) # real part
self.assertEqual(ee[:, 1], torch.zeros(ee.shape[0], dtype=dtype)) # imaginary part
if dtype.is_complex:
self.assertEqual(ee, np_e)
else:
# np_e.shape == (n, 2), where each column contain the real and
# imaginary parts of the result
self.assertEqual(ee[:, 0], np_e) # real part
self.assertEqual(ee[:, 1], torch.zeros(ee.shape[0], dtype=dtype)) # imaginary part
self.assertEqual(vv, np_v)

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.complex64, torch.complex128)
def test_eig_backward_complex(self, device, dtype):
# torch.eig's backward is not supported yet for complex types. We
# should kill this test once it's implemented.
a = torch.tensor([[1., 2], [3, 4]], device=device, dtype=dtype, requires_grad=True)
with self.assertRaisesRegex(RuntimeError,
"eig does not support automatic differentiation for outputs with complex dtype"):
e, v = torch.eig(a, True)

@skipCPUIfNoLapack
@skipCUDAIfNoMagma
@dtypes(torch.double, torch.float)
Expand Down