Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batched version of trtrs #18025

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 0 additions & 32 deletions aten/src/ATen/Declarations.cwrap
Expand Up @@ -2242,38 +2242,6 @@
- THTensor* self
- THTensor* A
]]
[[
name: _th_trtrs
cname: trtrs
types:
- Float
- Double
backends:
- CPU
- CUDA
variants:
- function
return: argument 0,1
arguments:
- arg: THTensor* res1
output: True
- arg: THTensor* res2
output: True
- THTensor* self
- THTensor* A
- arg: bool upper
if_true: U
if_false: L
default: U
- arg: bool transpose
if_true: T
if_false: N
default: N
- arg: bool unitriangular
if_true: U
if_false: N
default: N
]]
[[
name: _th_symeig
cname: syev
Expand Down
102 changes: 100 additions & 2 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Expand Up @@ -19,9 +19,11 @@
extern "C" void dgesv_(int *n, int *nrhs, double *a, int *lda, int *ipiv, double *b, int *ldb, int *info);
extern "C" void sgesv_(int *n, int *nrhs, float *a, int *lda, int *ipiv, float *b, int *ldb, int *info);

// inverse
// getrf
extern "C" void dgetrf_(int *m, int *n, double *a, int *lda, int *ipiv, int *info);
extern "C" void sgetrf_(int *m, int *n, float *a, int *lda, int *ipiv, int *info);

// getri
extern "C" void dgetri_(int *n, double *a, int *lda, int *ipiv, double *work, int *lwork, int *info);
extern "C" void sgetri_(int *n, float *a, int *lda, int *ipiv, float *work, int *lwork, int *info);

Expand All @@ -32,6 +34,10 @@ extern "C" void spotrs_(char *uplo, int *n, int *nrhs, float *a, int *lda, float
// potrf
extern "C" void dpotrf_(char *uplo, int *n, double *a, int *lda, int *info);
extern "C" void spotrf_(char *uplo, int *n, float *a, int *lda, int *info);

// trtrs
extern "C" void dtrtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info);
extern "C" void strtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info);
#endif

namespace at {
Expand Down Expand Up @@ -64,6 +70,11 @@ void lapackCholesky(char uplo, int n, scalar_t *a, int lda, int *info) {
AT_ERROR("cholesky only takes float or double Tensors");
}

template<class scalar_t>
void lapackTrtrs(char uplo, char trans, char diag, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int *info) {
AT_ERROR("trtrs only takes float or double Tensors");
}

#ifdef USE_LAPACK
template<> void lapackGesv<double>(int n, int nrhs, double *a, int lda, int *ipiv, double *b, int ldb, int *info) {
dgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info);
Expand Down Expand Up @@ -104,6 +115,14 @@ template<> void lapackCholesky<double>(char uplo, int n, double *a, int lda, int
template<> void lapackCholesky<float>(char uplo, int n, float *a, int lda, int *info) {
spotrf_(&uplo, &n, a, &lda, info);
}

template<> void lapackTrtrs<double>(char uplo, char trans, char diag, int n, int nrhs, double *a, int lda, double *b, int ldb, int *info) {
dtrtrs_(&uplo, &trans, &diag, &n, &nrhs, a, &lda, b, &ldb, info);
}

template<> void lapackTrtrs<float>(char uplo, char trans, char diag, int n, int nrhs, float *a, int lda, float *b, int ldb, int *info) {
strtrs_(&uplo, &trans, &diag, &n, &nrhs, a, &lda, b, &ldb, info);
}
#endif

// Below of the definitions of the functions operating on a batch that are going to be dispatched
Expand Down Expand Up @@ -320,7 +339,9 @@ Tensor& cholesky_solve_out(Tensor& result, const Tensor& self, const Tensor& A,
AT_CHECK(self.dim() == 2 && A.dim() == 2,
"torch.cholesky_solve() with the `out` keyword does not support batching. "
"b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2.");
result = at::_cholesky_solve_helper(self, A, upper);
Tensor result_tmp;
result_tmp = at::_cholesky_solve_helper(self, A, upper);
result.resize_as_(result_tmp).copy_(result_tmp);
return result;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!

}

Expand Down Expand Up @@ -483,6 +504,8 @@ std::tuple<Tensor&, Tensor&, Tensor&> btrifact_with_info_out(
return std::tuple<Tensor&, Tensor&, Tensor&>(A_LU, pivots, info);
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triu/tril ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template <typename scalar_t, bool upper>
static void apply_triu_tril_single(
scalar_t* result, scalar_t* self, bool inplace,
Expand Down Expand Up @@ -621,4 +644,79 @@ Tensor& triu_cpu_out(Tensor &result, const Tensor& self, int64_t k) {
return result;
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ trtrs ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template<typename scalar_t>
static void apply_trtrs(Tensor& b, Tensor& A, bool upper, bool transpose, bool unitriangular, std::vector<int64_t>& infos) {
#ifndef USE_LAPACK
AT_ERROR("trtrs: LAPACK library not found in compilation");
#else
char uplo = upper ? 'U' : 'L';
char trans = transpose ? 'T' : 'N';
char diag = unitriangular ? 'U' : 'N';

auto A_data = A.data<scalar_t>();
auto b_data = b.data<scalar_t>();
auto n = A.size(-2);
auto nrhs = b.size(-1);

int info;
if (b.dim() == 2) {
lapackTrtrs<scalar_t>(uplo, trans, diag, n, nrhs, A_data, n, b_data, n, &info);
infos[0] = info;
} else {
auto A_mat_stride = matrixStride(A);
auto b_mat_stride = matrixStride(b);
auto batch_size = batchCount(A);
for (int64_t i = 0; i < batch_size; i++) {
scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
lapackTrtrs<scalar_t>(uplo, trans, diag, n, nrhs, A_working_ptr, n, b_working_ptr, n, &info);
infos[i] = info;
if (info != 0) {
return;
}
}
}
#endif
}

std::tuple<Tensor, Tensor> _trtrs_helper_cpu(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) {
auto self_working_copy = cloneBatchedColumnMajor(self);
auto A_working_copy = cloneBatchedColumnMajor(A);
std::vector<int64_t> infos(batchCount(self), 0);
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "trtrs_cpu", [&]{
apply_trtrs<scalar_t>(self_working_copy, A_working_copy, upper, transpose, unitriangular, infos);
});
if (self.dim() > 2) {
batchCheckErrors(infos, "trtrs_cpu");
} else {
singleCheckErrors(infos[0], "trtrs_cpu");
}
return std::tuple<Tensor, Tensor>(self_working_copy, A_working_copy);
}

// Supports arbitrary batch dimensions for self and A
std::tuple<Tensor, Tensor> trtrs(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) {
AT_CHECK(self.dim() >= 2,
"b should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
AT_CHECK(A.dim() >= 2,
ifedan marked this conversation as resolved.
Show resolved Hide resolved
"u should have at least 2 dimensions, but has ", A.dim(), " dimensions instead");
Tensor self_broadcasted, A_broadcasted;
std::tie(self_broadcasted, A_broadcasted) = _linear_solve_broadcast_args(self, A);
return at::_trtrs_helper(self_broadcasted, A_broadcasted, upper, transpose, unitriangular);
}

std::tuple<Tensor&, Tensor&> trtrs_out(Tensor& result, Tensor& clone_A,
const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) {
AT_CHECK(self.dim() == 2 && A.dim() == 2,
"torch.trtrs() with the `out` keyword does not support batching. "
"b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2.");
Tensor result_tmp, clone_A_tmp;
std::tie(result_tmp, clone_A_tmp) = at::_trtrs_helper(self, A, upper, transpose, unitriangular);
result.resize_as_(result_tmp).copy_(result_tmp);
clone_A.resize_as_(clone_A_tmp).copy_(clone_A_tmp);
return std::tuple<Tensor&, Tensor&>(result, clone_A);
}

}} // namespace at::native
8 changes: 0 additions & 8 deletions aten/src/ATen/native/LegacyDefinitions.cpp
Expand Up @@ -424,14 +424,6 @@ std::tuple<Tensor,Tensor> gels(const Tensor & self, const Tensor & A) {
return at::legacy::th::_th_gels(self, A);
}

std::tuple<Tensor &,Tensor &> trtrs_out(Tensor & X, Tensor & M, const Tensor & self, const Tensor & A, bool upper, bool transpose, bool unitriangular) {
return at::legacy::th::_th_trtrs_out(X, M, self, A, upper, transpose, unitriangular);
}

std::tuple<Tensor,Tensor> trtrs(const Tensor & self, const Tensor & A, bool upper, bool transpose, bool unitriangular) {
return at::legacy::th::_th_trtrs(self, A, upper, transpose, unitriangular);
}

std::tuple<Tensor &,Tensor &> symeig_out(Tensor & e, Tensor & V, const Tensor & self, bool eigenvectors, bool upper) {
return at::legacy::th::_th_symeig_out(e, V, self, eigenvectors, upper);
}
Expand Down
124 changes: 112 additions & 12 deletions aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
Expand Up @@ -85,20 +85,19 @@ void magmaCholeskyBatched(
AT_ERROR("cholesky only takes float or double Tensors");
}

template<>
void magmaGesvBatched<double>(
magma_int_t n, magma_int_t nrhs, double** dA_array, magma_int_t ldda,
magma_int_t** dipiv_array, double** dB_array, magma_int_t lddb,
magma_int_t* dinfo_array, magma_int_t batch_count, const MAGMAQueue& magma_queue) {
magma_dgesv_batched(n, nrhs, dA_array, ldda, dipiv_array, dB_array, lddb, dinfo_array, batch_count, magma_queue.get_queue());
template<class scalar_t>
void magmaTrsm(
magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
scalar_t* dA, magma_int_t ldda, scalar_t* dB, magma_int_t lddb) {
AT_ERROR("trtrs only takes float or double Tensors");
}

template<>
void magmaGesvBatched<float>(
magma_int_t n, magma_int_t nrhs, float** dA_array, magma_int_t ldda,
magma_int_t** dipiv_array, float** dB_array, magma_int_t lddb,
magma_int_t* dinfo_array, magma_int_t batch_count, const MAGMAQueue& magma_queue) {
magma_sgesv_batched(n, nrhs, dA_array, ldda, dipiv_array, dB_array, lddb, dinfo_array, batch_count, magma_queue.get_queue());
template<class scalar_t>
void magmaTrsmBatched(
magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
scalar_t** dA_array, magma_int_t ldda, scalar_t** dB_array, magma_int_t lddb, magma_int_t batchsize,
const MAGMAQueue& magma_queue) {
AT_ERROR("trtrs only takes float or double Tensors");
}

template<>
Expand All @@ -115,6 +114,22 @@ void magmaGesv<float>(
magma_sgesv_gpu(n, nrhs, dA, ldda, ipiv, dB, lddb, info);
}

template<>
void magmaGesvBatched<double>(
magma_int_t n, magma_int_t nrhs, double** dA_array, magma_int_t ldda,
magma_int_t** dipiv_array, double** dB_array, magma_int_t lddb,
magma_int_t* dinfo_array, magma_int_t batch_count, const MAGMAQueue& magma_queue) {
magma_dgesv_batched(n, nrhs, dA_array, ldda, dipiv_array, dB_array, lddb, dinfo_array, batch_count, magma_queue.get_queue());
}

template<>
void magmaGesvBatched<float>(
magma_int_t n, magma_int_t nrhs, float** dA_array, magma_int_t ldda,
magma_int_t** dipiv_array, float** dB_array, magma_int_t lddb,
magma_int_t* dinfo_array, magma_int_t batch_count, const MAGMAQueue& magma_queue) {
magma_sgesv_batched(n, nrhs, dA_array, ldda, dipiv_array, dB_array, lddb, dinfo_array, batch_count, magma_queue.get_queue());
}

template<>
void magmaGetrfBatched<double>(
magma_int_t m, magma_int_t n, double** dA_array, magma_int_t ldda,
Expand Down Expand Up @@ -216,6 +231,36 @@ void magmaCholeskyBatched<float>(
magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
magma_spotrf_batched(uplo, n, dA_array, ldda, info_array, batchsize, magma_queue.get_queue());
}

template<>
void magmaTrsm<double>(
magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
double* dA, magma_int_t ldda, double* dB, magma_int_t lddb) {
magma_dtrsm(MagmaLeft, uplo, trans, diag, m, n, 1, dA, ldda, dB, lddb);
}

template<>
void magmaTrsm<float>(
magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
float* dA, magma_int_t ldda, float* dB, magma_int_t lddb) {
magma_strsm(MagmaLeft, uplo, trans, diag, m, n, 1, dA, ldda, dB, lddb);
}

template<>
void magmaTrsmBatched<double>(
magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
double** dA_array, magma_int_t ldda, double** dB_array, magma_int_t lddb, magma_int_t batchsize,
const MAGMAQueue& magma_queue) {
magmablas_dtrsm_batched(MagmaLeft, uplo, trans, diag, m, n, 1, dA_array, ldda, dB_array, lddb, batchsize, magma_queue.get_queue());
}

template<>
void magmaTrsmBatched<float>(
magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
float** dA_array, magma_int_t ldda, float** dB_array, magma_int_t lddb, magma_int_t batchsize,
const MAGMAQueue& magma_queue) {
magmablas_strsm_batched(MagmaLeft, uplo, trans, diag, m, n, 1, dA_array, ldda, dB_array, lddb, batchsize, magma_queue.get_queue());
}
#endif

#define ALLOCATE_ARRAY(name, type, size, dummy_tensor) \
Expand Down Expand Up @@ -554,6 +599,8 @@ std::tuple<Tensor, Tensor, Tensor> _btrifact_helper_cuda(const Tensor& self, boo
return std::make_tuple(self_working_copy, pivots_tensor, infos_tensor);
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triu/tril ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template <typename scalar_t, bool upper>
__global__
void triu_tril_kernel(
Expand Down Expand Up @@ -637,6 +684,59 @@ Tensor& triu_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
return triu_tril_cuda_template<true>(result, self_c, k, "triu");
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ trsm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template <typename scalar_t>
static void apply_trsm(Tensor& b, Tensor& A, bool upper, bool transpose, bool unitriangular) {
#ifndef USE_MAGMA
AT_ERROR("cholesky_solve: MAGMA library not found in "
"compilation. Please rebuild with MAGMA.");
#else
magma_uplo_t uplo = upper ? MagmaUpper : MagmaLower;
magma_trans_t trans = transpose ? MagmaTrans : MagmaNoTrans;
magma_diag_t diag = unitriangular ? MagmaUnit : MagmaNonUnit;

auto A_data = A.data<scalar_t>();
auto b_data = b.data<scalar_t>();
magma_int_t n = magma_int_cast(A.size(-2), "A.size(-2)");
magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)");

if (b.dim() == 2) {
magmaTrsm<scalar_t>(uplo, trans, diag, n, nrhs, A_data, n, b_data, n);
} else {
auto A_mat_stride = matrixStride(A);
auto b_mat_stride = matrixStride(b);
magma_int_t batch_size = magma_int_cast(batchCount(A), "batchCount");

scalar_t** A_array;
scalar_t** b_array;

ALLOCATE_ARRAY(A_array, scalar_t*, batch_size, b);
ALLOCATE_ARRAY(b_array, scalar_t*, batch_size, b);

// Set up the created arrays
for (int64_t i = 0; i < batch_size; i++) {
A_array[i] = &A_data[i * A_mat_stride];
b_array[i] = &b_data[i * b_mat_stride];
}

MAGMAQueue magma_queue(b.get_device());
magmaTrsmBatched<scalar_t>(
uplo, trans, diag, n, nrhs, A_array, n,
b_array, n, batch_size, magma_queue);
}
#endif
}

std::tuple<Tensor, Tensor> _trsm_helper_cuda(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) {
auto self_working_copy = cloneBatchedColumnMajor(self);
auto A_working_copy = cloneBatchedColumnMajor(A);
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "trsm_cuda", [&]{
apply_trsm<scalar_t>(self_working_copy, A_working_copy, upper, transpose, unitriangular);
});
return std::tuple<Tensor, Tensor>(self_working_copy, A_working_copy);
}

}} // namespace at::native

#undef ALLOCATE_ARRAY
7 changes: 7 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -3718,6 +3718,13 @@
matches_jit_signature: True
variants: method, function

- func: _trtrs_helper(Tensor self, Tensor A, bool upper, bool transpose, bool unitriangular) -> (Tensor, Tensor)
matches_jit_signature: True
variants: function
dispatch:
CPU: _trtrs_helper_cpu
CUDA: _trsm_helper_cuda

- func: symeig(Tensor self, bool eigenvectors=False, bool upper=True, *, Tensor(a!) e, Tensor(b!) V) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)
matches_jit_signature: True

Expand Down