Skip to content

Commit

Permalink
Revert D24730264: [pytorch][PR] Added CUDA support for complex input …
Browse files Browse the repository at this point in the history
…for torch.inverse

Test Plan: revert-hammer

Differential Revision:
D24730264 (33acbed)

Original commit changeset: b9c94ec46301

fbshipit-source-id: beb9263700e9bc92685f74c37c46aa33f3b595b9
  • Loading branch information
ezyang authored and facebook-github-bot committed Nov 6, 2020
1 parent f3ad7b2 commit 1aeefcd
Show file tree
Hide file tree
Showing 13 changed files with 161 additions and 497 deletions.
82 changes: 0 additions & 82 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -586,44 +586,6 @@ void getrfBatched<float>(
handle, n, dA_array, ldda, ipiv_array, info_array, batchsize));
}

template <>
void getrfBatched<c10::complex<double>>(
int n,
c10::complex<double>** dA_array,
int ldda,
int* ipiv_array,
int* info_array,
int batchsize) {
auto handle = at::cuda::getCurrentCUDABlasHandle();
TORCH_CUDABLAS_CHECK(cublasZgetrfBatched(
handle,
n,
reinterpret_cast<cuDoubleComplex**>(dA_array),
ldda,
ipiv_array,
info_array,
batchsize));
}

template <>
void getrfBatched<c10::complex<float>>(
int n,
c10::complex<float>** dA_array,
int ldda,
int* ipiv_array,
int* info_array,
int batchsize) {
auto handle = at::cuda::getCurrentCUDABlasHandle();
TORCH_CUDABLAS_CHECK(cublasCgetrfBatched(
handle,
n,
reinterpret_cast<cuComplex**>(dA_array),
ldda,
ipiv_array,
info_array,
batchsize));
}

template <>
void getriBatched<double>(
int n, double** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize, double** dC_array) {
Expand All @@ -640,50 +602,6 @@ void getriBatched<float>(
handle, n, dA_array, ldda, ipiv_array, dC_array, n, info_array, batchsize));
}

template <>
void getriBatched<c10::complex<double>>(
int n,
c10::complex<double>** dA_array,
int ldda,
int* ipiv_array,
int* info_array,
int batchsize,
c10::complex<double>** dC_array) {
auto handle = at::cuda::getCurrentCUDABlasHandle();
TORCH_CUDABLAS_CHECK(cublasZgetriBatched(
handle,
n,
reinterpret_cast<cuDoubleComplex**>(dA_array),
ldda,
ipiv_array,
reinterpret_cast<cuDoubleComplex**>(dC_array),
n,
info_array,
batchsize));
}

template <>
void getriBatched<c10::complex<float>>(
int n,
c10::complex<float>** dA_array,
int ldda,
int* ipiv_array,
int* info_array,
int batchsize,
c10::complex<float>** dC_array) {
auto handle = at::cuda::getCurrentCUDABlasHandle();
TORCH_CUDABLAS_CHECK(cublasCgetriBatched(
handle,
n,
reinterpret_cast<cuComplex**>(dA_array),
ldda,
ipiv_array,
reinterpret_cast<cuComplex**>(dC_array),
n,
info_array,
batchsize));
}

#endif // CUDART_VERSION

} // namespace blas
Expand Down
8 changes: 0 additions & 8 deletions aten/src/ATen/cuda/CUDABlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,6 @@ template<>
void getrfBatched<float>(CUDABLAS_GETRF_ARGTYPES(float));
template<>
void getrfBatched<double>(CUDABLAS_GETRF_ARGTYPES(double));
template<>
void getrfBatched<c10::complex<double>>(CUDABLAS_GETRF_ARGTYPES(c10::complex<double>));
template<>
void getrfBatched<c10::complex<float>>(CUDABLAS_GETRF_ARGTYPES(c10::complex<float>));


#define CUDABLAS_GETRI_ARGTYPES(Dtype) \
Expand All @@ -172,10 +168,6 @@ template<>
void getriBatched<float>(CUDABLAS_GETRI_ARGTYPES(float));
template<>
void getriBatched<double>(CUDABLAS_GETRI_ARGTYPES(double));
template<>
void getriBatched<c10::complex<double>>(CUDABLAS_GETRI_ARGTYPES(c10::complex<double>));
template<>
void getriBatched<c10::complex<float>>(CUDABLAS_GETRI_ARGTYPES(c10::complex<float>));

#endif // CUDART_VERSION

