Skip to content

Commit

Permalink
Speedup bincount and histc on CUDA (#97090)
Browse files Browse the repository at this point in the history
This is to speed up torch.bincount and torch.histc on CUDA.

1. Speed up int64_t gpuAtomicAdd,
2. and optimize the histogram kernel.

# Fixes #96626
After speedup, time cost in #96626 would be

```
... (run 2 times and ignore the first run)
case 1 CPU  0.0003631114959716797 seconds
case 1 CUDA 0.0005860328674316406 seconds
case 2 CPU  0.0013742446899414062 seconds
case 2 CUDA 0.0008623600006103516 seconds
```

Note that in "*case 1 CUDA*", the **max** op takes the most time, i.e., https://github.com/pytorch/pytorch/blob/5ee5a164ffeb7b7a167c53009fb8fe5f5bd439d9/aten/src/ATen/native/cuda/SummaryOps.cu#L334-L335, which is not to be optimized in this PR.

# Benchmark

Time is measured on i7-10700 + RTX 3080, Ubuntu 22.04 (in WSL). The baseline is PyTorch 2.0.0+cu117. My dev version of PyTorch is compiled with CUDA 11.8. Each case is measured 15 times to take the median.

## torch.bincount
#elem | nbins | distribution | CPU | PyTorch 2.0.0 | this PR | speedup
-- | -- | -- | -- | -- | -- | --
2**20 | 80 | random.uniform | 0.000834 | 0.005783 | 0.000266 | 21.8x
2**20 | 80 | narrow in 1 bin | 0.001576 | 0.003967 | 0.000563 | 7.0x
2**20 | 500 | random.uniform | 0.000852 | 0.003641 | 0.000334 | 10.9x
2**20 | 500 | narrow in 1% bins | 0.000894 | 0.001878 | 0.000349 | 5.4x
2**20 | 2048 | random.uniform | 0.000891 | 0.000820 | 0.000298 | 2.8x
2**20 | 2048 | narrow in 1% bins | 0.000958 | 1.043251 | 0.000335 | 3,116.6x
2**26 | 80 | random.uniform | 0.067715 | 0.322409 | 0.003032 | 106.3x
2**26 | 80 | narrow in 1 bin | 0.110940 | 0.194644 | 0.017651 | 11.0x
2**26 | 500 | random.uniform | 0.066666 | 0.192302 | 0.002535 | 75.8x
2**26 | 500 | narrow in 1% bins | 0.066130 | 0.092237 | 0.005462 | 16.9x
2**26 | 2048 | random.uniform | 0.066371 | 0.035308 | 0.002476 | 14.3x
2**26 | 2048 | narrow in 1% bins | 0.068453 | 72.122858 | 0.003185 | 22,644.3x

## torch.histc (float32)
#elem | nbins | distribution | CPU | PyTorch 2.0.0 | this PR | speedup
-- | -- | -- | -- | -- | -- | --
2**20 | 80 | random.uniform | 0.001261 | 0.000145 | 9.47E-05 | 1.5x
2**20 | 80 | narrow in 1 bin | 0.001074 | 0.000356 | 0.000311 | 1.1x
2**20 | 500 | random.uniform | 0.001162 | 0.000227 | 9.18E-05 | 2.5x
2**20 | 500 | narrow in 1% bins | 0.001082 | 0.000201 | 0.000152 | 1.3x
2**20 | 2048 | random.uniform | 0.001100 | 0.000203 | 0.000118 | 1.7x
2**20 | 2048 | narrow in 1% bins | 0.001089 | 0.000396 | 0.000107 | 3.7x
2**26 | 80 | random.uniform | 0.064219 | 0.001170 | 0.000786 | 1.5x
2**26 | 80 | narrow in 1 bin | 0.056471 | 0.013283 | 0.011939 | 1.1x
2**26 | 500 | random.uniform | 0.078183 | 0.003411 | 0.000562 | 6.1x
2**26 | 500 | narrow in 1% bins | 0.056711 | 0.002763 | 0.002738 | 1.0x
2**26 | 2048 | random.uniform | 0.059296 | 0.003503 | 0.000533 | 6.6x
2**26 | 2048 | narrow in 1% bins | 0.061754 | 0.015703 | 0.000962 | 16.3x

## torch.histc (int64)
#elem | nbins | distribution | CPU | PyTorch 2.0.0 | this PR | speedup
-- | -- | -- | -- | -- | -- | --
2**20 | 80 | random.uniform | N/A | 0.005614 | 9.47E-05 | 59.3x
2**20 | 80 | narrow in 1 bin | N/A | 0.003799 | 0.000395 | 9.6x
2**20 | 500 | random.uniform | N/A | 0.003665 | 9.58E-05 | 38.2x
2**20 | 500 | narrow in 1% bins | N/A | 0.001760 | 0.000178 | 9.9x
2**20 | 2048 | random.uniform | N/A | 0.000693 | 0.000111 | 6.2x
2**20 | 2048 | narrow in 1% bins | N/A | 1.082904 | 0.000123 | 8,802.4x
2**26 | 80 | random.uniform | N/A | 0.320400 | 0.001145 | 279.9x
2**26 | 80 | narrow in 1 bin | N/A | 0.193668 | 0.015229 | 12.7x
2**26 | 500 | random.uniform | N/A | 0.182897 | 0.000823 | 222.2x
2**26 | 500 | narrow in 1% bins | N/A | 0.089363 | 0.00376 | 23.8x
2**26 | 2048 | random.uniform | N/A | 0.033190 | 0.000832 | 39.9x
2**26 | 2048 | narrow in 1% bins | N/A | 71.721012 | 0.001525 | 47,017.8x

## Banchmark code

Here is the benchmark code:

```python3
import time
import torch

cases = [
    ("bincount    bins=80   wide  ", torch.randint(80, [2**20]),   lambda x: torch.bincount(x, minlength=80)),
    ("bincount    bins=80   narrow", torch.randint(1, [2**20]),    lambda x: torch.bincount(x, minlength=80)),
    ("bincount    bins=500  wide  ", torch.randint(500, [2**20]),  lambda x: torch.bincount(x, minlength=500)),
    ("bincount    bins=500  narrow", torch.randint(5, [2**20]),    lambda x: torch.bincount(x, minlength=500)),
    ("bincount    bins=2048 wide  ", torch.randint(2048, [2**20]), lambda x: torch.bincount(x, minlength=2048)),
    ("bincount    bins=2048 narrow", torch.randint(20, [2**20]),   lambda x: torch.bincount(x, minlength=2048)),
    ("histc_float bins=80   wide  ", torch.rand(2**20),            lambda x: torch.histc(x, bins=80, min=0., max=1.)),
    ("histc_float bins=80   narrow", torch.rand(2**20)*.01,        lambda x: torch.histc(x, bins=80, min=0., max=1.)),
    ("histc_float bins=500  wide  ", torch.rand(2**20),            lambda x: torch.histc(x, bins=500, min=0., max=1.)),
    ("histc_float bins=500  narrow", torch.rand(2**20)*.01,        lambda x: torch.histc(x, bins=500, min=0., max=1.)),
    ("histc_float bins=2048 wide  ", torch.rand(2**20),            lambda x: torch.histc(x, bins=2048, min=0., max=1.)),
    ("histc_float bins=2048 narrow", torch.rand(2**20)*.01,        lambda x: torch.histc(x, bins=2048, min=0., max=1.)),
    ("histc_int   bins=80   wide  ", torch.randint(80, [2**20]),   lambda x: torch.histc(x, bins=80, min=0., max=80.)),
    ("histc_int   bins=80   narrow", torch.randint(1, [2**20]),    lambda x: torch.histc(x, bins=80, min=0., max=80.)),
    ("histc_int   bins=500  wide  ", torch.randint(500, [2**20]),  lambda x: torch.histc(x, bins=500, min=0., max=500.)),
    ("histc_int   bins=500  narrow", torch.randint(5, [2**20]),    lambda x: torch.histc(x, bins=500, min=0., max=500.)),
    ("histc_int   bins=2048 wide  ", torch.randint(2048, [2**20]), lambda x: torch.histc(x, bins=2048, min=0., max=2048.)),
    ("histc_int   bins=2048 narrow", torch.randint(20, [2**20]),   lambda x: torch.histc(x, bins=2048, min=0., max=2048.)),
]

def test(case, device):
    name, x, func = case
    x = x.to(device)
    time_samples = []
    for _ in range(15):
        torch.cuda.synchronize()
        t1 = time.time()
        func(x)
        torch.cuda.synchronize()
        t2 = time.time()
        time_samples.append(t2 - t1)
    median = sorted(time_samples)[len(time_samples) // 2]
    print(device, name, median)

for case in cases:
    test(case, device="cuda")

# for case in cases:
#     test(case, device="cpu")
```
Pull Request resolved: #97090
Approved by: https://github.com/ngimel
  • Loading branch information
yuantailing authored and pytorchmergebot committed Mar 24, 2023
1 parent f3cf3d7 commit 63e1f12
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 91 deletions.
7 changes: 2 additions & 5 deletions aten/src/ATen/cuda/Atomic.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,8 @@ static inline __device__ void gpuAtomicAdd(int64_t *address, int64_t val) {
#if defined(USE_ROCM)
__atomic_fetch_add(address, val, __ATOMIC_RELAXED);
#else
AtomicAddIntegerImpl<int64_t, sizeof(int64_t)>()(address,
val,
[](int64_t a, int64_t b) {
return a + b;
});
static_assert(sizeof(unsigned long long int) == sizeof(int64_t), "bitwidth change is not allowed");
atomicAdd(reinterpret_cast<unsigned long long int *>(address), static_cast<unsigned long long int>(val));
#endif
}

Expand Down
97 changes: 19 additions & 78 deletions aten/src/ATen/native/cuda/SummaryOps.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@

namespace at {
namespace cuda {
#define THRESH_NUMBER_BINS_FOR_MULTI_BLOCK_MEM 100
#define THRESH_NUMBER_BINS_FOR_GLOBAL_MEM 1000
#define RATIO_OF_GMEM_ATOMIC_ADD_TO_SMEM_ATOMIC_ADD 8
#define FOR_KERNEL_LOOP(i, lim) \
for (IndexType i = blockIdx.x * blockDim.x + threadIdx.x; i < lim; \
i += gridDim.x * blockDim.x)
Expand All @@ -30,7 +29,7 @@ namespace cuda {
Memory types used for the 3 histogram implementations.
See `CUDA_tensor_histogram` below.
*/
enum class CUDAHistogramMemoryType { SHARED, MULTI_BLOCK, GLOBAL };
enum class CUDAHistogramMemoryType { SHARED, GLOBAL };
namespace {
template <typename input_t, typename IndexType>
__device__ static IndexType getBin(
Expand Down Expand Up @@ -60,7 +59,7 @@ template <
int ADims,
int PDims,
int BDims,
CUDAHistogramMemoryType MemoryType = CUDAHistogramMemoryType::MULTI_BLOCK,
CUDAHistogramMemoryType MemoryType,
typename Op>
C10_LAUNCH_BOUNDS_1(cuda::getApplyBlockSize())
__global__ void kernelHistogram1D(
Expand Down Expand Up @@ -106,39 +105,6 @@ __global__ void kernelHistogram1D(
gpuAtomicAddNoReturn(&a.data[aOffset], smem[i]);
}

} else if (MemoryType == CUDAHistogramMemoryType::MULTI_BLOCK) {
////////////////////////// Multi Block memory //////////////////////////
// atomically add to block specific global tensor
// then atomically add to the global output tensor
// compute histogram for the block
FOR_KERNEL_LOOP(linearIndex, totalElements) {
// Convert `linearIndex` into an offset of `b`
const IndexType bOffset =
detail::IndexToOffset<input_t, IndexType, BDims>::get(linearIndex, b);
const auto bVal = b.data[bOffset];
if (bVal >= minvalue && bVal <= maxvalue) {
// Use value at `b` as an offset of `p`
const IndexType bin =
getBin<input_t, IndexType>(bVal, minvalue, maxvalue, nbins);
const IndexType pIdx = p.strides[0] * blockIdx.x + bin;
const IndexType pOffset =
detail::IndexToOffset<output_t, IndexType, PDims>::get(pIdx, p);
gpuAtomicAddNoReturn(&p.data[pOffset], getOp(linearIndex));
}
}
__syncthreads();
// NOTE: atomically update output bin count.
// Atomic update is imp since __syncthread() will only synchronize threads
// in a given block, not across blocks.
const IndexType pIdx = p.strides[0] * blockIdx.x;
const IndexType pOffset =
detail::IndexToOffset<output_t, IndexType, PDims>::get(pIdx, p);
for (IndexType i = threadIdx.x; i < a.sizes[0]; i += blockDim.x) {
const IndexType aOffset =
detail::IndexToOffset<output_t, IndexType, ADims>::get(i, a);
gpuAtomicAddNoReturn(&a.data[aOffset], p.data[pOffset + i]);
}

} else {
////////////////////////// Global memory //////////////////////////
// atomically add to the output tensor
Expand Down Expand Up @@ -184,23 +150,10 @@ __global__ void kernelHistogram1D(
case CUDAHistogramMemoryType::SHARED: \
HANDLE_CASE(CUDAHistogramMemoryType::SHARED, getOp, sharedMem); \
break; \
case CUDAHistogramMemoryType::MULTI_BLOCK: \
HANDLE_CASE(CUDAHistogramMemoryType::MULTI_BLOCK, getOp, 0); \
break; \
default: \
HANDLE_CASE(CUDAHistogramMemoryType::GLOBAL, getOp, 0); \
}

inline int64_t getFreeGlobalMemory() {
// no need to use `cudaSetDevice`
size_t free_mem, total_mem;
cudaMemGetInfo(&free_mem, &total_mem);
TORCH_INTERNAL_ASSERT(
cudaGetLastError() == cudaSuccess,
"CUDA_tensor_histogram failed to get free global memory");
return static_cast<int64_t>(free_mem);
}

/*
Calculate the frequency of the input values.
Expand All @@ -210,13 +163,10 @@ inline int64_t getFreeGlobalMemory() {
See `help torch.bincount` for details on the math.
3 implementations based of input size and memory usage:
case: #bins < THRESH_NUMBER_BINS_FOR_MULTI_BLOCK_MEM and enough shared mem
case: enough shared mem
SHARED: Each block atomically adds to it's own **shared** hist copy,
then atomically updates the global tensor.
case: #bins < THRESH_NUMBER_BINS_FOR_GLOBAL_MEM and enough global mem
MULTI_BLOCK: Each block atomically adds to it's own **global** hist
copy, then atomically updates the global tensor.
case: THRESH_NUMBER_BINS_FOR_GLOBAL_MEM <= #bins
case: no enough shared mem
GLOBAL: all threads atomically update to a single **global** hist copy.
*/
template <typename output_t, typename input_t, bool HasWeights>
Expand Down Expand Up @@ -250,35 +200,27 @@ bool CUDA_tensor_histogram(
CUDAHistogramMemoryType memType = CUDAHistogramMemoryType::GLOBAL;
auto maxSharedMem = getCurrentDeviceProperties()->sharedMemPerBlock;
auto sharedMem = nbins * sizeof(output_t) + 8; // 8 guard bytes
auto maxGlobalMem = getFreeGlobalMemory();
auto multiBlockMem = nbins * grid.x * sizeof(output_t) + 8; // 8 guard bytes
// determine memory type to use in the kernel
if (nbins < THRESH_NUMBER_BINS_FOR_MULTI_BLOCK_MEM &&
sharedMem < maxSharedMem) {
if (sharedMem < maxSharedMem) {
// Solve equations:
// (1) #(smem atomicAdd per SM) = totalElements / min(grid.x, #SM)
// (2) #(gmem atomicAdd) = grid.x * nbins
// (3) RATIO_OF_GMEM_ATOMIC_ADD_TO_SMEM_ATOMIC_ADD = #(gmem atomicAdd) / #(smem atomicAdd per SM)
unsigned optimalGrid = ceil_div<size_t>(RATIO_OF_GMEM_ATOMIC_ADD_TO_SMEM_ATOMIC_ADD * totalElements,
nbins * getCurrentDeviceProperties()->multiProcessorCount);
if (optimalGrid < (unsigned)getCurrentDeviceProperties()->multiProcessorCount) {
optimalGrid = 1 + (unsigned)std::sqrt(RATIO_OF_GMEM_ATOMIC_ADD_TO_SMEM_ATOMIC_ADD * totalElements / nbins);
}
auto optimalSteps = ceil_div<size_t>(totalElements, optimalGrid * block.x);
optimalGrid = ceil_div<size_t>(totalElements, optimalSteps * block.x);
grid.x = std::min(grid.x, optimalGrid);
memType = CUDAHistogramMemoryType::SHARED;
} else if (
nbins < THRESH_NUMBER_BINS_FOR_GLOBAL_MEM &&
multiBlockMem < static_cast<size_t>(maxGlobalMem / 2)) {
// check against half of free mem to be extra safe
// due to cached allocator, we may anyway have slightly more free mem
memType = CUDAHistogramMemoryType::MULTI_BLOCK;
}

// alloc memory for MULTI_BLOCK
using IndexType = int64_t;
auto aInfo = detail::getTensorInfo<output_t, IndexType>(a);
auto bInfo = detail::getTensorInfo<input_t, IndexType>(b);
detail::TensorInfo<output_t, IndexType> pInfo(nullptr, 0, {}, {});
Tensor partial_output;
if (memType == CUDAHistogramMemoryType::MULTI_BLOCK) {
partial_output = at::zeros(
{grid.x, nbins},
optTypeMetaToScalarType(a.options().dtype_opt()),
a.options().layout_opt(),
a.options().device_opt(),
a.options().pinned_memory_opt());
pInfo = detail::getTensorInfo<output_t, IndexType>(partial_output);
}

if (HasWeights) {
auto cInfo = detail::getTensorInfo<output_t, IndexType>(c);
Expand All @@ -298,8 +240,7 @@ bool CUDA_tensor_histogram(
#undef HANDLE_CASE
#undef HANDLE_SWITCH_CASE
#undef FOR_KERNEL_LOOP
#undef THRESH_NUMBER_BINS_FOR_GLOBAL_MEM
#undef THRESH_NUMBER_BINS_FOR_MULTI_BLOCK_MEM
#undef RATIO_OF_GMEM_ATOMIC_ADD_TO_SMEM_ATOMIC_ADD
} // namespace cuda

namespace {
Expand Down
12 changes: 4 additions & 8 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1709,21 +1709,17 @@ def test_nvtx(self):

def test_bincount_ext(self):
# ensure CUDA code coverage
input_size = (5000,)
input_size = (100000,)
w = torch.randn(input_size, dtype=torch.double, device='cuda')
w_cpu = w.cpu()
# test shared memory impl
t = torch.randint(50, input_size, dtype=torch.int8, device='cuda')
self.assertEqual(t.cpu().bincount(), t.bincount())
self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
# test multi block memory impl
# see `THRESH_NUMBER_BINS_FOR_MULTI_BLOCK_MEM` in SummaryOps.cu
t = torch.randint(500, input_size, dtype=torch.int64, device='cuda')
self.assertEqual(t.cpu().bincount(), t.bincount())
self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
# test global memory impl
# see `THRESH_NUMBER_BINS_FOR_GLOBAL_MEM` in SummaryOps.cu
t = torch.randint(2000, input_size, dtype=torch.int64, device='cuda')
# see `CUDAHistogramMemoryType` in SummaryOps.cu
# 50000 * sizeof(int64_t) == 390 KiB, which should exceed smem of any known GPU
t = torch.randint(50000, input_size, dtype=torch.int64, device='cuda')
self.assertEqual(t.cpu().bincount(), t.bincount())
self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))

Expand Down

0 comments on commit 63e1f12

Please sign in to comment.