Skip to content

Commit

Permalink
Added support for complex input for torch.lu_solve pytorch#2 (pytorch…
Browse files Browse the repository at this point in the history
…#48028)

Summary:
Relanding pytorch#46862
There was an issue with the simultaneous merge of two slightly conflicting PRs.

This PR adds `torch.lu_solve` for complex inputs both on CPU and GPU.

Pull Request resolved: pytorch#48028

Reviewed By: linbinyu

Differential Revision: D25003700

Pulled By: zou3519

fbshipit-source-id: 24cd1babe9ccdbaa4e2ed23f08a9153d40d0f0cd
  • Loading branch information
IvanYashchuk authored and shaibagon committed Dec 3, 2020
1 parent 0b6b973 commit 32ec54d
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 12 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/BatchLinearAlgebra.cpp
Expand Up @@ -1315,7 +1315,7 @@ Tensor _lu_solve_helper_cpu(const Tensor& self, const Tensor& LU_data, const Ten
if (self.numel() == 0 || LU_data.numel() == 0) {
return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lu_solve_cpu", [&]{
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "lu_solve_cpu", [&]{
apply_lu_solve<scalar_t>(self_working_copy, LU_data_working_copy, LU_pivots_working_copy, infos);
});
if (self.dim() > 2) {
Expand Down
37 changes: 36 additions & 1 deletion aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
Expand Up @@ -1025,6 +1025,23 @@ void magmaLuSolve<float>(
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaLuSolve<c10::complex<double>>(
magma_int_t n, magma_int_t nrhs, c10::complex<double>* dA, magma_int_t ldda, magma_int_t* ipiv,
c10::complex<double>* dB, magma_int_t lddb, magma_int_t* info) {
MagmaStreamSyncGuard guard;
magma_zgetrs_gpu(MagmaNoTrans, n, nrhs, reinterpret_cast<magmaDoubleComplex*>(dA), ldda, ipiv, reinterpret_cast<magmaDoubleComplex*>(dB), lddb, info);
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaLuSolve<c10::complex<float>>(
magma_int_t n, magma_int_t nrhs, c10::complex<float>* dA, magma_int_t ldda, magma_int_t* ipiv,
c10::complex<float>* dB, magma_int_t lddb, magma_int_t* info) {
MagmaStreamSyncGuard guard;
magma_cgetrs_gpu(MagmaNoTrans, n, nrhs, reinterpret_cast<magmaFloatComplex*>(dA), ldda, ipiv, reinterpret_cast<magmaFloatComplex*>(dB), lddb, info);
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaLuSolveBatched<double>(
Expand All @@ -1043,6 +1060,24 @@ void magmaLuSolveBatched<float>(
info = magma_sgetrs_batched(MagmaNoTrans, n, nrhs, dA_array, ldda, dipiv_array, dB_array, lddb, batchsize, magma_queue.get_queue());
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaLuSolveBatched<c10::complex<double>>(
magma_int_t n, magma_int_t nrhs, c10::complex<double>** dA_array, magma_int_t ldda, magma_int_t** dipiv_array,
c10::complex<double>** dB_array, magma_int_t lddb, magma_int_t& info,
magma_int_t batchsize, const MAGMAQueue& magma_queue) {
info = magma_zgetrs_batched(MagmaNoTrans, n, nrhs, reinterpret_cast<magmaDoubleComplex**>(dA_array), ldda, dipiv_array, reinterpret_cast<magmaDoubleComplex**>(dB_array), lddb, batchsize, magma_queue.get_queue());
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaLuSolveBatched<c10::complex<float>>(
magma_int_t n, magma_int_t nrhs, c10::complex<float>** dA_array, magma_int_t ldda, magma_int_t** dipiv_array,
c10::complex<float>** dB_array, magma_int_t lddb, magma_int_t& info,
magma_int_t batchsize, const MAGMAQueue& magma_queue) {
info = magma_cgetrs_batched(MagmaNoTrans, n, nrhs, reinterpret_cast<magmaFloatComplex**>(dA_array), ldda, dipiv_array, reinterpret_cast<magmaFloatComplex**>(dB_array), lddb, batchsize, magma_queue.get_queue());
AT_CUDA_CHECK(cudaGetLastError());
}
#endif

#define ALLOCATE_ARRAY(name, type, size) \
Expand Down Expand Up @@ -2149,7 +2184,7 @@ Tensor _lu_solve_helper_cuda(const Tensor& self, const Tensor& LU_data, const Te
if (self.numel() == 0 || LU_data.numel() == 0) {
return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lu_solve_cuda", [&]{
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "lu_solve_cuda", [&]{
apply_lu_solve<scalar_t>(self_working_copy, LU_data_working_copy, LU_pivots_working_copy, info);
});
TORCH_CHECK(info == 0, "MAGMA lu_solve : invalid argument: ", -info);
Expand Down
25 changes: 15 additions & 10 deletions test/test_linalg.py
Expand Up @@ -4839,7 +4839,7 @@ def maybe_squeeze_result(l, r, result):

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.double)
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_lu_solve_batched_non_contiguous(self, device, dtype):
from numpy.linalg import solve
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
Expand All @@ -4858,28 +4858,32 @@ def lu_solve_test_helper(self, A_dims, b_dims, pivot, device, dtype):
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value

b = torch.randn(*b_dims, dtype=dtype, device=device)
A = random_fullrank_matrix_distinct_singular_value(*A_dims, dtype=dtype, device=device)
A = random_fullrank_matrix_distinct_singular_value(*A_dims, dtype=dtype).to(device)
LU_data, LU_pivots, info = torch.lu(A, get_infos=True, pivot=pivot)
self.assertEqual(info, torch.zeros_like(info))
return b, A, LU_data, LU_pivots

@skipCPUIfNoLapack
@skipCUDAIfNoMagma
@dtypes(torch.double)
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
torch.float64: 1e-8, torch.complex128: 1e-8})
def test_lu_solve(self, device, dtype):
def sub_test(pivot):
for k, n in zip([2, 3, 5], [3, 5, 7]):
b, A, LU_data, LU_pivots = self.lu_solve_test_helper((n,), (n, k), pivot, device, dtype)
x = torch.lu_solve(b, LU_data, LU_pivots)
self.assertLessEqual(b.dist(A.mm(x)), 1e-12)
self.assertEqual(b, A.mm(x))

sub_test(True)
if self.device_type == 'cuda':
sub_test(False)

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.double)
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
torch.float64: 1e-8, torch.complex128: 1e-8})
def test_lu_solve_batched(self, device, dtype):
def sub_test(pivot):
def lu_solve_batch_test_helper(A_dims, b_dims, pivot):
Expand All @@ -4890,7 +4894,8 @@ def lu_solve_batch_test_helper(A_dims, b_dims, pivot):
x_exp = torch.stack(x_exp_list) # Stacked output
x_act = torch.lu_solve(b, LU_data, LU_pivots) # Actual output
self.assertEqual(x_exp, x_act) # Equality check
self.assertLessEqual(b.dist(torch.matmul(A, x_act)), 1e-12) # Correctness check
Ax = torch.matmul(A, x_act)
self.assertEqual(b, Ax)

for batchsize in [1, 3, 4]:
lu_solve_batch_test_helper((5, batchsize), (batchsize, 5, 10), pivot)
Expand All @@ -4908,20 +4913,20 @@ def lu_solve_batch_test_helper(A_dims, b_dims, pivot):
@slowTest
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.double)
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_lu_solve_batched_many_batches(self, device, dtype):
def run_test(A_dims, b_dims):
b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype)
x = torch.lu_solve(b, LU_data, LU_pivots)
b_ = torch.matmul(A, x)
self.assertEqual(b_, b.expand_as(b_))
Ax = torch.matmul(A, x)
self.assertEqual(Ax, b.expand_as(Ax))

run_test((5, 65536), (65536, 5, 10))
run_test((5, 262144), (262144, 5, 10))

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.double)
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_lu_solve_batched_broadcasting(self, device, dtype):
from numpy.linalg import solve
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
Expand Down
2 changes: 2 additions & 0 deletions torch/_torch_docs.py
Expand Up @@ -4577,6 +4577,8 @@ def merge_dicts(*dicts):
Returns the LU solve of the linear system :math:`Ax = b` using the partially pivoted
LU factorization of A from :meth:`torch.lu`.
This function supports ``float``, ``double``, ``cfloat`` and ``cdouble`` dtypes for :attr:`input`.
Arguments:
b (Tensor): the RHS tensor of size :math:`(*, m, k)`, where :math:`*`
is zero or more batch dimensions.
Expand Down

0 comments on commit 32ec54d

Please sign in to comment.