Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for complex input for torch.lu_solve #2 #48028

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/native/BatchLinearAlgebra.cpp
Expand Up @@ -1315,7 +1315,7 @@ Tensor _lu_solve_helper_cpu(const Tensor& self, const Tensor& LU_data, const Ten
if (self.numel() == 0 || LU_data.numel() == 0) {
return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lu_solve_cpu", [&]{
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "lu_solve_cpu", [&]{
apply_lu_solve<scalar_t>(self_working_copy, LU_data_working_copy, LU_pivots_working_copy, infos);
});
if (self.dim() > 2) {
Expand Down
37 changes: 36 additions & 1 deletion aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
Expand Up @@ -1025,6 +1025,23 @@ void magmaLuSolve<float>(
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaLuSolve<c10::complex<double>>(
magma_int_t n, magma_int_t nrhs, c10::complex<double>* dA, magma_int_t ldda, magma_int_t* ipiv,
c10::complex<double>* dB, magma_int_t lddb, magma_int_t* info) {
MagmaStreamSyncGuard guard;
magma_zgetrs_gpu(MagmaNoTrans, n, nrhs, reinterpret_cast<magmaDoubleComplex*>(dA), ldda, ipiv, reinterpret_cast<magmaDoubleComplex*>(dB), lddb, info);
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaLuSolve<c10::complex<float>>(
magma_int_t n, magma_int_t nrhs, c10::complex<float>* dA, magma_int_t ldda, magma_int_t* ipiv,
c10::complex<float>* dB, magma_int_t lddb, magma_int_t* info) {
MagmaStreamSyncGuard guard;
magma_cgetrs_gpu(MagmaNoTrans, n, nrhs, reinterpret_cast<magmaFloatComplex*>(dA), ldda, ipiv, reinterpret_cast<magmaFloatComplex*>(dB), lddb, info);
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaLuSolveBatched<double>(
Expand All @@ -1043,6 +1060,24 @@ void magmaLuSolveBatched<float>(
info = magma_sgetrs_batched(MagmaNoTrans, n, nrhs, dA_array, ldda, dipiv_array, dB_array, lddb, batchsize, magma_queue.get_queue());
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaLuSolveBatched<c10::complex<double>>(
magma_int_t n, magma_int_t nrhs, c10::complex<double>** dA_array, magma_int_t ldda, magma_int_t** dipiv_array,
c10::complex<double>** dB_array, magma_int_t lddb, magma_int_t& info,
magma_int_t batchsize, const MAGMAQueue& magma_queue) {
info = magma_zgetrs_batched(MagmaNoTrans, n, nrhs, reinterpret_cast<magmaDoubleComplex**>(dA_array), ldda, dipiv_array, reinterpret_cast<magmaDoubleComplex**>(dB_array), lddb, batchsize, magma_queue.get_queue());
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaLuSolveBatched<c10::complex<float>>(
magma_int_t n, magma_int_t nrhs, c10::complex<float>** dA_array, magma_int_t ldda, magma_int_t** dipiv_array,
c10::complex<float>** dB_array, magma_int_t lddb, magma_int_t& info,
magma_int_t batchsize, const MAGMAQueue& magma_queue) {
info = magma_cgetrs_batched(MagmaNoTrans, n, nrhs, reinterpret_cast<magmaFloatComplex**>(dA_array), ldda, dipiv_array, reinterpret_cast<magmaFloatComplex**>(dB_array), lddb, batchsize, magma_queue.get_queue());
AT_CUDA_CHECK(cudaGetLastError());
}
#endif

#define ALLOCATE_ARRAY(name, type, size) \
Expand Down Expand Up @@ -2149,7 +2184,7 @@ Tensor _lu_solve_helper_cuda(const Tensor& self, const Tensor& LU_data, const Te
if (self.numel() == 0 || LU_data.numel() == 0) {
return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lu_solve_cuda", [&]{
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "lu_solve_cuda", [&]{
apply_lu_solve<scalar_t>(self_working_copy, LU_data_working_copy, LU_pivots_working_copy, info);
});
TORCH_CHECK(info == 0, "MAGMA lu_solve : invalid argument: ", -info);
Expand Down
25 changes: 15 additions & 10 deletions test/test_linalg.py
Expand Up @@ -4746,7 +4746,7 @@ def maybe_squeeze_result(l, r, result):

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.double)
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_lu_solve_batched_non_contiguous(self, device, dtype):
from numpy.linalg import solve
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
Expand All @@ -4765,28 +4765,32 @@ def lu_solve_test_helper(self, A_dims, b_dims, pivot, device, dtype):
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value

b = torch.randn(*b_dims, dtype=dtype, device=device)
A = random_fullrank_matrix_distinct_singular_value(*A_dims, dtype=dtype, device=device)
A = random_fullrank_matrix_distinct_singular_value(*A_dims, dtype=dtype).to(device)
LU_data, LU_pivots, info = torch.lu(A, get_infos=True, pivot=pivot)
self.assertEqual(info, torch.zeros_like(info))
return b, A, LU_data, LU_pivots

@skipCPUIfNoLapack
@skipCUDAIfNoMagma
@dtypes(torch.double)
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
torch.float64: 1e-8, torch.complex128: 1e-8})
def test_lu_solve(self, device, dtype):
def sub_test(pivot):
for k, n in zip([2, 3, 5], [3, 5, 7]):
b, A, LU_data, LU_pivots = self.lu_solve_test_helper((n,), (n, k), pivot, device, dtype)
x = torch.lu_solve(b, LU_data, LU_pivots)
self.assertLessEqual(b.dist(A.mm(x)), 1e-12)
self.assertEqual(b, A.mm(x))

sub_test(True)
if self.device_type == 'cuda':
sub_test(False)

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.double)
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
torch.float64: 1e-8, torch.complex128: 1e-8})
def test_lu_solve_batched(self, device, dtype):
def sub_test(pivot):
def lu_solve_batch_test_helper(A_dims, b_dims, pivot):
Expand All @@ -4797,7 +4801,8 @@ def lu_solve_batch_test_helper(A_dims, b_dims, pivot):
x_exp = torch.stack(x_exp_list) # Stacked output
x_act = torch.lu_solve(b, LU_data, LU_pivots) # Actual output
self.assertEqual(x_exp, x_act) # Equality check
self.assertLessEqual(b.dist(torch.matmul(A, x_act)), 1e-12) # Correctness check
Ax = torch.matmul(A, x_act)
self.assertEqual(b, Ax)

for batchsize in [1, 3, 4]:
lu_solve_batch_test_helper((5, batchsize), (batchsize, 5, 10), pivot)
Expand All @@ -4815,20 +4820,20 @@ def lu_solve_batch_test_helper(A_dims, b_dims, pivot):
@slowTest
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.double)
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_lu_solve_batched_many_batches(self, device, dtype):
def run_test(A_dims, b_dims):
b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype)
x = torch.lu_solve(b, LU_data, LU_pivots)
b_ = torch.matmul(A, x)
self.assertEqual(b_, b.expand_as(b_))
Ax = torch.matmul(A, x)
self.assertEqual(Ax, b.expand_as(Ax))

run_test((5, 65536), (65536, 5, 10))
run_test((5, 262144), (262144, 5, 10))

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.double)
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_lu_solve_batched_broadcasting(self, device, dtype):
from numpy.linalg import solve
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
Expand Down
2 changes: 2 additions & 0 deletions torch/_torch_docs.py
Expand Up @@ -4577,6 +4577,8 @@ def merge_dicts(*dicts):
Returns the LU solve of the linear system :math:`Ax = b` using the partially pivoted
LU factorization of A from :meth:`torch.lu`.

This function supports ``float``, ``double``, ``cfloat`` and ``cdouble`` dtypes for :attr:`input`.

Arguments:
b (Tensor): the RHS tensor of size :math:`(*, m, k)`, where :math:`*`
is zero or more batch dimensions.
Expand Down