From f79687a5903478033cae1adbe93c921fd5b412e2 Mon Sep 17 00:00:00 2001 From: Li Li Date: Wed, 12 Nov 2025 22:56:39 +0000 Subject: [PATCH 1/5] opt embedding_bounds_check_v2 --- .../utils/embedding_bounds_check_v2.cu | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu index 2f7bfc8bb7..ed21c8c5e3 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu @@ -31,6 +31,10 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2( index_t invalid_i = -1, invalid_idx = -1; int32_t invalid_b_t = -1; int64_t warning_inc = 0; + extern __shared__ int64_t block_warning_buffer[]; + const int linear_tid = threadIdx.z * (blockDim.y * blockDim.x) + + threadIdx.y * blockDim.x + threadIdx.x; + const int active_threads = blockDim.x * blockDim.y * blockDim.z; // Check the last element if (b_t_start == 0 && threadIdx.x == 0) { @@ -142,9 +146,21 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2( } } // for b_t - if (warning_inc > 0) { - gpuAtomicAdd(&warning[0], warning_inc); + // Accumulate per-thread warning counts in shared memory and reduce once per block. + block_warning_buffer[linear_tid] = warning_inc; + __syncthreads(); + + if (linear_tid == 0) { + int64_t block_warning_sum = 0; + for (int idx = 0; idx < active_threads; ++idx) { + block_warning_sum += block_warning_buffer[idx]; + } + block_warning_buffer[0] = block_warning_sum; + if (block_warning_sum > 0) { + gpuAtomicAdd(&warning[0], block_warning_sum); + } } + __syncthreads(); if (bounds_check_mode == BoundsCheckMode::WARNING && invalid_i != -1 && static_cast(atomicAdd( reinterpret_cast(&warning[0]), 0)) == 0) { @@ -227,7 +243,7 @@ void _bounds_check_indices_cuda_v2( grid_dim, \ dim3( \ fbgemm_gpu::kWarpSize, kNumThreads / fbgemm_gpu::kWarpSize), \ - 0, \ + sizeof(int64_t) * kNumThreads, \ at::cuda::getCurrentCUDAStream(), \ PTA_B(rows_per_table, int64_t, 1, 32), \ PTA_B(indices, index_t, 1, 32), \ From 25b936177930e64259ddbed04a0ff5d3f4202e68 Mon Sep 17 00:00:00 2001 From: Li Li Date: Thu, 13 Nov 2025 17:42:24 +0000 Subject: [PATCH 2/5] fixed size of the shared memory --- fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu index ed21c8c5e3..f4f8d74363 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu @@ -31,7 +31,7 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2( index_t invalid_i = -1, invalid_idx = -1; int32_t invalid_b_t = -1; int64_t warning_inc = 0; - extern __shared__ int64_t block_warning_buffer[]; + __shared__ int64_t block_warning_buffer[kMaxThreads]; const int linear_tid = threadIdx.z * (blockDim.y * blockDim.x) + threadIdx.y * blockDim.x + threadIdx.x; const int active_threads = blockDim.x * blockDim.y * blockDim.z; @@ -243,7 +243,7 @@ void _bounds_check_indices_cuda_v2( grid_dim, \ dim3( \ fbgemm_gpu::kWarpSize, kNumThreads / fbgemm_gpu::kWarpSize), \ - sizeof(int64_t) * kNumThreads, \ + 0 , \ at::cuda::getCurrentCUDAStream(), \ PTA_B(rows_per_table, int64_t, 1, 32), \ PTA_B(indices, index_t, 1, 32), \ From b6956747f078a49883b15095f29c251b1c5d0358 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Thu, 13 Nov 2025 13:23:09 -0600 Subject: [PATCH 3/5] use parallel tree reduction for shared warning buffer within thread block --- .../codegen/utils/embedding_bounds_check_v2.cu | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu index f4f8d74363..f1e2420601 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu @@ -150,12 +150,17 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2( block_warning_buffer[linear_tid] = warning_inc; __syncthreads(); - if (linear_tid == 0) { - int64_t block_warning_sum = 0; - for (int idx = 0; idx < active_threads; ++idx) { - block_warning_sum += block_warning_buffer[idx]; + // Parallel tree reduction + for (int stride = active_threads / 2; stride > 0; stride >>= 1) { + if (linear_tid < stride) { + block_warning_buffer[linear_tid] += block_warning_buffer[linear_tid + stride]; } - block_warning_buffer[0] = block_warning_sum; + __syncthreads(); + } + + // Thread 0 has the final sum + if (linear_tid == 0) { + int64_t block_warning_sum = block_warning_buffer[0]; if (block_warning_sum > 0) { gpuAtomicAdd(&warning[0], block_warning_sum); } From 4cdac24e74a87896dd2a2a66073d577a9a53f74b Mon Sep 17 00:00:00 2001 From: Li Li Date: Thu, 13 Nov 2025 19:44:18 +0000 Subject: [PATCH 4/5] minor format fix --- fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu index f1e2420601..ae13a5afcb 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu @@ -248,7 +248,7 @@ void _bounds_check_indices_cuda_v2( grid_dim, \ dim3( \ fbgemm_gpu::kWarpSize, kNumThreads / fbgemm_gpu::kWarpSize), \ - 0 , \ + 0, \ at::cuda::getCurrentCUDAStream(), \ PTA_B(rows_per_table, int64_t, 1, 32), \ PTA_B(indices, index_t, 1, 32), \ From 2aaaee361af464331760c15b402ed69a586856fd Mon Sep 17 00:00:00 2001 From: Li Li Date: Fri, 21 Nov 2025 05:05:47 +0000 Subject: [PATCH 5/5] guard atomic operation optimization in /workspace/FBGEMM/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu by #ifdef USE_ROCM --- .../utils/embedding_bounds_check_v2.cu | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu index ae13a5afcb..27811cb0ec 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu @@ -28,13 +28,15 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2( const index_t num_indices = indices.size(0); const auto b_t_start = blockIdx.x * blockDim.y + threadIdx.y; - index_t invalid_i = -1, invalid_idx = -1; - int32_t invalid_b_t = -1; - int64_t warning_inc = 0; - __shared__ int64_t block_warning_buffer[kMaxThreads]; - const int linear_tid = threadIdx.z * (blockDim.y * blockDim.x) + - threadIdx.y * blockDim.x + threadIdx.x; - const int active_threads = blockDim.x * blockDim.y * blockDim.z; + #ifdef USE_ROCM + index_t invalid_i = -1, invalid_idx = -1; + int32_t invalid_b_t = -1; + int64_t warning_inc = 0; + __shared__ int64_t block_warning_buffer[kMaxThreads]; + const int linear_tid = threadIdx.z * (blockDim.y * blockDim.x) + + threadIdx.y * blockDim.x + threadIdx.x; + const int active_threads = blockDim.x * blockDim.y * blockDim.z; + #endif // Check the last element if (b_t_start == 0 && threadIdx.x == 0) { @@ -146,6 +148,7 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2( } } // for b_t +#ifdef USE_ROCM // Accumulate per-thread warning counts in shared memory and reduce once per block. block_warning_buffer[linear_tid] = warning_inc; __syncthreads(); @@ -166,6 +169,11 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2( } } __syncthreads(); +#else + if (warning_inc > 0) { + gpuAtomicAdd(&warning[0], warning_inc); + } +#endif if (bounds_check_mode == BoundsCheckMode::WARNING && invalid_i != -1 && static_cast(atomicAdd( reinterpret_cast(&warning[0]), 0)) == 0) {