Skip to content

Commit

Permalink
Add batched version of trtrs (#18025)
Browse files Browse the repository at this point in the history
Summary:
- Remove single batch TH/THC implementations
- Remove `_batch_trtrs_lower` from `multivariate_normal`
- Add tests for batched behavior
- Modify trtrs_backward to accommodate for batched case
- Modify docs

In a future PR, this will be renamed to `triangular_solve`.
Pull Request resolved: pytorch/pytorch#18025

Differential Revision: D14523004

Pulled By: ifedan

fbshipit-source-id: 11c6a967d107f969b60e5a5c73ce6bb8099ebbe1
  • Loading branch information
vishwakftw authored and facebook-github-bot committed Mar 20, 2019
1 parent 3263457 commit 5ca087c
Show file tree
Hide file tree
Showing 11 changed files with 219 additions and 163 deletions.
32 changes: 0 additions & 32 deletions aten/src/ATen/Declarations.cwrap
Expand Up @@ -2242,38 +2242,6 @@
- THTensor* self
- THTensor* A
]]
[[
name: _th_trtrs
cname: trtrs
types:
- Float
- Double
backends:
- CPU
- CUDA
variants:
- function
return: argument 0,1
arguments:
- arg: THTensor* res1
output: True
- arg: THTensor* res2
output: True
- THTensor* self
- THTensor* A
- arg: bool upper
if_true: U
if_false: L
default: U
- arg: bool transpose
if_true: T
if_false: N
default: N
- arg: bool unitriangular
if_true: U
if_false: N
default: N
]]
[[
name: _th_symeig
cname: syev
Expand Down
102 changes: 100 additions & 2 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Expand Up @@ -19,9 +19,11 @@
extern "C" void dgesv_(int *n, int *nrhs, double *a, int *lda, int *ipiv, double *b, int *ldb, int *info);
extern "C" void sgesv_(int *n, int *nrhs, float *a, int *lda, int *ipiv, float *b, int *ldb, int *info);

// inverse
// getrf
extern "C" void dgetrf_(int *m, int *n, double *a, int *lda, int *ipiv, int *info);
extern "C" void sgetrf_(int *m, int *n, float *a, int *lda, int *ipiv, int *info);

// getri
extern "C" void dgetri_(int *n, double *a, int *lda, int *ipiv, double *work, int *lwork, int *info);
extern "C" void sgetri_(int *n, float *a, int *lda, int *ipiv, float *work, int *lwork, int *info);

Expand All @@ -32,6 +34,10 @@ extern "C" void spotrs_(char *uplo, int *n, int *nrhs, float *a, int *lda, float
// potrf
extern "C" void dpotrf_(char *uplo, int *n, double *a, int *lda, int *info);
extern "C" void spotrf_(char *uplo, int *n, float *a, int *lda, int *info);

// trtrs
extern "C" void dtrtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info);
extern "C" void strtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info);
#endif

namespace at {
Expand Down Expand Up @@ -64,6 +70,11 @@ void lapackCholesky(char uplo, int n, scalar_t *a, int lda, int *info) {
AT_ERROR("cholesky only takes float or double Tensors");
}

template<class scalar_t>
void lapackTrtrs(char uplo, char trans, char diag, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int *info) {
AT_ERROR("trtrs only takes float or double Tensors");
}

#ifdef USE_LAPACK
template<> void lapackSolve<double>(int n, int nrhs, double *a, int lda, int *ipiv, double *b, int ldb, int *info) {
dgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info);
Expand Down Expand Up @@ -104,6 +115,14 @@ template<> void lapackCholesky<double>(char uplo, int n, double *a, int lda, int
template<> void lapackCholesky<float>(char uplo, int n, float *a, int lda, int *info) {
spotrf_(&uplo, &n, a, &lda, info);
}

template<> void lapackTrtrs<double>(char uplo, char trans, char diag, int n, int nrhs, double *a, int lda, double *b, int ldb, int *info) {
dtrtrs_(&uplo, &trans, &diag, &n, &nrhs, a, &lda, b, &ldb, info);
}

template<> void lapackTrtrs<float>(char uplo, char trans, char diag, int n, int nrhs, float *a, int lda, float *b, int ldb, int *info) {
strtrs_(&uplo, &trans, &diag, &n, &nrhs, a, &lda, b, &ldb, info);
}
#endif

// Below of the definitions of the functions operating on a batch that are going to be dispatched
Expand Down Expand Up @@ -317,7 +336,9 @@ Tensor& cholesky_solve_out(Tensor& result, const Tensor& self, const Tensor& A,
AT_CHECK(self.dim() == 2 && A.dim() == 2,
"torch.cholesky_solve() with the `out` keyword does not support batching. "
"b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2.");
result = at::_cholesky_solve_helper(self, A, upper);
Tensor result_tmp;
result_tmp = at::_cholesky_solve_helper(self, A, upper);
result.resize_as_(result_tmp).copy_(result_tmp);
return result;
}

Expand Down Expand Up @@ -480,6 +501,8 @@ std::tuple<Tensor&, Tensor&, Tensor&> btrifact_with_info_out(
return std::tuple<Tensor&, Tensor&, Tensor&>(A_LU, pivots, info);
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triu/tril ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template <typename scalar_t, bool upper>
static void apply_triu_tril_single(
scalar_t* result, scalar_t* self, bool inplace,
Expand Down Expand Up @@ -618,4 +641,79 @@ Tensor& triu_cpu_out(Tensor &result, const Tensor& self, int64_t k) {
return result;
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ trtrs ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template<typename scalar_t>
static void apply_trtrs(Tensor& b, Tensor& A, bool upper, bool transpose, bool unitriangular, std::vector<int64_t>& infos) {
#ifndef USE_LAPACK
AT_ERROR("trtrs: LAPACK library not found in compilation");
#else
char uplo = upper ? 'U' : 'L';
char trans = transpose ? 'T' : 'N';
char diag = unitriangular ? 'U' : 'N';

auto A_data = A.data<scalar_t>();
auto b_data = b.data<scalar_t>();
auto n = A.size(-2);
auto nrhs = b.size(-1);

int info;
if (b.dim() == 2) {
lapackTrtrs<scalar_t>(uplo, trans, diag, n, nrhs, A_data, n, b_data, n, &info);
infos[0] = info;
} else {
auto A_mat_stride = matrixStride(A);
auto b_mat_stride = matrixStride(b);
auto batch_size = batchCount(A);
for (int64_t i = 0; i < batch_size; i++) {
scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
lapackTrtrs<scalar_t>(uplo, trans, diag, n, nrhs, A_working_ptr, n, b_working_ptr, n, &info);
infos[i] = info;
if (info != 0) {
return;
}
}
}
#endif
}

std::tuple<Tensor, Tensor> _trtrs_helper_cpu(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) {
auto self_working_copy = cloneBatchedColumnMajor(self);
auto A_working_copy = cloneBatchedColumnMajor(A);
std::vector<int64_t> infos(batchCount(self), 0);
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "trtrs_cpu", [&]{
apply_trtrs<scalar_t>(self_working_copy, A_working_copy, upper, transpose, unitriangular, infos);
});
if (self.dim() > 2) {
batchCheckErrors(infos, "trtrs_cpu");
} else {
singleCheckErrors(infos[0], "trtrs_cpu");
}
return std::tuple<Tensor, Tensor>(self_working_copy, A_working_copy);
}

// Supports arbitrary batch dimensions for self and A
std::tuple<Tensor, Tensor> trtrs(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) {
AT_CHECK(self.dim() >= 2,
"b should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
AT_CHECK(A.dim() >= 2,
"u should have at least 2 dimensions, but has ", A.dim(), " dimensions instead");
Tensor self_broadcasted, A_broadcasted;
std::tie(self_broadcasted, A_broadcasted) = _linear_solve_broadcast_args(self, A);
return at::_trtrs_helper(self_broadcasted, A_broadcasted, upper, transpose, unitriangular);
}

std::tuple<Tensor&, Tensor&> trtrs_out(Tensor& result, Tensor& clone_A,
const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) {
AT_CHECK(self.dim() == 2 && A.dim() == 2,
"torch.trtrs() with the `out` keyword does not support batching. "
"b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2.");
Tensor result_tmp, clone_A_tmp;
std::tie(result_tmp, clone_A_tmp) = at::_trtrs_helper(self, A, upper, transpose, unitriangular);
result.resize_as_(result_tmp).copy_(result_tmp);
clone_A.resize_as_(clone_A_tmp).copy_(clone_A_tmp);
return std::tuple<Tensor&, Tensor&>(result, clone_A);
}

}} // namespace at::native
8 changes: 0 additions & 8 deletions aten/src/ATen/native/LegacyDefinitions.cpp
Expand Up @@ -424,14 +424,6 @@ std::tuple<Tensor,Tensor> gels(const Tensor & self, const Tensor & A) {
return at::legacy::th::_th_gels(self, A);
}

std::tuple<Tensor &,Tensor &> trtrs_out(Tensor & X, Tensor & M, const Tensor & self, const Tensor & A, bool upper, bool transpose, bool unitriangular) {
return at::legacy::th::_th_trtrs_out(X, M, self, A, upper, transpose, unitriangular);
}

std::tuple<Tensor,Tensor> trtrs(const Tensor & self, const Tensor & A, bool upper, bool transpose, bool unitriangular) {
return at::legacy::th::_th_trtrs(self, A, upper, transpose, unitriangular);
}

std::tuple<Tensor &,Tensor &> symeig_out(Tensor & e, Tensor & V, const Tensor & self, bool eigenvectors, bool upper) {
return at::legacy::th::_th_symeig_out(e, V, self, eigenvectors, upper);
}
Expand Down
124 changes: 112 additions & 12 deletions aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
Expand Up @@ -85,20 +85,19 @@ void magmaCholeskyBatched(
AT_ERROR("cholesky only takes float or double Tensors");
}

template<>
void magmaSolveBatched<double>(
magma_int_t n, magma_int_t nrhs, double** dA_array, magma_int_t ldda,
magma_int_t** dipiv_array, double** dB_array, magma_int_t lddb,
magma_int_t* dinfo_array, magma_int_t batch_count, const MAGMAQueue& magma_queue) {
magma_dgesv_batched(n, nrhs, dA_array, ldda, dipiv_array, dB_array, lddb, dinfo_array, batch_count, magma_queue.get_queue());
template<class scalar_t>
void magmaTrsm(
magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
scalar_t* dA, magma_int_t ldda, scalar_t* dB, magma_int_t lddb) {
AT_ERROR("trtrs only takes float or double Tensors");
}

template<>
void magmaSolveBatched<float>(
magma_int_t n, magma_int_t nrhs, float** dA_array, magma_int_t ldda,
magma_int_t** dipiv_array, float** dB_array, magma_int_t lddb,
magma_int_t* dinfo_array, magma_int_t batch_count, const MAGMAQueue& magma_queue) {
magma_sgesv_batched(n, nrhs, dA_array, ldda, dipiv_array, dB_array, lddb, dinfo_array, batch_count, magma_queue.get_queue());
template<class scalar_t>
void magmaTrsmBatched(
magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
scalar_t** dA_array, magma_int_t ldda, scalar_t** dB_array, magma_int_t lddb, magma_int_t batchsize,
const MAGMAQueue& magma_queue) {
AT_ERROR("trtrs only takes float or double Tensors");
}

template<>
Expand All @@ -115,6 +114,22 @@ void magmaSolve<float>(
magma_sgesv_gpu(n, nrhs, dA, ldda, ipiv, dB, lddb, info);
}

template<>
void magmaSolveBatched<double>(
magma_int_t n, magma_int_t nrhs, double** dA_array, magma_int_t ldda,
magma_int_t** dipiv_array, double** dB_array, magma_int_t lddb,
magma_int_t* dinfo_array, magma_int_t batch_count, const MAGMAQueue& magma_queue) {
magma_dgesv_batched(n, nrhs, dA_array, ldda, dipiv_array, dB_array, lddb, dinfo_array, batch_count, magma_queue.get_queue());
}

template<>
void magmaSolveBatched<float>(
magma_int_t n, magma_int_t nrhs, float** dA_array, magma_int_t ldda,
magma_int_t** dipiv_array, float** dB_array, magma_int_t lddb,
magma_int_t* dinfo_array, magma_int_t batch_count, const MAGMAQueue& magma_queue) {
magma_sgesv_batched(n, nrhs, dA_array, ldda, dipiv_array, dB_array, lddb, dinfo_array, batch_count, magma_queue.get_queue());
}

template<>
void magmaGetrfBatched<double>(
magma_int_t m, magma_int_t n, double** dA_array, magma_int_t ldda,
Expand Down Expand Up @@ -216,6 +231,36 @@ void magmaCholeskyBatched<float>(
magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
magma_spotrf_batched(uplo, n, dA_array, ldda, info_array, batchsize, magma_queue.get_queue());
}

template<>
void magmaTrsm<double>(
magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
double* dA, magma_int_t ldda, double* dB, magma_int_t lddb) {
magma_dtrsm(MagmaLeft, uplo, trans, diag, m, n, 1, dA, ldda, dB, lddb);
}

template<>
void magmaTrsm<float>(
magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
float* dA, magma_int_t ldda, float* dB, magma_int_t lddb) {
magma_strsm(MagmaLeft, uplo, trans, diag, m, n, 1, dA, ldda, dB, lddb);
}

template<>
void magmaTrsmBatched<double>(
magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
double** dA_array, magma_int_t ldda, double** dB_array, magma_int_t lddb, magma_int_t batchsize,
const MAGMAQueue& magma_queue) {
magmablas_dtrsm_batched(MagmaLeft, uplo, trans, diag, m, n, 1, dA_array, ldda, dB_array, lddb, batchsize, magma_queue.get_queue());
}

template<>
void magmaTrsmBatched<float>(
magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
float** dA_array, magma_int_t ldda, float** dB_array, magma_int_t lddb, magma_int_t batchsize,
const MAGMAQueue& magma_queue) {
magmablas_strsm_batched(MagmaLeft, uplo, trans, diag, m, n, 1, dA_array, ldda, dB_array, lddb, batchsize, magma_queue.get_queue());
}
#endif

#define ALLOCATE_ARRAY(name, type, size, dummy_tensor) \
Expand Down Expand Up @@ -554,6 +599,8 @@ std::tuple<Tensor, Tensor, Tensor> _btrifact_helper_cuda(const Tensor& self, boo
return std::make_tuple(self_working_copy, pivots_tensor, infos_tensor);
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triu/tril ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template <typename scalar_t, bool upper>
__global__
void triu_tril_kernel(
Expand Down Expand Up @@ -637,6 +684,59 @@ Tensor& triu_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
return triu_tril_cuda_template<true>(result, self_c, k, "triu");
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ trsm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template <typename scalar_t>
static void apply_trsm(Tensor& b, Tensor& A, bool upper, bool transpose, bool unitriangular) {
#ifndef USE_MAGMA
AT_ERROR("cholesky_solve: MAGMA library not found in "
"compilation. Please rebuild with MAGMA.");
#else
magma_uplo_t uplo = upper ? MagmaUpper : MagmaLower;
magma_trans_t trans = transpose ? MagmaTrans : MagmaNoTrans;
magma_diag_t diag = unitriangular ? MagmaUnit : MagmaNonUnit;

auto A_data = A.data<scalar_t>();
auto b_data = b.data<scalar_t>();
magma_int_t n = magma_int_cast(A.size(-2), "A.size(-2)");
magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)");

if (b.dim() == 2) {
magmaTrsm<scalar_t>(uplo, trans, diag, n, nrhs, A_data, n, b_data, n);
} else {
auto A_mat_stride = matrixStride(A);
auto b_mat_stride = matrixStride(b);
magma_int_t batch_size = magma_int_cast(batchCount(A), "batchCount");

scalar_t** A_array;
scalar_t** b_array;

ALLOCATE_ARRAY(A_array, scalar_t*, batch_size, b);
ALLOCATE_ARRAY(b_array, scalar_t*, batch_size, b);

// Set up the created arrays
for (int64_t i = 0; i < batch_size; i++) {
A_array[i] = &A_data[i * A_mat_stride];
b_array[i] = &b_data[i * b_mat_stride];
}

MAGMAQueue magma_queue(b.get_device());
magmaTrsmBatched<scalar_t>(
uplo, trans, diag, n, nrhs, A_array, n,
b_array, n, batch_size, magma_queue);
}
#endif
}

std::tuple<Tensor, Tensor> _trsm_helper_cuda(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) {
auto self_working_copy = cloneBatchedColumnMajor(self);
auto A_working_copy = cloneBatchedColumnMajor(A);
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "trsm_cuda", [&]{
apply_trsm<scalar_t>(self_working_copy, A_working_copy, upper, transpose, unitriangular);
});
return std::tuple<Tensor, Tensor>(self_working_copy, A_working_copy);
}

}} // namespace at::native

#undef ALLOCATE_ARRAY
7 changes: 7 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -3713,6 +3713,13 @@
matches_jit_signature: True
variants: method, function

- func: _trtrs_helper(Tensor self, Tensor A, bool upper, bool transpose, bool unitriangular) -> (Tensor, Tensor)
matches_jit_signature: True
variants: function
dispatch:
CPU: _trtrs_helper_cpu
CUDA: _trsm_helper_cuda

- func: symeig(Tensor self, bool eigenvectors=False, bool upper=True, *, Tensor(a!) e, Tensor(b!) V) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)
matches_jit_signature: True

Expand Down

0 comments on commit 5ca087c

Please sign in to comment.