diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap index 4749d0fb30773..8b670b7db17ea 100644 --- a/aten/src/ATen/Declarations.cwrap +++ b/aten/src/ATen/Declarations.cwrap @@ -2671,7 +2671,7 @@ default: U ]] [[ - name: _th_potrs + name: _th_potrs_single cname: potrs types: - Float diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index c9cf16b92f719..b441c7a412eb3 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -81,8 +81,6 @@ _(aten, _floor) \ _(aten, _fused_dropout) \ _(aten, _ger) \ _(aten, _gesv_helper) \ -_(aten, _gesv_single) \ -_(aten, _getri_single) \ _(aten, _indexCopy) \ _(aten, _indices) \ _(aten, _inverse_helper) \ @@ -103,6 +101,7 @@ _(aten, _pack_padded_sequence_backward) \ _(aten, _pad_packed_sequence) \ _(aten, _pdist_backward) \ _(aten, _pdist_forward) \ +_(aten, _potrs_helper) \ _(aten, _prod) \ _(aten, _prodall) \ _(aten, _range) \ diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 6d8aca9f4884f..6f451f078af0a 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1,7 +1,6 @@ #include "ATen/ATen.h" #include "ATen/CPUApplyUtils.h" #include "ATen/Dispatch.h" -#include "ATen/ExpandUtils.h" #include "ATen/NativeFunctions.h" #include "ATen/native/LinearAlgebraUtils.h" @@ -16,14 +15,18 @@ #ifdef USE_LAPACK // gesv -extern "C" void dgesv_(int* n, int* nrhs, double* a, int* lda, int *ipiv, double* b, int* ldb, int* info); -extern "C" void sgesv_(int* n, int* nrhs, float* a, int* lda, int* ipiv, float* b, int* ldb, int* info); +extern "C" void dgesv_(int *n, int *nrhs, double *a, int *lda, int *ipiv, double *b, int *ldb, int *info); +extern "C" void sgesv_(int *n, int *nrhs, float *a, int *lda, int *ipiv, float *b, int *ldb, int *info); // inverse extern "C" void dgetrf_(int *m, int *n, double *a, int *lda, int *ipiv, int *info); extern "C" void sgetrf_(int *m, int *n, float *a, int *lda, int *ipiv, int *info); extern "C" void dgetri_(int *n, double *a, int *lda, int *ipiv, double *work, int *lwork, int *info); extern "C" void sgetri_(int *n, float *a, int *lda, int *ipiv, float *work, int *lwork, int *info); + +// potrs +extern "C" void dpotrs_(char *uplo, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info); +extern "C" void spotrs_(char *uplo, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info); #endif namespace at { @@ -32,12 +35,12 @@ namespace native { // Define the per-batch functions to be used in the main implementation of the batched // linear algebra operations template -void lapackGesv(int n, int nrhs, scalar_t* a, int lda, int* ipiv, scalar_t* b, int ldb, int* info) { +void lapackGesv(int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info) { AT_ERROR("gesv only takes float or double Tensors"); } template -void lapackGetrf(int m, int n, scalar_t* a, int lda, int *ipiv, int *info) { +void lapackGetrf(int m, int n, scalar_t *a, int lda, int *ipiv, int *info) { AT_ERROR("getrf only takes float or double Tensors"); } @@ -46,12 +49,17 @@ void lapackGetri(int n, scalar_t *a, int lda, int *ipiv, scalar_t *work, int lwo AT_ERROR("getri only takes float or double Tensors"); } +template +void lapackPotrs(char uplo, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int *info) { + AT_ERROR("potrs only takes float or double Tensors"); +} + #ifdef USE_LAPACK -template<> void lapackGesv(int n, int nrhs, double* a, int lda, int* ipiv, double* b, int ldb, int* info) { +template<> void lapackGesv(int n, int nrhs, double *a, int lda, int *ipiv, double *b, int ldb, int *info) { dgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info); } -template<> void lapackGesv(int n, int nrhs, float* a, int lda, int* ipiv, float* b, int ldb, int* info) { +template<> void lapackGesv(int n, int nrhs, float *a, int lda, int *ipiv, float *b, int ldb, int *info) { sgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info); } @@ -70,6 +78,14 @@ template<> void lapackGetrf(int m, int n, double *a, int lda, int *ipiv, template<> void lapackGetrf(int m, int n, float *a, int lda, int *ipiv, int *info) { sgetrf_(&m, &n, a, &lda, ipiv, info); } + +template<> void lapackPotrs(char uplo, int n, int nrhs, double *a, int lda, double *b, int ldb, int *info) { + dpotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info); +} + +template<> void lapackPotrs(char uplo, int n, int nrhs, float *a, int lda, float *b, int ldb, int *info) { + spotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info); +} #endif // Below of the definitions of the functions operating on a batch that are going to be dispatched @@ -105,8 +121,16 @@ static void apply_gesv(Tensor& b, Tensor& A, std::vector& infos) { } } -// These utilities are specified in LinearAlgebraUtils.h -GENERATE_LINALG_HELPER_2_ARGS(gesv, self, A, cpu) +std::tuple _gesv_helper_cpu(const Tensor& self, const Tensor& A) { + std::vector infos(batchCount(self), 0); + auto self_working_copy = cloneBatchedColumnMajor(self); + auto A_working_copy = cloneBatchedColumnMajor(A); + AT_DISPATCH_FLOATING_TYPES(self.type(), "gesv", [&]{ + apply_gesv(self_working_copy, A_working_copy, infos); + }); + batchCheckErrors(infos, "gesv"); + return std::tuple(self_working_copy, A_working_copy); +} // Supports arbitrary batch dimensions for self and A std::tuple gesv(const Tensor& self, const Tensor& A) { @@ -117,21 +141,8 @@ std::tuple gesv(const Tensor& self, const Tensor& A) { return at::_th_gesv_single(self, A); } - gesvCheckInputs(self, A); - - // broadcast the batch dimensions of self and A. - IntList self_batch_sizes(self.sizes().data(), self.ndimension() - 2); - IntList A_batch_sizes(A.sizes().data(), A.ndimension() - 2); - std::vector expand_batch_portion = infer_size(self_batch_sizes, A_batch_sizes); - - std::vector self_expand_size({expand_batch_portion}); - self_expand_size.insert(self_expand_size.end(), { self.size(-2), self.size(-1) }); - - std::vector A_expand_size({expand_batch_portion}); - A_expand_size.insert(A_expand_size.end(), { A.size(-2), A.size(-1) }); - - Tensor self_broadcasted = self.expand(self_expand_size); - Tensor A_broadcasted = A.expand(A_expand_size); + Tensor self_broadcasted, A_broadcasted; + std::tie(self_broadcasted, A_broadcasted) = _linear_solve_broadcast_args(self, A); return at::_gesv_helper(self_broadcasted, A_broadcasted); } @@ -185,7 +196,15 @@ static void apply_inverse(Tensor& self, std::vector& infos) { } } -GENERATE_LINALG_HELPER_1_ARGS(inverse, self, cpu) +Tensor _inverse_helper_cpu(const Tensor& self) { + std::vector infos(batchCount(self), 0); + auto self_working_copy = cloneBatchedColumnMajor(self); + AT_DISPATCH_FLOATING_TYPES(self.type(), "inverse", [&]{ + apply_inverse(self_working_copy, infos); + }); + batchCheckErrors(infos, "inverse"); + return self_working_copy; +} Tensor inverse(const Tensor &self) { if (self.size(-1) == 0) { @@ -206,4 +225,63 @@ Tensor& inverse_out(Tensor &result, const Tensor &self) { return result; } +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ potrs ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +static void apply_potrs(Tensor& b, Tensor& A, bool upper, std::vector& infos) { +#ifndef USE_LAPACK + AT_ERROR("potrs: LAPACK library not found in compilation"); +#endif + char uplo = upper ? 'U' : 'L'; + + auto A_data = A.data(); + auto b_data = b.data(); + auto A_mat_stride = matrixStride(A); + auto b_mat_stride = matrixStride(b); + + auto batch_size = batchCount(A); + auto n = A.size(-2); + auto nrhs = b.size(-1); + + for (int64_t i = 0; i < batch_size; i++) { + int info; + scalar_t* A_working_ptr = &A_data[i * A_mat_stride]; + scalar_t* b_working_ptr = &b_data[i * b_mat_stride]; + lapackPotrs(uplo, n, nrhs, A_working_ptr, n, b_working_ptr, n, &info); + infos[i] = info; + if (info != 0) { + return; + } + } +} + +Tensor _potrs_helper_cpu(const Tensor& self, const Tensor& A, bool upper) { + std::vector infos(batchCount(self), 0); + auto self_working_copy = cloneBatchedColumnMajor(self); + auto A_working_copy = cloneBatchedColumnMajor(A); + AT_DISPATCH_FLOATING_TYPES(self.type(), "potrs", [&]{ + apply_potrs(self_working_copy, A_working_copy, upper, infos); + }); + batchCheckErrors(infos, "potrs"); + return self_working_copy; +} + +// Supports arbitrary batch dimensions for self and A +Tensor potrs(const Tensor& self, const Tensor& A, bool upper) { + if (self.dim() <= 2 && A.dim() <= 2) { + return at::_th_potrs_single(self, A, upper); + } + + Tensor self_broadcasted, A_broadcasted; + std::tie(self_broadcasted, A_broadcasted) = _linear_solve_broadcast_args(self, A); + return at::_potrs_helper(self_broadcasted, A_broadcasted, upper); +} + +Tensor& potrs_out(Tensor& result, const Tensor& self, const Tensor& A, bool upper) { + AT_CHECK(self.dim() == 2 && A.dim() == 2, + "torch.potrs() with the `out` keyword does not support batching. " + "b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2."); + return at::_th_potrs_single_out(result, self, A, upper); +} + }} // namespace at::native diff --git a/aten/src/ATen/native/LegacyDefinitions.cpp b/aten/src/ATen/native/LegacyDefinitions.cpp index a63eafdc7d71f..7b90cefb7753e 100644 --- a/aten/src/ATen/native/LegacyDefinitions.cpp +++ b/aten/src/ATen/native/LegacyDefinitions.cpp @@ -491,14 +491,6 @@ Tensor cholesky(const Tensor & self, bool upper) { return at::_th_potrf(self, upper); } -Tensor & potrs_out(Tensor & result, const Tensor & self, const Tensor & input2, bool upper) { - return at::_th_potrs_out(result, self, input2, upper); -} - -Tensor potrs(const Tensor & self, const Tensor & input2, bool upper) { - return at::_th_potrs(self, input2, upper); -} - Tensor & potri_out(Tensor & result, const Tensor & self, bool upper) { return at::_th_potri_out(result, self, upper); } diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h index d9fd1264c68f9..9c467f20b3d90 100644 --- a/aten/src/ATen/native/LinearAlgebraUtils.h +++ b/aten/src/ATen/native/LinearAlgebraUtils.h @@ -1,4 +1,5 @@ #include "ATen/ATen.h" +#include "ATen/ExpandUtils.h" #include namespace at { namespace native { @@ -52,8 +53,8 @@ static inline double _get_epsilon(const ScalarType& sc_type) { } } -// Validates input shapes for gesv -static inline void gesvCheckInputs(const Tensor& self, const Tensor& A) { +// Validates input shapes for linear solve methods (gesv, potrs) +static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A) { AT_CHECK(A.size(-1) == A.size(-2), "A must be batches of square matrices, " "but they are ", A.size(-1), " by ", A.size(-2), " matrices"); @@ -87,29 +88,6 @@ static inline void batchCheckErrors(std::vector& infos, const char* nam } } -#define GENERATE_LINALG_HELPER_1_ARGS(NAME, ARG, BACKEND) \ - Tensor _##NAME##_helper_##BACKEND(const Tensor& ARG) { \ - std::vector infos(batchCount(ARG), 0); \ - auto ARG##_working_copy = cloneBatchedColumnMajor(ARG); \ - AT_DISPATCH_FLOATING_TYPES(ARG.type(), #NAME, [&]{ \ - apply_##NAME(ARG##_working_copy, infos); \ - }); \ - batchCheckErrors(infos, #NAME); \ - return ARG##_working_copy; \ - } - -#define GENERATE_LINALG_HELPER_2_ARGS(NAME, ARG1, ARG2, BACKEND) \ - std::tuple _##NAME##_helper_##BACKEND(const Tensor& ARG1, const Tensor& ARG2) { \ - std::vector infos(batchCount(ARG1), 0); \ - auto ARG1##_working_copy = cloneBatchedColumnMajor(ARG1); \ - auto ARG2##_working_copy = cloneBatchedColumnMajor(ARG2); \ - AT_DISPATCH_FLOATING_TYPES(ARG1.type(), #NAME, [&]{ \ - apply_##NAME(ARG1##_working_copy, ARG2##_working_copy, infos); \ - }); \ - batchCheckErrors(infos, #NAME); \ - return std::tuple(ARG1##_working_copy, ARG2##_working_copy); \ - } - // Checks if all the Tensors in a TensorList are of the same dimensions static inline void checkAllSameDim(TensorList tensors, int64_t dim) { for (auto &t : tensors) { @@ -117,4 +95,23 @@ static inline void checkAllSameDim(TensorList tensors, int64_t dim) { } } +static inline std::tuple _linear_solve_broadcast_args(const Tensor& arg1, const Tensor& arg2) { + linearSolveCheckInputs(arg1, arg2); + + // broadcast the batch dimensions of arg1 and arg2. + IntList arg1_batch_sizes(arg1.sizes().data(), arg1.ndimension() - 2); + IntList arg2_batch_sizes(arg2.sizes().data(), arg2.ndimension() - 2); + std::vector expand_batch_portion = infer_size(arg1_batch_sizes, arg2_batch_sizes); + + std::vector arg1_expand_size({expand_batch_portion}); + arg1_expand_size.insert(arg1_expand_size.end(), { arg1.size(-2), arg1.size(-1) }); + + std::vector arg2_expand_size({expand_batch_portion}); + arg2_expand_size.insert(arg2_expand_size.end(), { arg2.size(-2), arg2.size(-1) }); + + Tensor arg1_broadcasted = arg1.expand(arg1_expand_size); + Tensor arg2_broadcasted = arg2.expand(arg2_expand_size); + return std::make_tuple(arg1_broadcasted, arg2_broadcasted); +} + }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index ab8eabd48705a..7b2d7bd086702 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -43,6 +43,13 @@ void magmaGetriBatched( AT_ERROR("getri only takes float or double Tensors"); } +template +void magmaPotrsBatched( + magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, scalar_t** dA_array, magma_int_t ldda, + scalar_t** dB_array, magma_int_t lddb, magma_int_t& info, magma_int_t batchsize, const MAGMAQueue& magma_queue) { + AT_ERROR("potrs only takes float or double Tensors"); +} + template<> void magmaGesvBatched( magma_int_t n, magma_int_t nrhs, double** dA_array, magma_int_t ldda, @@ -90,12 +97,28 @@ void magmaGetriBatched( magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) { magma_sgetri_outofplace_batched(n, dA_array, ldda, ipiv_array, dinvA_array, lddia, info_array, batchsize, magma_queue.get_queue()); } + +template<> +void magmaPotrsBatched( + magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, double** dA_array, magma_int_t ldda, + double** dB_array, magma_int_t lddb, magma_int_t& info, magma_int_t batchsize, const MAGMAQueue& magma_queue) { + info = magma_dpotrs_batched(uplo, n, nrhs, dA_array, ldda, dB_array, lddb, batchsize, magma_queue.get_queue()); +} + +template<> +void magmaPotrsBatched( + magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, float** dA_array, magma_int_t ldda, + float** dB_array, magma_int_t lddb, magma_int_t& info, magma_int_t batchsize, const MAGMAQueue& magma_queue) { + info = magma_spotrs_batched(uplo, n, nrhs, dA_array, ldda, dB_array, lddb, batchsize, magma_queue.get_queue()); +} #endif #define ALLOCATE_ARRAY(name, type, size, dummy_tensor) \ auto storage_##name = pin_memory(size, dummy_tensor); \ name = static_cast(storage_##name.data()); +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ gesv ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + template static void apply_gesv(Tensor& b, Tensor& A, std::vector& infos) { #ifndef USE_MAGMA @@ -141,7 +164,19 @@ AT_ERROR("gesv: MAGMA library not found in " #endif } -GENERATE_LINALG_HELPER_2_ARGS(gesv, self, A, cuda) +std::tuple _gesv_helper_cuda(const Tensor& self, const Tensor& A) { + std::vector infos(batchCount(self), 0); + auto self_working_copy = cloneBatchedColumnMajor(self); + auto A_working_copy = cloneBatchedColumnMajor(A); + AT_DISPATCH_FLOATING_TYPES(self.type(), "gesv", [&]{ + apply_gesv(self_working_copy, A_working_copy, infos); + }); + batchCheckErrors(infos, "gesv"); + return std::tuple(self_working_copy, A_working_copy); +} + + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ inverse ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template static void apply_inverse(Tensor &self, Tensor &self_inv, std::vector& infos) { @@ -205,6 +240,63 @@ Tensor _inverse_helper_cuda(const Tensor& self) { return self_inv_working_copy; } +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ potrs ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +static void apply_potrs(Tensor& b, Tensor& A, bool upper, int64_t& info) { +#ifndef USE_MAGMA +AT_ERROR("potrs: MAGMA library not found in " + "compilation. Please rebuild with MAGMA."); +#else + magma_uplo_t uplo = upper ? MagmaUpper : MagmaLower; + + auto A_data = A.data(); + auto b_data = b.data(); + auto A_mat_stride = matrixStride(A); + auto b_mat_stride = matrixStride(b); + + magma_int_t batch_size = magma_int_cast(batchCount(A), "batchCount"); + magma_int_t n = magma_int_cast(A.size(-2), "A.size(-2)"); + magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)"); + + magma_int_t info_tmp; + magma_int_t* ipiv_data; + magma_int_t** ipiv_array; + scalar_t** A_array; + scalar_t** b_array; + + ALLOCATE_ARRAY(ipiv_data, magma_int_t, batch_size * n, b); + ALLOCATE_ARRAY(ipiv_array, magma_int_t*, batch_size, b); + ALLOCATE_ARRAY(A_array, scalar_t*, batch_size, b); + ALLOCATE_ARRAY(b_array, scalar_t*, batch_size, b); + + // Set up the created arrays + for (int64_t i = 0; i < batch_size; i++) { + A_array[i] = &A_data[i * A_mat_stride]; + b_array[i] = &b_data[i * b_mat_stride]; + ipiv_array[i] = &ipiv_data[i * n]; + } + + MAGMAQueue magma_queue(b.get_device()); + magmaPotrsBatched( + uplo, n, nrhs, A_array, n, b_array, n, + info_tmp, batch_size, magma_queue); + + info = info_tmp; +#endif +} + +Tensor _potrs_helper_cuda(const Tensor& self, const Tensor& A, bool upper) { + int64_t info = 0; + auto self_working_copy = cloneBatchedColumnMajor(self); + auto A_working_copy = cloneBatchedColumnMajor(A); + AT_DISPATCH_FLOATING_TYPES(self.type(), "potrs", [&]{ + apply_potrs(self_working_copy, A_working_copy, upper, info); + }); + AT_CHECK(info == 0, "MAGMA potrs : invalid argument: ", -info); + return self_working_copy; +} + }} // namespace at::native #undef ALLOCATE_ARRAY diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 807cf77cd7d28..a8e219097f688 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -914,6 +914,7 @@ - func: inverse_out(Tensor result, Tensor self) -> Tensor - func: _inverse_helper(Tensor self) -> Tensor + variants: function dispatch: CPU: _inverse_helper_cpu CUDA: _inverse_helper_cuda @@ -2907,6 +2908,12 @@ variants: method, function device_guard: false +- func: _potrs_helper(Tensor self, Tensor A, bool upper) -> Tensor + variants: function + dispatch: + CPU: _potrs_helper_cpu + CUDA: _potrs_helper_cuda + - func: potri_out(Tensor result, Tensor self, bool upper=true) -> Tensor device_guard: false diff --git a/test/common_utils.py b/test/common_utils.py index 1685fa411be4b..2a3acc694fecc 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -662,9 +662,9 @@ def random_symmetric_psd_matrix(l): return A.mm(A.transpose(0, 1)) -def random_symmetric_pd_matrix(l, eps=1e-5): - A = torch.randn(l, l) - return A.mm(A.transpose(0, 1)) + torch.eye(l) * eps +def random_symmetric_pd_matrix(l, *batches): + A = torch.randn(*(batches + (l, l))) + return A.matmul(A.transpose(-2, -1)) + torch.eye(l) * 1e-5 def make_nonzero_det(A, sign=None, min_singular_value=0.1): diff --git a/test/test_cuda.py b/test/test_cuda.py index 2dd8936c33f3c..c6bdc32a7c745 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1571,6 +1571,18 @@ def test_gesv_batched(self): def test_gesv_batched_dims(self): _TestTorchMixin._test_gesv_batched_dims(self, lambda t: t.cuda()) + @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") + def test_potrs(self): + _TestTorchMixin._test_potrs(self, lambda t: t.cuda()) + + @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") + def test_potrs_batched(self): + _TestTorchMixin._test_potrs_batched(self, lambda t: t.cuda()) + + @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") + def test_potrs_batched_dims(self): + _TestTorchMixin._test_potrs_batched_dims(self, lambda t: t.cuda()) + def test_view(self): _TestTorchMixin._test_view(self, lambda t: t.cuda()) diff --git a/test/test_torch.py b/test/test_torch.py index 7a40467cca9e5..c6e669b86f682 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -5346,8 +5346,8 @@ def test_cholesky(self): B = torch.mm(L, L.t()) self.assertEqual(A, B, 1e-14, 'cholesky (lower) did not allow rebuilding the original matrix') - @skipIfNoLapack - def test_potrs(self): + @staticmethod + def _test_potrs(self, cast): a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23), (-6.05, -3.30, 5.36, -4.44, 1.08), (-0.45, 2.58, -2.70, 0.27, 9.04), @@ -5359,6 +5359,7 @@ def test_potrs(self): # make sure 'a' is symmetric PSD a = torch.mm(a, a.t()) + a, b = cast(a), cast(b) # upper Triangular Test U = torch.cholesky(a, True) @@ -5370,6 +5371,102 @@ def test_potrs(self): x = torch.potrs(b, L, False) self.assertLessEqual(b.dist(torch.mm(a, x)), 1e-12) + @skipIfNoLapack + def test_potrs(self): + self._test_potrs(self, lambda t: t) + + @staticmethod + def _test_potrs_batched(self, cast): + from common_utils import random_symmetric_pd_matrix + + # TODO: This function should be replaced after batch potrf is ready + def get_cholesky(bmat, upper): + n = bmat.size(-1) + cholesky = torch.stack([m.cholesky(upper) for m in bmat.reshape(-1, n, n)]) + return cholesky.reshape_as(bmat) + + def potrs_test_helper(A_dims, b_dims, cast, upper): + A = cast(random_symmetric_pd_matrix(*A_dims)) + L = get_cholesky(A, upper) + b = cast(torch.randn(*b_dims)) + return A, L, b + + for upper in [True, False]: + # test against potrs: one batch with both choices of upper + A, L, b = potrs_test_helper((5, 1), (1, 5, 10), cast, upper) + x_exp = torch.potrs(b.squeeze(0), L.squeeze(0), upper=upper) + x = torch.potrs(b, L, upper=upper) + self.assertEqual(x, x_exp.unsqueeze(0)) + + # test against potrs in a loop: four batches with both choices of upper + A, L, b = potrs_test_helper((5, 4), (4, 5, 10), cast, upper) + x_exp_list = list() + for i in range(4): + x_exp = torch.potrs(b[i], L[i], upper=upper) + x_exp_list.append(x_exp) + x_exp = torch.stack(x_exp_list) + + x = torch.potrs(b, L, upper=upper) + self.assertEqual(x, x_exp) + + # basic correctness test + A, L, b = potrs_test_helper((5, 3), (3, 5, 10), cast, upper) + x = torch.potrs(b, L, upper) + self.assertLessEqual(b.dist(torch.matmul(A, x)), 1e-12) + + # Test non-contiguous inputs. + if not TEST_NUMPY: + return + import numpy + from numpy.linalg import solve + A = random_symmetric_pd_matrix(2, 2) + b = torch.randn(2, 2, 2) + x_exp = torch.Tensor(solve(A.permute(0, 2, 1).numpy(), b.permute(2, 1, 0).numpy())) + A = cast(A).permute(0, 2, 1) + b = cast(b).permute(2, 1, 0) + assert not A.is_contiguous() and not b.is_contiguous(), "contiguous inputs" + L = get_cholesky(A, upper) + x = torch.potrs(b, L, upper=upper) + self.assertEqual(x, cast(x_exp)) + + @skipIfNoLapack + def test_potrs_batched(self): + self._test_potrs_batched(self, lambda t: t) + + @staticmethod + def _test_potrs_batched_dims(self, cast): + if not TEST_NUMPY: + return + + from numpy.linalg import solve + from common_utils import random_symmetric_pd_matrix + + # TODO: This function should be replaced after batch potrf is ready + def get_cholesky(bmat, upper): + n = bmat.size(-1) + cholesky = torch.stack([m.cholesky(upper) for m in bmat.reshape(-1, n, n)]) + return cholesky.reshape_as(bmat) + + def run_test(A_dims, b_dims, cast, upper): + A = random_symmetric_pd_matrix(*A_dims) + b = torch.randn(*b_dims) + x_exp = torch.Tensor(solve(A.numpy(), b.numpy())) + A, b = cast(A), cast(b) + L = get_cholesky(A, upper) + x = torch.potrs(b, L, upper=upper) + self.assertEqual(x, cast(x_exp)) + + for upper in [True, False]: + # test against numpy.linalg.solve + run_test((4, 2, 1, 3), (2, 1, 3, 4, 6), cast, upper) # no broadcasting + run_test((4, 2, 1, 3), (4, 6), cast, upper) # broadcasting b + run_test((4,), (2, 1, 3, 4, 2), cast, upper) # broadcasting A + run_test((4, 1, 3, 1), (2, 1, 3, 4, 5), cast, upper) # broadcasting A & b + + @skipIfNoLapack + def test_potrs_batched_dims(self): + self._test_potrs_batched_dims(self, lambda t: t) + @skipIfNoLapack def test_potri(self): a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23), diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 98fb7ceb764e7..342794fc42d35 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -580,8 +580,8 @@ self: not_implemented("potri") - name: potrs(Tensor self, Tensor input2, bool upper) - self: not_implemented("potri") - input2: not_implemented("potri") + self: not_implemented("potrs") + input2: not_implemented("potrs") - name: pow(Tensor self, Scalar exponent) self: pow_backward(grad, self, exponent) diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 3c888475b4e7f..ba894ac8b2b02 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -25,8 +25,8 @@ 'index', '_indexCopy_', 'max_values', 'min_values', 'argmax', 'argmin', '_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*', '_th_.*', - 'arange.*', 'range.*', '_gesv.*', '_getri.*', '_inverse.*', 'slice', - 'randint(_out)?', + 'arange.*', 'range.*', '_gesv.*', '_getri.*', '_inverse.*', '_potrs.*', + 'slice', 'randint(_out)?', '_local_scalar', '_local_scalar_dense', 'max_pool1d', 'max_pool2d', 'max_pool3d', 'linear', 'to', 'copy_sparse_to_sparse_', diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 3f6ad343fa8f5..c7b5773694d4c 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -3415,11 +3415,21 @@ def parse_kwargs(desc): .. math:: c = (u u^T)^{-1} b -.. note:: :attr:`b` is always a 2-D tensor, use `b.unsqueeze(1)` to convert a vector. +`torch.potrs(b, u)` can take in 2D inputs `b, u` or inputs that are +batches of 2D matrices. If the inputs are batches, then returns +batched outputs `c` + +.. note:: + + The :attr:`out` keyword only supports 2D matrix inputs, that is, + `b, u` must be 2D matrices. Args: - b (Tensor): the right hand side 2-D tensor - u (Tensor): the input 2-D tensor, a upper or lower triangular Cholesky factor + b (Tensor): input matrix of size :math:`(*, m, k)`, + where :math:`*` is zero or more batch dimensions + u (Tensor): input matrix of size :math:`(*, m, m)`, + where :math:`*` is zero of more batch dimensions composed of + upper or lower triangular Cholesky factor upper (bool, optional): whether to return a upper (default) or lower triangular matrix out (Tensor, optional): the output tensor for `c`