Expand Down
98 changes: 0 additions & 98 deletions aten/src/ATen/cuda/CUDASolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,56 +33,6 @@ void getrf<float>(
handle, m, n, dA, ldda, static_cast<float*>(dataPtr.get()), ipiv, info));
}

template <>
void getrf<c10::complex<double>>(
cusolverDnHandle_t handle,
int m,
int n,
c10::complex<double>* dA,
int ldda,
int* ipiv,
int* info) {
int lwork;
TORCH_CUSOLVER_CHECK(cusolverDnZgetrf_bufferSize(
handle, m, n, reinterpret_cast<cuDoubleComplex*>(dA), ldda, &lwork));
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
void* buffer = allocator.allocate(sizeof(cuDoubleComplex) * lwork).get();
TORCH_CUSOLVER_CHECK(cusolverDnZgetrf(
handle,
m,
n,
reinterpret_cast<cuDoubleComplex*>(dA),
ldda,
static_cast<cuDoubleComplex*>(buffer),
ipiv,
info));
}

template <>
void getrf<c10::complex<float>>(
cusolverDnHandle_t handle,
int m,
int n,
c10::complex<float>* dA,
int ldda,
int* ipiv,
int* info) {
int lwork;
TORCH_CUSOLVER_CHECK(cusolverDnCgetrf_bufferSize(
handle, m, n, reinterpret_cast<cuComplex*>(dA), ldda, &lwork));
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
void* buffer = allocator.allocate(sizeof(cuComplex) * lwork).get();
TORCH_CUSOLVER_CHECK(cusolverDnCgetrf(
handle,
m,
n,
reinterpret_cast<cuComplex*>(dA),
ldda,
static_cast<cuComplex*>(buffer),
ipiv,
info));
}

template <>
void getrs<double>(
cusolverDnHandle_t handle, int n, int nrhs, double* dA, int lda, int* ipiv, double* ret, int ldb, int* info) {
Expand All @@ -97,54 +47,6 @@ void getrs<float>(
handle, CUBLAS_OP_N, n, nrhs, dA, lda, ipiv, ret, ldb, info));
}

template <>
void getrs<c10::complex<double>>(
cusolverDnHandle_t handle,
int n,
int nrhs,
c10::complex<double>* dA,
int lda,
int* ipiv,
c10::complex<double>* ret,
int ldb,
int* info) {
TORCH_CUSOLVER_CHECK(cusolverDnZgetrs(
handle,
CUBLAS_OP_N,
n,
nrhs,
reinterpret_cast<cuDoubleComplex*>(dA),
lda,
ipiv,
reinterpret_cast<cuDoubleComplex*>(ret),
ldb,
info));
}

template <>
void getrs<c10::complex<float>>(
cusolverDnHandle_t handle,
int n,
int nrhs,
c10::complex<float>* dA,
int lda,
int* ipiv,
c10::complex<float>* ret,
int ldb,
int* info) {
TORCH_CUSOLVER_CHECK(cusolverDnCgetrs(
handle,
CUBLAS_OP_N,
n,
nrhs,
reinterpret_cast<cuComplex*>(dA),
lda,
ipiv,
reinterpret_cast<cuComplex*>(ret),
ldb,
info));
}

} // namespace solver
} // namespace cuda
} // namespace at
Expand Down
8 changes: 0 additions & 8 deletions aten/src/ATen/cuda/CUDASolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@ template<>
void getrf<float>(CUDASOLVER_GETRF_ARGTYPES(float));
template<>
void getrf<double>(CUDASOLVER_GETRF_ARGTYPES(double));
template<>
void getrf<c10::complex<double>>(CUDASOLVER_GETRF_ARGTYPES(c10::complex<double>));
template<>
void getrf<c10::complex<float>>(CUDASOLVER_GETRF_ARGTYPES(c10::complex<float>));


#define CUDASOLVER_GETRS_ARGTYPES(Dtype) \
Expand All @@ -36,10 +32,6 @@ template<>
void getrs<float>(CUDASOLVER_GETRS_ARGTYPES(float));
template<>
void getrs<double>(CUDASOLVER_GETRS_ARGTYPES(double));
template<>
void getrs<c10::complex<double>>(CUDASOLVER_GETRS_ARGTYPES(c10::complex<double>));
template<>
void getrs<c10::complex<float>>(CUDASOLVER_GETRS_ARGTYPES(c10::complex<float>));


} // namespace solver
Expand Down
106 changes: 2 additions & 104 deletions aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
Original file line number Diff line number Diff line change
Expand Up @@ -328,18 +328,6 @@ inline magma_int_t magmaGetriOptimalBlocksize<float>(magma_int_t n) {
return magma_get_sgetri_nb(n);
}

