diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 37b7c5bbb223..9cc040b4dc8f 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1315,7 +1315,7 @@ Tensor _lu_solve_helper_cpu(const Tensor& self, const Tensor& LU_data, const Ten if (self.numel() == 0 || LU_data.numel() == 0) { return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lu_solve_cpu", [&]{ + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "lu_solve_cpu", [&]{ apply_lu_solve(self_working_copy, LU_data_working_copy, LU_pivots_working_copy, infos); }); if (self.dim() > 2) { diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index a9cdfbb65705..eaee3c87b1f8 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -1025,6 +1025,23 @@ void magmaLuSolve( AT_CUDA_CHECK(cudaGetLastError()); } +template<> +void magmaLuSolve>( + magma_int_t n, magma_int_t nrhs, c10::complex* dA, magma_int_t ldda, magma_int_t* ipiv, + c10::complex* dB, magma_int_t lddb, magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_zgetrs_gpu(MagmaNoTrans, n, nrhs, reinterpret_cast(dA), ldda, ipiv, reinterpret_cast(dB), lddb, info); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template<> +void magmaLuSolve>( + magma_int_t n, magma_int_t nrhs, c10::complex* dA, magma_int_t ldda, magma_int_t* ipiv, + c10::complex* dB, magma_int_t lddb, magma_int_t* info) { + MagmaStreamSyncGuard guard; + magma_cgetrs_gpu(MagmaNoTrans, n, nrhs, reinterpret_cast(dA), ldda, ipiv, reinterpret_cast(dB), lddb, info); + AT_CUDA_CHECK(cudaGetLastError()); +} template<> void magmaLuSolveBatched( @@ -1043,6 +1060,24 @@ void magmaLuSolveBatched( info = magma_sgetrs_batched(MagmaNoTrans, n, nrhs, dA_array, ldda, dipiv_array, dB_array, lddb, batchsize, magma_queue.get_queue()); AT_CUDA_CHECK(cudaGetLastError()); } + +template<> +void magmaLuSolveBatched>( + magma_int_t n, magma_int_t nrhs, c10::complex** dA_array, magma_int_t ldda, magma_int_t** dipiv_array, + c10::complex** dB_array, magma_int_t lddb, magma_int_t& info, + magma_int_t batchsize, const MAGMAQueue& magma_queue) { + info = magma_zgetrs_batched(MagmaNoTrans, n, nrhs, reinterpret_cast(dA_array), ldda, dipiv_array, reinterpret_cast(dB_array), lddb, batchsize, magma_queue.get_queue()); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template<> +void magmaLuSolveBatched>( + magma_int_t n, magma_int_t nrhs, c10::complex** dA_array, magma_int_t ldda, magma_int_t** dipiv_array, + c10::complex** dB_array, magma_int_t lddb, magma_int_t& info, + magma_int_t batchsize, const MAGMAQueue& magma_queue) { + info = magma_cgetrs_batched(MagmaNoTrans, n, nrhs, reinterpret_cast(dA_array), ldda, dipiv_array, reinterpret_cast(dB_array), lddb, batchsize, magma_queue.get_queue()); + AT_CUDA_CHECK(cudaGetLastError()); +} #endif #define ALLOCATE_ARRAY(name, type, size) \ @@ -2149,7 +2184,7 @@ Tensor _lu_solve_helper_cuda(const Tensor& self, const Tensor& LU_data, const Te if (self.numel() == 0 || LU_data.numel() == 0) { return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lu_solve_cuda", [&]{ + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "lu_solve_cuda", [&]{ apply_lu_solve(self_working_copy, LU_data_working_copy, LU_pivots_working_copy, info); }); TORCH_CHECK(info == 0, "MAGMA lu_solve : invalid argument: ", -info); diff --git a/test/test_linalg.py b/test/test_linalg.py index 2b84e22491d6..33c61fb60ff9 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -4839,7 +4839,7 @@ def maybe_squeeze_result(l, r, result): @skipCUDAIfNoMagma @skipCPUIfNoLapack - @dtypes(torch.double) + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) def test_lu_solve_batched_non_contiguous(self, device, dtype): from numpy.linalg import solve from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value @@ -4858,20 +4858,22 @@ def lu_solve_test_helper(self, A_dims, b_dims, pivot, device, dtype): from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value b = torch.randn(*b_dims, dtype=dtype, device=device) - A = random_fullrank_matrix_distinct_singular_value(*A_dims, dtype=dtype, device=device) + A = random_fullrank_matrix_distinct_singular_value(*A_dims, dtype=dtype).to(device) LU_data, LU_pivots, info = torch.lu(A, get_infos=True, pivot=pivot) self.assertEqual(info, torch.zeros_like(info)) return b, A, LU_data, LU_pivots @skipCPUIfNoLapack @skipCUDAIfNoMagma - @dtypes(torch.double) + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, + torch.float64: 1e-8, torch.complex128: 1e-8}) def test_lu_solve(self, device, dtype): def sub_test(pivot): for k, n in zip([2, 3, 5], [3, 5, 7]): b, A, LU_data, LU_pivots = self.lu_solve_test_helper((n,), (n, k), pivot, device, dtype) x = torch.lu_solve(b, LU_data, LU_pivots) - self.assertLessEqual(b.dist(A.mm(x)), 1e-12) + self.assertEqual(b, A.mm(x)) sub_test(True) if self.device_type == 'cuda': @@ -4879,7 +4881,9 @@ def sub_test(pivot): @skipCUDAIfNoMagma @skipCPUIfNoLapack - @dtypes(torch.double) + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, + torch.float64: 1e-8, torch.complex128: 1e-8}) def test_lu_solve_batched(self, device, dtype): def sub_test(pivot): def lu_solve_batch_test_helper(A_dims, b_dims, pivot): @@ -4890,7 +4894,8 @@ def lu_solve_batch_test_helper(A_dims, b_dims, pivot): x_exp = torch.stack(x_exp_list) # Stacked output x_act = torch.lu_solve(b, LU_data, LU_pivots) # Actual output self.assertEqual(x_exp, x_act) # Equality check - self.assertLessEqual(b.dist(torch.matmul(A, x_act)), 1e-12) # Correctness check + Ax = torch.matmul(A, x_act) + self.assertEqual(b, Ax) for batchsize in [1, 3, 4]: lu_solve_batch_test_helper((5, batchsize), (batchsize, 5, 10), pivot) @@ -4908,20 +4913,20 @@ def lu_solve_batch_test_helper(A_dims, b_dims, pivot): @slowTest @skipCUDAIfNoMagma @skipCPUIfNoLapack - @dtypes(torch.double) + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) def test_lu_solve_batched_many_batches(self, device, dtype): def run_test(A_dims, b_dims): b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype) x = torch.lu_solve(b, LU_data, LU_pivots) - b_ = torch.matmul(A, x) - self.assertEqual(b_, b.expand_as(b_)) + Ax = torch.matmul(A, x) + self.assertEqual(Ax, b.expand_as(Ax)) run_test((5, 65536), (65536, 5, 10)) run_test((5, 262144), (262144, 5, 10)) @skipCUDAIfNoMagma @skipCPUIfNoLapack - @dtypes(torch.double) + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) def test_lu_solve_batched_broadcasting(self, device, dtype): from numpy.linalg import solve from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 3b6ee12e7a68..dd4be74dde80 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -4577,6 +4577,8 @@ def merge_dicts(*dicts): Returns the LU solve of the linear system :math:`Ax = b` using the partially pivoted LU factorization of A from :meth:`torch.lu`. +This function supports ``float``, ``double``, ``cfloat`` and ``cdouble`` dtypes for :attr:`input`. + Arguments: b (Tensor): the RHS tensor of size :math:`(*, m, k)`, where :math:`*` is zero or more batch dimensions.