From f3ee1de2665fabe5b040e050f2981f520c133ef8 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Thu, 9 May 2024 22:20:36 -0700 Subject: [PATCH] Add memchecks to embedding_inplace_update ops Summary: - Add memchecks to embedding_inplace_update ops Differential Revision: D57191119 --- .../embedding_inplace_update.cu | 95 +++++++++---------- 1 file changed, 45 insertions(+), 50 deletions(-) diff --git a/fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update.cu b/fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update.cu index fcd9073d97..857b4eb9a6 100644 --- a/fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update.cu +++ b/fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update.cu @@ -6,13 +6,12 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include - -#include - #include "fbgemm_gpu/embedding_inplace_update.h" #include "fbgemm_gpu/fbgemm_cuda_utils.cuh" +#include "fbgemm_gpu/fbgemm_tensor_accessor.h" using Tensor = at::Tensor; @@ -22,28 +21,28 @@ constexpr int32_t kCacheLocationMissing = -1; template __launch_bounds__(kMaxThreads) __global__ void embedding_inplace_update_kernel( - at::PackedTensorAccessor64 dev_weights, - at::PackedTensorAccessor64 uvm_weights, - const at::PackedTensorAccessor32 + pta::PackedTensorAccessor64 dev_weights, + pta::PackedTensorAccessor64 uvm_weights, + const pta::PackedTensorAccessor32 weights_placements, - const at::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 weights_offsets, - const at::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 weights_tys, - const at::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 D_offsets, - const at::PackedTensorAccessor64 + const pta::PackedTensorAccessor64 update_weights, - const at::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 update_table_idx, - const at::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 update_row_idx, - const at::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 update_offsets, const int64_t row_alignment, - at::PackedTensorAccessor64 + pta::PackedTensorAccessor64 lxu_cache_weights, - const at::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 lxu_cache_locations) { // each row is updated by one warp of threads // blockIdx.x: block idx, threadIdx.x: thread idx in the warp, @@ -151,32 +150,29 @@ void embedding_inplace_update_cuda( AT_DISPATCH_INDEX_TYPES( update_row_idx.scalar_type(), "embedding_inplace_update_kernel", [&] { +#ifdef FBGEMM_GPU_MEMCHECK + const auto func_name = "embedding_inplace_update_kernel"; +#endif embedding_inplace_update_kernel<<< nbit::div_round_up(N, warpsPerBlock), // number of blocks needed dim3(kWarpSize, warpsPerBlock), // shape of each block 0, at::cuda::getCurrentCUDAStream()>>>( - dev_weights.packed_accessor64(), - uvm_weights.packed_accessor64(), - weights_placements - .packed_accessor32(), - weights_offsets - .packed_accessor32(), - weights_tys.packed_accessor32(), - D_offsets.packed_accessor32(), - update_weights - .packed_accessor64(), - update_table_idx - .packed_accessor32(), - update_row_idx - .packed_accessor32(), - update_offsets - .packed_accessor32(), + MAKE_PTA_WITH_NAME(func_name, dev_weights, uint8_t, 1, 64), + MAKE_PTA_WITH_NAME(func_name, uvm_weights, uint8_t, 1, 64), + MAKE_PTA_WITH_NAME(func_name, weights_placements, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, weights_tys, uint8_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, update_weights, uint8_t, 1, 64), + MAKE_PTA_WITH_NAME(func_name, update_table_idx, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, update_row_idx, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, update_offsets, int64_t, 1, 32), row_alignment, - lxu_cache_weights_value - .packed_accessor64(), - lxu_cache_locations_value - .packed_accessor32()); + MAKE_PTA_WITH_NAME( + func_name, lxu_cache_weights_value, uint8_t, 2, 64), + MAKE_PTA_WITH_NAME( + func_name, lxu_cache_locations_value, int32_t, 1, 32)); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } @@ -184,15 +180,15 @@ void embedding_inplace_update_cuda( template __global__ __launch_bounds__(kMaxThreads) void pruned_array_lookup_from_row_idx_kernel( - const at::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 update_row_indices, - const at::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 update_table_indices, - const at::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 index_remappings, - const at::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 index_remappings_offsets, - at::PackedTensorAccessor32 + pta::PackedTensorAccessor32 dense_indices) { const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= update_row_indices.size(0)) { @@ -255,21 +251,20 @@ Tensor pruned_array_lookup_from_row_idx_cuda( update_row_indices.scalar_type(), "pruned_array_lookup_from_row_idx_kernel", [&] { +#ifdef FBGEMM_GPU_MEMCHECK + const auto func_name = "pruned_array_lookup_from_row_idx_kernel"; +#endif pruned_array_lookup_from_row_idx_kernel<<< nbit::div_round_up(num_indices, kForwardMaxThreads), kForwardMaxThreads, 0, at::cuda::getCurrentCUDAStream()>>>( - update_row_indices - .packed_accessor32(), - update_table_indices - .packed_accessor32(), - index_remappings - .packed_accessor32(), - index_remappings_offsets - .packed_accessor32(), - dense_indices - .packed_accessor32()); + MAKE_PTA_WITH_NAME(func_name, update_row_indices, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, update_table_indices, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, index_remappings, int32_t, 1, 32), + MAKE_PTA_WITH_NAME( + func_name, index_remappings_offsets, int64_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32)); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); return dense_indices;