template <>
inline magma_int_t magmaGetriOptimalBlocksize<c10::complex<double>>(
magma_int_t n) {
return magma_get_zgetri_nb(n);
}

template <>
inline magma_int_t magmaGetriOptimalBlocksize<c10::complex<float>>(
magma_int_t n) {
return magma_get_cgetri_nb(n);
}

template<>
void magmaGetri<double>(
magma_int_t n, double* dA, magma_int_t ldda, magma_int_t* ipiv, double* dwork,
Expand All @@ -358,48 +346,6 @@ void magmaGetri<float>(
AT_CUDA_CHECK(cudaGetLastError());
}

template <>
void magmaGetri<c10::complex<double>>(
magma_int_t n,
c10::complex<double>* dA,
magma_int_t ldda,
magma_int_t* ipiv,
c10::complex<double>* dwork,
magma_int_t lwork,
magma_int_t* info) {
MagmaStreamSyncGuard guard;
magma_zgetri_gpu(
n,
reinterpret_cast<magmaDoubleComplex*>(dA),
ldda,
ipiv,
reinterpret_cast<magmaDoubleComplex*>(dwork),
lwork,
info);
AT_CUDA_CHECK(cudaGetLastError());
}

template <>
void magmaGetri<c10::complex<float>>(
magma_int_t n,
c10::complex<float>* dA,
magma_int_t ldda,
magma_int_t* ipiv,
c10::complex<float>* dwork,
magma_int_t lwork,
magma_int_t* info) {
MagmaStreamSyncGuard guard;
magma_cgetri_gpu(
n,
reinterpret_cast<magmaFloatComplex*>(dA),
ldda,
ipiv,
reinterpret_cast<magmaFloatComplex*>(dwork),
lwork,
info);
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaGetriBatched<double>(
magma_int_t n, double** dA_array, magma_int_t ldda,
Expand All @@ -418,54 +364,6 @@ void magmaGetriBatched<float>(
AT_CUDA_CHECK(cudaGetLastError());
}

template <>
void magmaGetriBatched<c10::complex<double>>(
magma_int_t n,
c10::complex<double>** dA_array,
magma_int_t ldda,
magma_int_t** ipiv_array,
c10::complex<double>** dinvA_array,
magma_int_t lddia,
magma_int_t* info_array,
magma_int_t batchsize,
const MAGMAQueue& magma_queue) {
magma_zgetri_outofplace_batched(
n,
reinterpret_cast<magmaDoubleComplex**>(dA_array),
ldda,
ipiv_array,
reinterpret_cast<magmaDoubleComplex**>(dinvA_array),
lddia,
info_array,
batchsize,
magma_queue.get_queue());
AT_CUDA_CHECK(cudaGetLastError());
}

template <>
void magmaGetriBatched<c10::complex<float>>(
magma_int_t n,
c10::complex<float>** dA_array,
magma_int_t ldda,
magma_int_t** ipiv_array,
c10::complex<float>** dinvA_array,
magma_int_t lddia,
magma_int_t* info_array,
magma_int_t batchsize,
const MAGMAQueue& magma_queue) {
magma_cgetri_outofplace_batched(
n,
reinterpret_cast<magmaFloatComplex**>(dA_array),
ldda,
ipiv_array,
reinterpret_cast<magmaFloatComplex**>(dinvA_array),
lddia,
info_array,
batchsize,
magma_queue.get_queue());
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaCholeskySolve<double>(
magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, double* dA, magma_int_t ldda,
Expand Down Expand Up @@ -1121,14 +1019,14 @@ Tensor _inverse_helper_cuda_legacy(const Tensor& self) {
if (self.dim() > 2) {
std::vector<int64_t> infos(batchCount(self), 0);
auto self_working_copy = cloneBatchedColumnMajor(self);
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "inverse_cuda", [&]{
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "inverse_cuda", [&]{
apply_batched_inverse<scalar_t>(
self_working_copy, self_inv_working_copy, infos);
});
batchCheckErrors(infos, "inverse_cuda");
} else {
int64_t info = 0;
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "inverse_cuda", [&]{
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "inverse_cuda", [&]{
apply_single_inverse<scalar_t>(self_inv_working_copy, info);
});
singleCheckErrors(info, "inverse_cuda");
Expand Down

0 comments on commit 1aeefcd

Please sign in to comment.