From 61305cd638b6fcd73a0b66b4cde7014fecb9e8ce Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 21 Jun 2022 21:46:10 +0000 Subject: [PATCH] Improve small sort performance on CUDA Currently, `bitonicSortKVInPlace` is written to sort one array per block of threads. If that dimension happens to be very small (<128 elements), this results in low thread occupancy. Instead, this changes `bitonicSortKVInPlace` to operate with a 2d block. Sorting happens along the x dimension, and the y dimension is a fixed size batch. Pull Request resolved: https://github.com/pytorch/pytorch/pull/79627 Approved by: https://github.com/ngimel --- aten/src/ATen/native/cuda/Sort.cu | 83 ++++++++++------- aten/src/ATen/native/cuda/SortUtils.cuh | 113 +++++++++++------------- aten/src/ATen/native/cuda/TensorTopK.cu | 2 +- 3 files changed, 106 insertions(+), 92 deletions(-) diff --git a/aten/src/ATen/native/cuda/Sort.cu b/aten/src/ATen/native/cuda/Sort.cu index 5c08ddf5978299a..5f21b1ceb7b5d78 100644 --- a/aten/src/ATen/native/cuda/Sort.cu +++ b/aten/src/ATen/native/cuda/Sort.cu @@ -15,6 +15,19 @@ namespace at { namespace native { +template +static int minimum_grid_for_occupancy(T kernel, int max_block_size) { + int minGridSize; + int blockSize; + C10_CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize( + &minGridSize, + &blockSize, + kernel, + /*dynamicSMemSize=*/0, + max_block_size)); + return minGridSize; +} + // In alignment with default sort on a c++ map, this function // will permute key and value tensors identically, and // in such a way that the 'key' tensor is ordered numerically @@ -45,43 +58,51 @@ void sortKeyValueInplace(const TensorBase& key, // vectorized (key, value) sort by slice segment TORCH_INTERNAL_ASSERT(ceilPowerOf2 <= 2048, "sortKeyValueInplace only works for sizes <= 2048 at present"); - // The grid is based on the number of independent slices that we - // have to sort; one block per slice - dim3 grid; - TORCH_INTERNAL_ASSERT(getGridFromTiles(keySlices, grid), "Too many slices to sort"); + const auto stream = c10::cuda::getCurrentCUDAStream(); -#define HANDLE_CASE(TYPE, A, SIZE) \ +#define HANDLE_CASE(TYPE, A, SIZE, BATCH) \ do { \ - int blockSize = SIZE / 2; \ - if (blockSize < 1) { \ - blockSize = 1; \ - } \ + constexpr int items_per_thread = 2; \ + static_assert(SIZE % items_per_thread == 0, ""); \ + constexpr int block_x = SIZE / items_per_thread; \ + constexpr int max_block_y = BATCH; \ + \ + /* Scale batch size down if the grid would be too small */ \ + const auto min_grid = minimum_grid_for_occupancy( \ + bitonicSortKVInPlace< \ + A, -1, block_x, max_block_y, \ + scalar_t, int64_t, LTOp, TYPE>, \ + block_x * max_block_y); \ + const auto max_batch = std::max(int64_t{1}, keySlices / min_grid); \ + const int block_y = std::min(int64_t{max_block_y}, max_batch); \ + dim3 block(block_x, block_y); \ \ - dim3 block(blockSize); \ + dim3 grid; \ + const int grid_count = (keySlices + block_y - 1) / block_y; \ + TORCH_INTERNAL_ASSERT(getGridFromTiles(grid_count, grid), \ + "Too many slices to sort"); \ \ if (dir) { \ - bitonicSortKVInPlace, TYPE, SIZE> \ - <<>>( \ + bitonicSortKVInPlace \ + <<>>( \ keyInfo, \ - keySlices, \ + (TYPE) keySlices, \ (TYPE) keySliceSize, \ (TYPE) keyInfo.strides[collapseKeyDim], \ valueInfo, \ (TYPE) valueInfo.strides[collapseValueDim], \ - GTOp()); \ + GTOp()); \ C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } else { \ - bitonicSortKVInPlace, TYPE, SIZE> \ - <<>>( \ + bitonicSortKVInPlace \ + <<>>( \ keyInfo, \ - keySlices, \ + (TYPE) keySlices, \ (TYPE) keySliceSize, \ (TYPE) keyInfo.strides[collapseKeyDim], \ valueInfo, \ (TYPE) valueInfo.strides[collapseValueDim], \ - LTOp()); \ + LTOp()); \ C10_CUDA_KERNEL_LAUNCH_CHECK(); \ } \ } while (0) @@ -90,29 +111,29 @@ void sortKeyValueInplace(const TensorBase& key, { \ switch (ceilPowerOf2) { \ case 2048: \ - HANDLE_CASE(TYPE, A, 2048); \ - break; \ + HANDLE_CASE(TYPE, A, 2048, 1); \ + break; \ case 1024: \ case 512: \ case 256: \ - HANDLE_CASE(TYPE, A, 1024); \ - break; \ + HANDLE_CASE(TYPE, A, 1024, 1); \ + break; \ case 128: \ case 64: \ - HANDLE_CASE(TYPE, A, 128); \ - break; \ + HANDLE_CASE(TYPE, A, 128, 4); \ + break; \ case 32: \ case 16: \ case 8: \ case 4: \ case 2: \ - HANDLE_CASE(TYPE, A, 32); \ - break; \ + HANDLE_CASE(TYPE, A, 32, 16); \ + break; \ case 1: \ - /* Nothing to do, data already sorted */ \ - break; \ + /* Nothing to do, data already sorted */ \ + break; \ default: \ - TORCH_INTERNAL_ASSERT(false); \ + TORCH_INTERNAL_ASSERT(false); \ } \ } diff --git a/aten/src/ATen/native/cuda/SortUtils.cuh b/aten/src/ATen/native/cuda/SortUtils.cuh index 10dd7bb3b024424..282a7ceb06b4b53 100644 --- a/aten/src/ATen/native/cuda/SortUtils.cuh +++ b/aten/src/ATen/native/cuda/SortUtils.cuh @@ -30,11 +30,11 @@ __device__ inline void bitonicSwap(K& kA, V& vA, bool& validA, } }; -template -__device__ inline void bitonicSort(K keys[Power2SortSize], - V values[Power2SortSize], - bool valid[Power2SortSize], +template +__device__ inline void bitonicSort(K *keys, + V *values, + bool *valid, const Comparator& comp) { #if !defined(USE_ROCM) #pragma unroll @@ -78,10 +78,9 @@ __device__ inline void bitonicSort(K keys[Power2SortSize], // at::cuda::detail::TensorInfo version // Sorts (key, value) pairs (in different tensors) in-place; i.e., // modifies the input `keys` and `values` -template -C10_LAUNCH_BOUNDS_1(1024) +template +C10_LAUNCH_BOUNDS_1(block_dim_x * max_block_dim_y) __global__ void bitonicSortKVInPlace(at::cuda::detail::TensorInfo keys, IndexType keySlices, @@ -91,69 +90,63 @@ bitonicSortKVInPlace(at::cuda::detail::TensorInfo keys, IndexType valueSliceStride, Comparator comp) { // Find the slice of the tensor that we are sorting - const IndexType linearIndex = getLinearBlockId(); - // Tiling the slices could have us be out of bounds, if there are a - // lot of slices to sort - if (linearIndex >= keySlices) { + // NOTE: blockDim.y may be less max_block_dim_y + const IndexType blockIndex = getLinearBlockId(); + const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y; + + // If the entire block is out of bounds exit early + if (blockIndex * blockDim.y >= keySlices) { return; } + // It's also possible for some rows of a block to be out of bounds + // but all thread need to run for __syncthreads to work. + const bool row_valid = linearIndex < keySlices; + + constexpr int items_per_thread = 2; + constexpr int Power2SortSize = block_dim_x * items_per_thread; + + // Storage for max_block_dim_y sorts performed in parallel + __shared__ K blockSharedKeys[max_block_dim_y][Power2SortSize]; + __shared__ V blockSharedValues[max_block_dim_y][Power2SortSize]; + __shared__ bool blockSharedValid[max_block_dim_y][Power2SortSize]; - __shared__ K sharedKeys[Power2SortSize]; - __shared__ V sharedValues[Power2SortSize]; - __shared__ bool sharedValid[Power2SortSize]; + auto sharedKeys = blockSharedKeys[threadIdx.y]; + auto sharedValues = blockSharedValues[threadIdx.y]; + auto sharedValid = blockSharedValid[threadIdx.y]; const IndexType keyStartOffset = at::cuda::detail::IndexToOffset::get(linearIndex, keys); const IndexType valueStartOffset = at::cuda::detail::IndexToOffset::get(linearIndex, values); - // If the sort size is 1, the data is already sorted - if (Power2SortSize == 1) { - return; - } else { - // Otherwise, each thread is responsible for loading and storing 2 - // elements. The sort size is guaranteed to be >= 2 - const int elem1 = threadIdx.x; - const int elem2 = threadIdx.x + (Power2SortSize / 2); - - bool valid1 = (elem1 < keySliceSize); - K k1 = valid1 ? - keys.data[keyStartOffset + elem1 * keySliceStride] : static_cast(0); - V v1 = valid1 ? - values.data[valueStartOffset + elem1 * valueSliceStride] : static_cast(0); - - sharedKeys[elem1] = k1; - sharedValues[elem1] = v1; - sharedValid[elem1] = valid1; - - bool valid2 = (elem2 < keySliceSize); - K k2 = valid2 ? - keys.data[keyStartOffset + elem2 * keySliceStride] : static_cast(0); - V v2 = valid2 ? - values.data[valueStartOffset + elem2 * valueSliceStride] : static_cast(0); - - sharedKeys[elem2] = k2; - sharedValues[elem2] = v2; - sharedValid[elem2] = valid2; - - // Sort! - bitonicSort( + // Load 2 values per thread into the shared workspace + #pragma unroll + for (int k = 0; k < items_per_thread; ++k) { + auto idx = threadIdx.x + k * blockDim.x; + bool valid = row_valid && idx < keySliceSize; + + sharedKeys[idx] = valid ? + keys.data[idx * keySliceStride + keyStartOffset] : K{}; + sharedValues[idx] = valid ? + values.data[idx * valueSliceStride + valueStartOffset] : V{}; + sharedValid[idx] = valid; + } + + // Sort! + bitonicSort( sharedKeys, sharedValues, sharedValid, comp); - // elem1 and elem2 values might be out-of-range, if the data size we are - // sorting is smaller than half the power2 size - if (valid1) { - keys.data[keyStartOffset + elem1 * keySliceStride] = - sharedKeys[elem1]; - values.data[valueStartOffset + elem1 * valueSliceStride] = - sharedValues[elem1]; - } + if (!row_valid) { + return; + } - if (valid2) { - keys.data[keyStartOffset + elem2 * keySliceStride] = - sharedKeys[elem2]; - values.data[valueStartOffset + elem2 * valueSliceStride] = - sharedValues[elem2]; + // Store outputs + #pragma unroll + for (int k = 0; k < items_per_thread; ++k) { + auto idx = threadIdx.x + k * blockDim.x; + if (idx < keySliceSize) { + keys.data[idx * keySliceStride + keyStartOffset] = sharedKeys[idx]; + values.data[idx * valueSliceStride + valueStartOffset] = sharedValues[idx]; } } } diff --git a/aten/src/ATen/native/cuda/TensorTopK.cu b/aten/src/ATen/native/cuda/TensorTopK.cu index a4763e2d6f0d452..1caf3ec57608641 100644 --- a/aten/src/ATen/native/cuda/TensorTopK.cu +++ b/aten/src/ATen/native/cuda/TensorTopK.cu @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -10,7 +11,6 @@ #include #include #include -#include #include #include #include