Skip to content

Commit

Permalink
fix asserts in cuda code
Browse files Browse the repository at this point in the history
  • Loading branch information
ngimel committed May 27, 2020
1 parent c6e9e93 commit 79ee1f5
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/EmbeddingBag.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ __global__ void EmbeddingBag_updateOutputKernel(
scalar_t *weightFeat = weight + featureDim * weight_stride1;
int64_t begin = bag == 0 ? 0 : offsets[bag]; // forces first offset to be 0 instead of asserting on it
int64_t end = (bag < numBags - 1) ? (offsets[bag + 1]) : numIndices;
assert(end >= begin);
CUDA_KERNEL_ASSERT(end >= begin);

accscalar_t weightFeatSum = 0;
scalar_t weightFeatMax;
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/cuda/SortingRadixSelect.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ struct TopKTypeConfig<int16_t> {
typedef uint32_t RadixType;

static inline __device__ RadixType convert(int16_t v) {
assert(sizeof(short) == 2);
static_assert(sizeof(short) == 2);
return 32768u + v;
}

Expand All @@ -79,7 +79,7 @@ struct TopKTypeConfig<int32_t> {
typedef uint32_t RadixType;

static inline __device__ RadixType convert(int32_t v) {
assert(sizeof(int) == 4);
static_assert(sizeof(int) == 4);
return 2147483648u + v;
}

Expand All @@ -93,7 +93,7 @@ struct TopKTypeConfig<int64_t> {
typedef uint64_t RadixType;

static inline __device__ RadixType convert(int64_t v) {
assert(sizeof(int64_t) == 8);
static_assert(sizeof(int64_t) == 8);
return 9223372036854775808ull + v;
}

Expand Down
4 changes: 2 additions & 2 deletions aten/src/THC/THCTensorInfo.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,15 @@ TensorInfo<T, IndexType>::TensorInfo(T* p,
template <typename T, typename IndexType>
void
TensorInfo<T, IndexType>::reduceDim(int dim) {
assert(dim < dims && dim >= 0);
TORCH_INTERNAL_ASSERT(dim < dims && dim >= 0);
sizes[dim] = 1;
}

template <typename T, typename IndexType>
int
TensorInfo<T, IndexType>::collapseDims(const int excludeDim) {

assert(excludeDim >= -1 && excludeDim < dims);
TORCH_INTERNAL_ASSERT(excludeDim >= -1 && excludeDim < dims);

int stopDim = (excludeDim == -1) ? dims : excludeDim;
int newIndex = -1;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/THC/generic/THCTensorMode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ void THCTensor_(mode)(THCState *state,
break;
case 1:
default:
assert(false);
TORCH_INTERNAL_ASSERT(false);
}
THCudaCheck(cudaGetLastError());

Expand Down
2 changes: 1 addition & 1 deletion aten/src/THC/generic/THCTensorSort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ void THCTensor_(sortKeyValueInplace)(THCState* state,
/* Nothing to do, data already sorted */ \
break; \
default: \
assert(false); \
TORCH_INTERNAL_ASSERT(false); \
} \
}

Expand Down
6 changes: 3 additions & 3 deletions caffe2/operators/top_k_radix_selection.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ struct TopKTypeConfig<short> {
typedef unsigned int RadixType;

static inline __device__ RadixType convert(short v) {
CUDA_KERNEL_ASSERT(sizeof(short) == 2);
static_assert(sizeof(short) == 2);
return 32768u + v;
}

Expand All @@ -90,7 +90,7 @@ struct TopKTypeConfig<int> {
typedef unsigned int RadixType;

static inline __device__ RadixType convert(int v) {
CUDA_KERNEL_ASSERT(sizeof(int) == 4);
static_assert(sizeof(int) == 4);
return 2147483648u + v;
}

Expand All @@ -104,7 +104,7 @@ struct TopKTypeConfig<long> {
typedef unsigned long long int RadixType;

static inline __device__ RadixType convert(long v) {
CUDA_KERNEL_ASSERT(sizeof(long) == 8);
static_assert(sizeof(long) == 8);
return 9223372036854775808ull + v;
}

Expand Down

0 comments on commit 79ee1f5

Please sign in to comment.