From f35c58346f3b5050f2858ad108348d023ae3940b Mon Sep 17 00:00:00 2001
From: Li Li
Date: Mon, 24 Nov 2025 17:41:50 -0800
Subject: [PATCH] minimize gpuAtomicAdd overhead in
bounds_check_indices_kernel_v2 (#5171)
Summary:
X-link: https://github.com/facebookresearch/FBGEMM/pull/2168
Differential Revision: D87008101
Pulled By: q10
---
.../utils/embedding_bounds_check_v2.cu | 31 +++++++++++++++++++
1 file changed, 31 insertions(+)
diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu
index 2f7bfc8bb7..31e3385559 100644
--- a/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu
+++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu
@@ -31,6 +31,12 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2(
index_t invalid_i = -1, invalid_idx = -1;
int32_t invalid_b_t = -1;
int64_t warning_inc = 0;
+#ifdef USE_ROCM
+ __shared__ int64_t block_warning_buffer[kMaxThreads];
+ const int linear_tid = threadIdx.z * (blockDim.y * blockDim.x) +
+ threadIdx.y * blockDim.x + threadIdx.x;
+ const int active_threads = blockDim.x * blockDim.y * blockDim.z;
+#endif
// Check the last element
if (b_t_start == 0 && threadIdx.x == 0) {
@@ -142,9 +148,34 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2(
}
} // for b_t
+#ifdef USE_ROCM
+ // Accumulate per-thread warning counts in shared memory and reduce once per
+ // block.
+ block_warning_buffer[linear_tid] = warning_inc;
+ __syncthreads();
+
+ // Parallel tree reduction
+ for (int stride = active_threads / 2; stride > 0; stride >>= 1) {
+ if (linear_tid < stride) {
+ block_warning_buffer[linear_tid] +=
+ block_warning_buffer[linear_tid + stride];
+ }
+ __syncthreads();
+ }
+
+ // Thread 0 has the final sum
+ if (linear_tid == 0) {
+ int64_t block_warning_sum = block_warning_buffer[0];
+ if (block_warning_sum > 0) {
+ gpuAtomicAdd(&warning[0], block_warning_sum);
+ }
+ }
+ __syncthreads();
+#else
if (warning_inc > 0) {
gpuAtomicAdd(&warning[0], warning_inc);
}
+#endif
if (bounds_check_mode == BoundsCheckMode::WARNING && invalid_i != -1 &&
static_cast(atomicAdd(
reinterpret_cast(&warning[0]), 0)) == 0) {