Skip to content

Commit

Permalink
Improve small sort performance on CUDA
Browse files Browse the repository at this point in the history
Currently, `bitonicSortKVInPlace` is written to sort one array per
block of threads. If that dimension happens to be very small
(<128 elements), this results in low thread occupancy.

Instead, this changes `bitonicSortKVInPlace` to operate with a 2d
block. Sorting happens along the x dimension, and the y dimension
is a fixed size batch.

Pull Request resolved: #79627

Approved by: https://github.com/ngimel
  • Loading branch information
pytorchmergebot committed Jun 22, 2022
1 parent 9244547 commit 61305cd
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 92 deletions.
83 changes: 52 additions & 31 deletions aten/src/ATen/native/cuda/Sort.cu
Expand Up @@ -15,6 +15,19 @@

namespace at { namespace native {

template <typename T>
static int minimum_grid_for_occupancy(T kernel, int max_block_size) {
int minGridSize;
int blockSize;
C10_CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(
&minGridSize,
&blockSize,
kernel,
/*dynamicSMemSize=*/0,
max_block_size));
return minGridSize;
}

// In alignment with default sort on a c++ map, this function
// will permute key and value tensors identically, and
// in such a way that the 'key' tensor is ordered numerically
Expand Down Expand Up @@ -45,43 +58,51 @@ void sortKeyValueInplace(const TensorBase& key,
// vectorized (key, value) sort by slice segment
TORCH_INTERNAL_ASSERT(ceilPowerOf2 <= 2048, "sortKeyValueInplace only works for sizes <= 2048 at present");

// The grid is based on the number of independent slices that we
// have to sort; one block per slice
dim3 grid;
TORCH_INTERNAL_ASSERT(getGridFromTiles(keySlices, grid), "Too many slices to sort");
const auto stream = c10::cuda::getCurrentCUDAStream();

#define HANDLE_CASE(TYPE, A, SIZE) \
#define HANDLE_CASE(TYPE, A, SIZE, BATCH) \
do { \
int blockSize = SIZE / 2; \
if (blockSize < 1) { \
blockSize = 1; \
} \
constexpr int items_per_thread = 2; \
static_assert(SIZE % items_per_thread == 0, ""); \
constexpr int block_x = SIZE / items_per_thread; \
constexpr int max_block_y = BATCH; \
\
/* Scale batch size down if the grid would be too small */ \
const auto min_grid = minimum_grid_for_occupancy( \
bitonicSortKVInPlace< \
A, -1, block_x, max_block_y, \
scalar_t, int64_t, LTOp<scalar_t, true>, TYPE>, \
block_x * max_block_y); \
const auto max_batch = std::max(int64_t{1}, keySlices / min_grid); \
const int block_y = std::min(int64_t{max_block_y}, max_batch); \
dim3 block(block_x, block_y); \
\
dim3 block(blockSize); \
dim3 grid; \
const int grid_count = (keySlices + block_y - 1) / block_y; \
TORCH_INTERNAL_ASSERT(getGridFromTiles(grid_count, grid), \
"Too many slices to sort"); \
\
if (dir) { \
bitonicSortKVInPlace<scalar_t, int64_t, A, -1, \
GTOp<scalar_t, true>, TYPE, SIZE> \
<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>( \
bitonicSortKVInPlace<A, -1, block_x, max_block_y> \
<<<grid, block, 0, stream>>>( \
keyInfo, \
keySlices, \
(TYPE) keySlices, \
(TYPE) keySliceSize, \
(TYPE) keyInfo.strides[collapseKeyDim], \
valueInfo, \
(TYPE) valueInfo.strides[collapseValueDim], \
GTOp<scalar_t, true>()); \
GTOp<scalar_t, true>()); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
} else { \
bitonicSortKVInPlace<scalar_t, int64_t, A, -1, \
LTOp<scalar_t, true>, TYPE, SIZE> \
<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>( \
bitonicSortKVInPlace<A, -1, block_x, max_block_y> \
<<<grid, block, 0, stream>>>( \
keyInfo, \
keySlices, \
(TYPE) keySlices, \
(TYPE) keySliceSize, \
(TYPE) keyInfo.strides[collapseKeyDim], \
valueInfo, \
(TYPE) valueInfo.strides[collapseValueDim], \
LTOp<scalar_t, true>()); \
LTOp<scalar_t, true>()); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
} \
} while (0)
Expand All @@ -90,29 +111,29 @@ void sortKeyValueInplace(const TensorBase& key,
{ \
switch (ceilPowerOf2) { \
case 2048: \
HANDLE_CASE(TYPE, A, 2048); \
break; \
HANDLE_CASE(TYPE, A, 2048, 1); \
break; \
case 1024: \
case 512: \
case 256: \
HANDLE_CASE(TYPE, A, 1024); \
break; \
HANDLE_CASE(TYPE, A, 1024, 1); \
break; \
case 128: \
case 64: \
HANDLE_CASE(TYPE, A, 128); \
break; \
HANDLE_CASE(TYPE, A, 128, 4); \
break; \
case 32: \
case 16: \
case 8: \
case 4: \
case 2: \
HANDLE_CASE(TYPE, A, 32); \
break; \
HANDLE_CASE(TYPE, A, 32, 16); \
break; \
case 1: \
/* Nothing to do, data already sorted */ \
break; \
/* Nothing to do, data already sorted */ \
break; \
default: \
TORCH_INTERNAL_ASSERT(false); \
TORCH_INTERNAL_ASSERT(false); \
} \
}

Expand Down
113 changes: 53 additions & 60 deletions aten/src/ATen/native/cuda/SortUtils.cuh
Expand Up @@ -30,11 +30,11 @@ __device__ inline void bitonicSwap(K& kA, V& vA, bool& validA,
}
};

template <typename Comparator, typename K, typename V,
typename IndexType, int Power2SortSize>
__device__ inline void bitonicSort(K keys[Power2SortSize],
V values[Power2SortSize],
bool valid[Power2SortSize],
template <int Power2SortSize, typename IndexType, typename Comparator,
typename K, typename V>
__device__ inline void bitonicSort(K *keys,
V *values,
bool *valid,
const Comparator& comp) {
#if !defined(USE_ROCM)
#pragma unroll
Expand Down Expand Up @@ -78,10 +78,9 @@ __device__ inline void bitonicSort(K keys[Power2SortSize],
// at::cuda::detail::TensorInfo version
// Sorts (key, value) pairs (in different tensors) in-place; i.e.,
// modifies the input `keys` and `values`
template <typename K, typename V,
int KeyDims, int ValueDims,
typename Comparator, typename IndexType, int Power2SortSize>
C10_LAUNCH_BOUNDS_1(1024)
template <int KeyDims, int ValueDims, int block_dim_x, int max_block_dim_y,
typename K, typename V, typename Comparator, typename IndexType>
C10_LAUNCH_BOUNDS_1(block_dim_x * max_block_dim_y)
__global__ void
bitonicSortKVInPlace(at::cuda::detail::TensorInfo<K, IndexType> keys,
IndexType keySlices,
Expand All @@ -91,69 +90,63 @@ bitonicSortKVInPlace(at::cuda::detail::TensorInfo<K, IndexType> keys,
IndexType valueSliceStride,
Comparator comp) {
// Find the slice of the tensor that we are sorting
const IndexType linearIndex = getLinearBlockId<IndexType>();
// Tiling the slices could have us be out of bounds, if there are a
// lot of slices to sort
if (linearIndex >= keySlices) {
// NOTE: blockDim.y may be less max_block_dim_y
const IndexType blockIndex = getLinearBlockId<IndexType>();
const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y;

// If the entire block is out of bounds exit early
if (blockIndex * blockDim.y >= keySlices) {
return;
}
// It's also possible for some rows of a block to be out of bounds
// but all thread need to run for __syncthreads to work.
const bool row_valid = linearIndex < keySlices;

constexpr int items_per_thread = 2;
constexpr int Power2SortSize = block_dim_x * items_per_thread;

// Storage for max_block_dim_y sorts performed in parallel
__shared__ K blockSharedKeys[max_block_dim_y][Power2SortSize];
__shared__ V blockSharedValues[max_block_dim_y][Power2SortSize];
__shared__ bool blockSharedValid[max_block_dim_y][Power2SortSize];

__shared__ K sharedKeys[Power2SortSize];
__shared__ V sharedValues[Power2SortSize];
__shared__ bool sharedValid[Power2SortSize];
auto sharedKeys = blockSharedKeys[threadIdx.y];
auto sharedValues = blockSharedValues[threadIdx.y];
auto sharedValid = blockSharedValid[threadIdx.y];

const IndexType keyStartOffset =
at::cuda::detail::IndexToOffset<K, IndexType, KeyDims>::get(linearIndex, keys);
const IndexType valueStartOffset =
at::cuda::detail::IndexToOffset<V, IndexType, ValueDims>::get(linearIndex, values);

// If the sort size is 1, the data is already sorted
if (Power2SortSize == 1) {
return;
} else {
// Otherwise, each thread is responsible for loading and storing 2
// elements. The sort size is guaranteed to be >= 2
const int elem1 = threadIdx.x;
const int elem2 = threadIdx.x + (Power2SortSize / 2);

bool valid1 = (elem1 < keySliceSize);
K k1 = valid1 ?
keys.data[keyStartOffset + elem1 * keySliceStride] : static_cast<K>(0);
V v1 = valid1 ?
values.data[valueStartOffset + elem1 * valueSliceStride] : static_cast<V>(0);

sharedKeys[elem1] = k1;
sharedValues[elem1] = v1;
sharedValid[elem1] = valid1;

bool valid2 = (elem2 < keySliceSize);
K k2 = valid2 ?
keys.data[keyStartOffset + elem2 * keySliceStride] : static_cast<K>(0);
V v2 = valid2 ?
values.data[valueStartOffset + elem2 * valueSliceStride] : static_cast<V>(0);

sharedKeys[elem2] = k2;
sharedValues[elem2] = v2;
sharedValid[elem2] = valid2;

// Sort!
bitonicSort<Comparator, K, V, IndexType, Power2SortSize>(
// Load 2 values per thread into the shared workspace
#pragma unroll
for (int k = 0; k < items_per_thread; ++k) {
auto idx = threadIdx.x + k * blockDim.x;
bool valid = row_valid && idx < keySliceSize;

sharedKeys[idx] = valid ?
keys.data[idx * keySliceStride + keyStartOffset] : K{};
sharedValues[idx] = valid ?
values.data[idx * valueSliceStride + valueStartOffset] : V{};
sharedValid[idx] = valid;
}

// Sort!
bitonicSort<Power2SortSize, IndexType>(
sharedKeys, sharedValues, sharedValid, comp);

// elem1 and elem2 values might be out-of-range, if the data size we are
// sorting is smaller than half the power2 size
if (valid1) {
keys.data[keyStartOffset + elem1 * keySliceStride] =
sharedKeys[elem1];
values.data[valueStartOffset + elem1 * valueSliceStride] =
sharedValues[elem1];
}
if (!row_valid) {
return;
}

if (valid2) {
keys.data[keyStartOffset + elem2 * keySliceStride] =
sharedKeys[elem2];
values.data[valueStartOffset + elem2 * valueSliceStride] =
sharedValues[elem2];
// Store outputs
#pragma unroll
for (int k = 0; k < items_per_thread; ++k) {
auto idx = threadIdx.x + k * blockDim.x;
if (idx < keySliceSize) {
keys.data[idx * keySliceStride + keyStartOffset] = sharedKeys[idx];
values.data[idx * valueSliceStride + valueStartOffset] = sharedValues[idx];
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/TensorTopK.cu
Expand Up @@ -3,14 +3,14 @@
#include <ATen/core/TensorBase.h>
#include <ATen/ceil_div.h>
#include <ATen/Dispatch.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/cuda/ScanUtils.cuh>
#include <ATen/cuda/AsmUtils.cuh>
#include <ATen/cuda/DeviceUtils.cuh>
#include <ATen/native/cuda/SortingCommon.cuh>
#include <ATen/native/cuda/SortingRadixSelect.cuh>
#include <ATen/native/cuda/SortUtils.cuh>
#include <ATen/cuda/cub.cuh>
#include <c10/cuda/CUDACachingAllocator.h>
#include <ATen/cuda/detail/KernelUtils.h>
Expand Down

0 comments on commit 61305cd

Please sign in to comment.