Skip to content

Commit

Permalink
Fix incorrect usage of CUDACachingAllocator [v2] (#48817)
Browse files Browse the repository at this point in the history
Summary:
This is similar to #46605, where the c10::complex part of the code was not merged yet at that moment.

Pull Request resolved: #48817

Reviewed By: malfet

Differential Revision: D25333179

Pulled By: ezyang

fbshipit-source-id: a92bdad5ad4b36bef7f050b21a59676c38e7b1fc
  • Loading branch information
xwang233 authored and facebook-github-bot committed Dec 7, 2020
1 parent 8bc6023 commit 36df253
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions aten/src/ATen/cuda/CUDASolver.cpp
Expand Up @@ -46,14 +46,14 @@ void getrf<c10::complex<double>>(
TORCH_CUSOLVER_CHECK(cusolverDnZgetrf_bufferSize(
handle, m, n, reinterpret_cast<cuDoubleComplex*>(dA), ldda, &lwork));
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
void* buffer = allocator.allocate(sizeof(cuDoubleComplex) * lwork).get();
auto dataPtr = allocator.allocate(sizeof(cuDoubleComplex) * lwork);
TORCH_CUSOLVER_CHECK(cusolverDnZgetrf(
handle,
m,
n,
reinterpret_cast<cuDoubleComplex*>(dA),
ldda,
static_cast<cuDoubleComplex*>(buffer),
static_cast<cuDoubleComplex*>(dataPtr.get()),
ipiv,
info));
}
Expand All @@ -71,14 +71,14 @@ void getrf<c10::complex<float>>(
TORCH_CUSOLVER_CHECK(cusolverDnCgetrf_bufferSize(
handle, m, n, reinterpret_cast<cuComplex*>(dA), ldda, &lwork));
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
void* buffer = allocator.allocate(sizeof(cuComplex) * lwork).get();
auto dataPtr = allocator.allocate(sizeof(cuComplex) * lwork);
TORCH_CUSOLVER_CHECK(cusolverDnCgetrf(
handle,
m,
n,
reinterpret_cast<cuComplex*>(dA),
ldda,
static_cast<cuComplex*>(buffer),
static_cast<cuComplex*>(dataPtr.get()),
ipiv,
info));
}
Expand Down

0 comments on commit 36df253

Please sign in to comment.