Skip to content

Commit

Permalink
Add Kernel Launch Checks to files under caffe2/aten/THC (#49358)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #49358

Added the header file (`c10/cuda/CUDAException.h`) where the `C10_CUDA_KERNEL_LAUNCH_CHECK` is defined as needed to files under `caffe2/aten/THC`, and then added `C10_CUDA_KERNEL_LAUNCH_CHECK()` calls after each kernel launch. In some cases, removed some extraneous ErrorChecks

Test Plan:
Checked that the code still builds with
```
buck build //caffe2/aten:ATen-cu
```

Also ran basic aten tests
```
buck test //caffe2/aten:atest
```

Reviewed By: r-barnes

Differential Revision: D25541113

fbshipit-source-id: df1a50e14d291a86b24ca1746ac27fa586f9757c
  • Loading branch information
Amogh Akshintala authored and facebook-github-bot committed Dec 15, 2020
1 parent cb3169d commit e2510a0
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 64 deletions.
38 changes: 20 additions & 18 deletions aten/src/THC/THCApply.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <THC/THCTensorTypeUtils.cuh>
#include <THC/THCTensorCopy.hpp>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAException.h>

//
// This file contains pointwise operation functions and kernels that
Expand Down Expand Up @@ -242,14 +243,11 @@ bool THC_pointwiseApply1(THCState* state,
// (or vice versa), the contiguous tensor can be collapsed to one
// dimension, and the loop to translate the linear index to the array
// index can be similarly collapsed. That is what this unrolling is for.
#define HANDLE_CASE(TYPE, A) \
kernelPointwiseApply1<Op, \
ScalarTypeA, \
TYPE, A> \
<<<grid, block, 0, c10::cuda::getCurrentCUDAStream(curDevice)>>>( \
OffsetInfo<ScalarTypeA, TYPE, A> \
(aInfo), \
(TYPE) totalElements, op);
#define HANDLE_CASE(TYPE, A) \
kernelPointwiseApply1<Op, ScalarTypeA, TYPE, A> \
<<<grid, block, 0, c10::cuda::getCurrentCUDAStream(curDevice)>>>( \
OffsetInfo<ScalarTypeA, TYPE, A>(aInfo), (TYPE) totalElements, op); \
C10_CUDA_KERNEL_LAUNCH_CHECK();

#define HANDLE_A_CASE(TYPE, A) { \
switch (A) { \
Expand Down Expand Up @@ -298,6 +296,7 @@ bool THC_pointwiseApply1(THCState* state,
uint64_t, 1>
<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
aOffset, (uint64_t) totalElements, op);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {

#if CUDA_VERSION < 9000
Expand All @@ -310,6 +309,7 @@ bool THC_pointwiseApply1(THCState* state,
uint64_t, -1>
<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
aOffset, (uint64_t) totalElements, op);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
#undef HANDLE_CASE
Expand Down Expand Up @@ -392,16 +392,13 @@ bool THC_pointwiseApply2(THCState* state,
// dimension, and the loop to translate the linear index to the array
// index can be similarly collapsed. That is what this unrolling is for.
#define HANDLE_CASE(TYPE, A, B) \
kernelPointwiseApply2<Op, \
ScalarTypeA, \
ScalarTypeB, \
TYPE, A, B> \
kernelPointwiseApply2<Op, ScalarTypeA, ScalarTypeB, TYPE, A, B> \
<<<grid, block, 0, c10::cuda::getCurrentCUDAStream(curDevice)>>>( \
OffsetInfo<ScalarTypeA, TYPE, A> \
(aInfo), \
OffsetInfo<ScalarTypeB, TYPE, B> \
(bInfo), \
(TYPE) totalElements, op);
OffsetInfo<ScalarTypeA, TYPE, A>(aInfo), \
OffsetInfo<ScalarTypeB, TYPE, B>(bInfo), \
(TYPE) totalElements, op); \
C10_CUDA_KERNEL_LAUNCH_CHECK();


#define HANDLE_B_CASE(TYPE, A, B) { \
switch (B) { \
Expand Down Expand Up @@ -474,6 +471,7 @@ bool THC_pointwiseApply2(THCState* state,
uint64_t, 1, 1>
<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
aOffset, bOffset, (uint64_t) totalElements, op);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
#if CUDA_VERSION < 9000
grid.x = min(at::cuda::getCurrentDeviceProperties()->multiProcessorCount * THC_APPLY_BLOCKS_PER_SM , grid.x);
Expand All @@ -488,6 +486,7 @@ bool THC_pointwiseApply2(THCState* state,
uint64_t, -1, -1>
<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
aOffset, bOffset, (uint64_t) totalElements, op);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
#undef HANDLE_CASE
Expand Down Expand Up @@ -598,7 +597,8 @@ bool THC_pointwiseApply3(THCState* state,
(bInfo), \
OffsetInfo<ScalarTypeC, TYPE, C> \
(cInfo), \
(TYPE) totalElements, op);
(TYPE) totalElements, op); \
C10_CUDA_KERNEL_LAUNCH_CHECK();

#define HANDLE_C_CASE(TYPE, A, B, C) { \
switch (C) { \
Expand Down Expand Up @@ -697,6 +697,7 @@ bool THC_pointwiseApply3(THCState* state,
uint64_t, 1, 1, 1>
<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
aOffset, bOffset, cOffset, (uint64_t) totalElements, op);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
#if CUDA_VERSION < 9000
grid.x = min(at::cuda::getCurrentDeviceProperties()->multiProcessorCount * THC_APPLY_BLOCKS_PER_SM , grid.x);
Expand All @@ -715,6 +716,7 @@ bool THC_pointwiseApply3(THCState* state,
uint64_t, -1, -1, -1>
<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
aOffset, bOffset, cOffset, (uint64_t) totalElements, op);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
#undef HANDLE_CASE
Expand Down
4 changes: 4 additions & 0 deletions aten/src/THC/THCReduceAll.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
//

#include <THC/THCReduceApplyUtils.cuh>
#include <c10/cuda/CUDAException.h>
#include <c10/macros/Macros.h>

#ifdef __HIP_PLATFORM_HCC__
Expand Down Expand Up @@ -209,6 +210,7 @@ void callReduceAll(THCState* state,
<<<grid, block, smemSize, c10::cuda::getCurrentCUDAStream()>>>(
in, (IndexType) totalElements, init, modifyOp, reduceOp,
(AccT*) scratchSpace);
C10_CUDA_KERNEL_LAUNCH_CHECK();

int numPass1Blocks = grid.x;
getPass2ReduceBlockGrid<AccT>(state, totalElements, grid, block);
Expand All @@ -218,6 +220,7 @@ void callReduceAll(THCState* state,
<<<grid, block, smemSize, c10::cuda::getCurrentCUDAStream()>>>(
numPass1Blocks, init, reduceOp,
(AccT*) scratchSpace, devOut);
C10_CUDA_KERNEL_LAUNCH_CHECK();

THCudaFree(state, scratchSpace);
} else {
Expand All @@ -227,6 +230,7 @@ void callReduceAll(THCState* state,
kernelReduceAll<T, IndexType, AccT, ModifyOp, ReduceOp, ADims>
<<<grid, block, smemSize, c10::cuda::getCurrentCUDAStream()>>>(
in, (IndexType) totalElements, init, modifyOp, reduceOp, devOut);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}

Expand Down
8 changes: 5 additions & 3 deletions aten/src/THC/THCTensorSort.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <THC/THCTensorSort.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAException.h>

void THCudaLongTensor_fillSliceWithIndex(THCState* state,
THCudaLongTensor* t,
Expand Down Expand Up @@ -28,8 +29,10 @@ void THCudaLongTensor_fillSliceWithIndex(THCState* state,

#define FILL_INDEX(T, DIM) \
fillSliceWithIndex<T, DIM> \
<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>( \
info, numSlices, sliceSize, info.strides[collapseDim])
<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>( \
info, numSlices, sliceSize, info.strides[collapseDim]); \
C10_CUDA_KERNEL_LAUNCH_CHECK()


if (THCTensor_canUse32BitIndexMath(state, t)) {
TensorInfo<int64_t, uint32_t> info =
Expand Down Expand Up @@ -59,6 +62,5 @@ void THCudaLongTensor_fillSliceWithIndex(THCState* state,
}

#undef FILL_INDEX
THCudaCheck(cudaGetLastError());
}
}
27 changes: 16 additions & 11 deletions aten/src/THC/generic/THCTensorIndex.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <ATen/cuda/CUDAContext.h>
#include <ATen/MemoryOverlap.h>
#include <c10/cuda/CUDAException.h>

// Check tensor dimensions for index operations, and return the slice size.
// src can be nullptr in case of indexFill: in that case it is ignored.
Expand Down Expand Up @@ -127,11 +128,12 @@ void THCTensor_(indexCopy)(THCState *state, THCTensor *dst, int dim, THCudaLongT

int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;

#define SMALL_INDEX(TENSOR_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM) \
indexCopySmallIndex<TENSOR_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM> \
<<<smallIndexGrid, smallIndexBlock, 0, stream>>>( \
dstInfo, srcInfo, indicesInfo, \
dstCopyDim, srcCopyDim, sliceSize, dstCopyDimSize);
#define SMALL_INDEX(TENSOR_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM) \
indexCopySmallIndex<TENSOR_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM> \
<<<smallIndexGrid, smallIndexBlock, 0, stream>>>( \
dstInfo, srcInfo, indicesInfo, \
dstCopyDim, srcCopyDim, sliceSize, dstCopyDimSize); \
C10_CUDA_KERNEL_LAUNCH_CHECK();

#define LARGE_INDEX(TENSOR_TYPE, TYPE, \
DST_DIM, SRC_DIM, IDX_DIM, IDX_IS_MAJOR) \
Expand All @@ -141,7 +143,8 @@ void THCTensor_(indexCopy)(THCState *state, THCTensor *dst, int dim, THCudaLongT
dstInfo, srcInfo, indicesInfo, \
dstCopyDim, srcCopyDim, srcTotalSize, \
(IDX_IS_MAJOR) ? sliceSize : numIndices, \
dstCopyDimSize);
dstCopyDimSize); \
C10_CUDA_KERNEL_LAUNCH_CHECK();

dim3 smallIndexGrid(std::min(THCCeilDiv(sliceSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8)));
dim3 smallIndexBlock(std::min(sliceSize, (ptrdiff_t)128));
Expand Down Expand Up @@ -307,19 +310,21 @@ void THCTensor_(indexFill)(THCState *state, THCTensor *dst, int dim, THCudaLongT

int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;

#define SMALL_INDEX(TENSOR_TYPE, TYPE, DST_DIM, IDX_DIM) \
#define SMALL_INDEX(TENSOR_TYPE, TYPE, DST_DIM, IDX_DIM) \
indexFillSmallIndex<TENSOR_TYPE, TYPE, DST_DIM, IDX_DIM> \
<<<smallIndexGrid, smallIndexBlock, 0, stream>>>( \
dstInfo, indicesInfo, \
dstFillDim, sliceSize, dstFillDimSize, val);
<<<smallIndexGrid, smallIndexBlock, 0, stream>>>( \
dstInfo, indicesInfo, \
dstFillDim, sliceSize, dstFillDimSize, val); \
C10_CUDA_KERNEL_LAUNCH_CHECK();

#define LARGE_INDEX(TENSOR_TYPE, TYPE, DST_DIM, IDX_DIM, IDX_IS_MAJOR) \
indexFillLargeIndex<TENSOR_TYPE, TYPE, DST_DIM, IDX_DIM, IDX_IS_MAJOR> \
<<<largeIndexGrid, largeIndexBlock, 0, stream>>>( \
dstInfo, indicesInfo, \
dstFillDim, sliceSize * numIndices, \
(IDX_IS_MAJOR) ? sliceSize : numIndices, \
dstFillDimSize, val);
dstFillDimSize, val); \
C10_CUDA_KERNEL_LAUNCH_CHECK();

dim3 smallIndexGrid(std::min(THCCeilDiv(sliceSize, (ptrdiff_t)128), (ptrdiff_t)(mpc * 8)));
dim3 smallIndexBlock(std::min(sliceSize, (ptrdiff_t)128));
Expand Down
4 changes: 4 additions & 0 deletions aten/src/THC/generic/THCTensorMathMagma.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#define THC_GENERIC_FILE "THC/generic/THCTensorMathMagma.cu"
#else

#include <c10/cuda/CUDAException.h>

#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)

#ifdef USE_MAGMA
Expand Down Expand Up @@ -171,8 +173,10 @@ void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, bool upper
dim3 threads(128);
if (uplo == 'U') {
THCTensor_(copyUpperSymmetric)<<<blocks, threads, 0, stream>>>(input_data, n, len);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
THCTensor_(copyLowerSymmetric)<<<blocks, threads, 0, stream>>>(input_data, n, len);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

THCTensor_(freeCopyTo)(state, input, ra_);
Expand Down
6 changes: 4 additions & 2 deletions aten/src/THC/generic/THCTensorMathReduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ void THCTensor_(renorm)(THCState *state, THCTensor* self, THCTensor* src, scalar
dim3 threads(32);

THCTensor_kernel_renorm<scalar_t, accreal>
<<<grid, threads, 0, c10::cuda::getCurrentCUDAStream()>>>
(THCTensor_(data)(state, data), scalar_cast<accreal>(value), size, scalar_cast<accreal>(maxnorm));
<<<grid, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(THCTensor_(data)(state, data),
scalar_cast<accreal>(value), size, scalar_cast<accreal>(maxnorm));

// Do not replace with C10_CUDA_KERNEL_LAUNCH_CHECK() yet as it exhibits different behaviour from THError().
// THError() calls the an error handler, or throws std::runtime_error if a custom handler hasn't been registered.
cudaError_t errcode = cudaGetLastError();
if(errcode != cudaSuccess)
THError(cudaGetErrorString(errcode));
Expand Down
17 changes: 9 additions & 8 deletions aten/src/THC/generic/THCTensorMode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define THC_GENERIC_FILE "THC/generic/THCTensorMode.cu"
#else

#include <c10/cuda/CUDAException.h>
#include <thrust/iterator/constant_iterator.h>

void THCTensor_(calculateMode)(THCState *state,
Expand Down Expand Up @@ -235,14 +236,14 @@ void THCTensor_(mode)(THCState *state,

// Macro that calls kernel --> note that we set the block dimensions here, and
// the amount of shared memory
#define HANDLE_MODE(SIZE) \
{ \
dim3 blockSize(SIZE / 2); \
\
int memsize = (sizeof(scalar_t) * SIZE) + (2 * SIZE * sizeof(unsigned int)); \
computeMode<scalar_t, SIZE> \
<<<grid, blockSize, memsize, c10::cuda::getCurrentCUDAStream()>>>( \
THCTensor_(data)(state, contiguous), tiValues, tiIndices, sliceSize); \
#define HANDLE_MODE(SIZE) \
{ \
const dim3 blockSize(SIZE / 2); \
const auto memsize = (sizeof(scalar_t) * SIZE) + (2 * SIZE * sizeof(unsigned int)); \
computeMode<scalar_t, SIZE> \
<<<grid, blockSize, memsize, c10::cuda::getCurrentCUDAStream()>>>( \
THCTensor_(data)(state, contiguous), tiValues, tiIndices, sliceSize); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
}

// Tradeoff between compilation time and the number of specializations. Ideally we would have
Expand Down
8 changes: 8 additions & 0 deletions aten/src/THC/generic/THCTensorRandom.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/CUDAGeneratorImpl.h>
#include <ATen/Utils.h>
#include <c10/cuda/CUDAException.h>
#include <utility>

#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)
Expand Down Expand Up @@ -39,6 +40,8 @@ void THCTensor_(multinomialAliasSetup)(THCState *state, THCTensor *_probs, THCud
THCudaLongTensor_data(state, larger_short),
one, inputsize
);
C10_CUDA_KERNEL_LAUNCH_CHECK();

at::Tensor smaller_short_wrapped = THTensor_wrap(smaller_short);
at::Tensor smaller_wrapped = THTensor_wrap(smaller);
at::Tensor larger_short_wrapped = THTensor_wrap(larger_short);
Expand All @@ -57,13 +60,16 @@ void THCTensor_(multinomialAliasSetup)(THCState *state, THCTensor *_probs, THCud
THCudaLongTensor_data(state, larger_short),
inputsize - h_large_c, h_large_c
);
C10_CUDA_KERNEL_LAUNCH_CHECK();

scalar_t q_max = at::max(THTensor_wrap(_q)).item<scalar_t>();
condDiv<<<
inputBlockDim, BLOCK_SIZE, 0, c10::cuda::getCurrentCUDAStream()>>>(
THCTensor_(data)(state, _q),
THCudaLongTensor_data(state, _J),
inputsize, q_max
);
C10_CUDA_KERNEL_LAUNCH_CHECK();

THCudaLongTensor_free(state, smaller);
THCudaLongTensor_free(state, larger);
Expand Down Expand Up @@ -104,6 +110,8 @@ void THCTensor_(multinomialAliasDraw)(THCState *state, THCudaLongTensor *self, T
THCTensor_(data)(state, uniform),
THCTensor_(data)(state, bernoulli)
);
C10_CUDA_KERNEL_LAUNCH_CHECK();

THCTensor_(free)(state, uniform);
THCTensor_(free)(state, bernoulli);
}
Expand Down
14 changes: 6 additions & 8 deletions aten/src/THC/generic/THCTensorScatterGather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
#define THC_GENERIC_FILE "THC/generic/THCTensorScatterGather.cu"
#else

#include <c10/cuda/CUDAException.h>

#define RUN(TYPE, DIMS, REAL) \
THCudaTensor_gatherKernel<TYPE, REAL, DIMS> \
<<<grid, block, 0, c10::cuda::getCurrentCUDAStream(curDevice)>>>( \
tensorInfo, srcInfo, indexInfo, dim, (TYPE)totalElements);
THCudaTensor_gatherKernel<TYPE, REAL, DIMS> \
<<<grid, block, 0, c10::cuda::getCurrentCUDAStream(curDevice)>>>( \
tensorInfo, srcInfo, indexInfo, dim, (TYPE)totalElements); \
C10_CUDA_KERNEL_LAUNCH_CHECK();

void THCTensor_(gather)(THCState* state, THCTensor *tensor,
THCTensor *src, int dim, THCudaLongTensor *index) {
Expand Down Expand Up @@ -61,19 +64,15 @@ void THCTensor_(gather)(THCState* state, THCTensor *tensor,
switch (indexInfo.dims) {
case 1:
RUN(unsigned int, 1, scalar_t);
THCudaCheck(cudaGetLastError());
break;
case 2:
RUN(unsigned int, 2, scalar_t);
THCudaCheck(cudaGetLastError());
break;
case 3:
RUN(unsigned int, 3, scalar_t);
THCudaCheck(cudaGetLastError());
break;
default:
RUN(unsigned int, -1, scalar_t);
THCudaCheck(cudaGetLastError());
break;
}
} else {
Expand All @@ -84,7 +83,6 @@ void THCTensor_(gather)(THCState* state, THCTensor *tensor,
TensorInfo<int64_t, uint64_t> indexInfo =
getTensorInfo<int64_t, THCudaLongTensor, uint64_t>(state, index);
RUN(uint64_t, -1, scalar_t);
THCudaCheck(cudaGetLastError());
}
}

Expand Down

0 comments on commit e2510a0

Please sign in to comment.