Skip to content

Commit

Permalink
Add memchecks to embedding_inplace_update ops
Browse files Browse the repository at this point in the history
Summary: - Add memchecks to embedding_inplace_update ops

Differential Revision: D57191119
  • Loading branch information
q10 authored and facebook-github-bot committed May 10, 2024
1 parent a7b73a4 commit f3ee1de
Showing 1 changed file with 45 additions and 50 deletions.
95 changes: 45 additions & 50 deletions fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
* LICENSE file in the root directory of this source tree.
*/

#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>

#include <c10/cuda/CUDAGuard.h>

#include "fbgemm_gpu/embedding_inplace_update.h"
#include "fbgemm_gpu/fbgemm_cuda_utils.cuh"
#include "fbgemm_gpu/fbgemm_tensor_accessor.h"

using Tensor = at::Tensor;

Expand All @@ -22,28 +21,28 @@ constexpr int32_t kCacheLocationMissing = -1;

template <typename index_t>
__launch_bounds__(kMaxThreads) __global__ void embedding_inplace_update_kernel(
at::PackedTensorAccessor64<uint8_t, 1, at::RestrictPtrTraits> dev_weights,
at::PackedTensorAccessor64<uint8_t, 1, at::RestrictPtrTraits> uvm_weights,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
pta::PackedTensorAccessor64<uint8_t, 1, at::RestrictPtrTraits> dev_weights,
pta::PackedTensorAccessor64<uint8_t, 1, at::RestrictPtrTraits> uvm_weights,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
weights_placements,
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
weights_offsets,
const at::PackedTensorAccessor32<uint8_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<uint8_t, 1, at::RestrictPtrTraits>
weights_tys,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
D_offsets,
const at::PackedTensorAccessor64<uint8_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor64<uint8_t, 1, at::RestrictPtrTraits>
update_weights,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
update_table_idx,
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
update_row_idx,
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
update_offsets,
const int64_t row_alignment,
at::PackedTensorAccessor64<uint8_t, 2, at::RestrictPtrTraits>
pta::PackedTensorAccessor64<uint8_t, 2, at::RestrictPtrTraits>
lxu_cache_weights,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
lxu_cache_locations) {
// each row is updated by one warp of threads
// blockIdx.x: block idx, threadIdx.x: thread idx in the warp,
Expand Down Expand Up @@ -151,48 +150,45 @@ void embedding_inplace_update_cuda(

AT_DISPATCH_INDEX_TYPES(
update_row_idx.scalar_type(), "embedding_inplace_update_kernel", [&] {
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name = "embedding_inplace_update_kernel";
#endif
embedding_inplace_update_kernel<<<
nbit::div_round_up(N, warpsPerBlock), // number of blocks needed
dim3(kWarpSize, warpsPerBlock), // shape of each block
0,
at::cuda::getCurrentCUDAStream()>>>(
dev_weights.packed_accessor64<uint8_t, 1, at::RestrictPtrTraits>(),
uvm_weights.packed_accessor64<uint8_t, 1, at::RestrictPtrTraits>(),
weights_placements
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
weights_offsets
.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
weights_tys.packed_accessor32<uint8_t, 1, at::RestrictPtrTraits>(),
D_offsets.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
update_weights
.packed_accessor64<uint8_t, 1, at::RestrictPtrTraits>(),
update_table_idx
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
update_row_idx
.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
update_offsets
.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
MAKE_PTA_WITH_NAME(func_name, dev_weights, uint8_t, 1, 64),
MAKE_PTA_WITH_NAME(func_name, uvm_weights, uint8_t, 1, 64),
MAKE_PTA_WITH_NAME(func_name, weights_placements, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, weights_tys, uint8_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, update_weights, uint8_t, 1, 64),
MAKE_PTA_WITH_NAME(func_name, update_table_idx, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, update_row_idx, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, update_offsets, int64_t, 1, 32),
row_alignment,
lxu_cache_weights_value
.packed_accessor64<uint8_t, 2, at::RestrictPtrTraits>(),
lxu_cache_locations_value
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>());
MAKE_PTA_WITH_NAME(
func_name, lxu_cache_weights_value, uint8_t, 2, 64),
MAKE_PTA_WITH_NAME(
func_name, lxu_cache_locations_value, int32_t, 1, 32));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
template <typename index_t>
__global__
__launch_bounds__(kMaxThreads) void pruned_array_lookup_from_row_idx_kernel(
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
update_row_indices,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
update_table_indices,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
index_remappings,
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
index_remappings_offsets,
at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
dense_indices) {
const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= update_row_indices.size(0)) {
Expand Down Expand Up @@ -255,21 +251,20 @@ Tensor pruned_array_lookup_from_row_idx_cuda(
update_row_indices.scalar_type(),
"pruned_array_lookup_from_row_idx_kernel",
[&] {
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name = "pruned_array_lookup_from_row_idx_kernel";
#endif
pruned_array_lookup_from_row_idx_kernel<<<
nbit::div_round_up(num_indices, kForwardMaxThreads),
kForwardMaxThreads,
0,
at::cuda::getCurrentCUDAStream()>>>(
update_row_indices
.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
update_table_indices
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
index_remappings
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
index_remappings_offsets
.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
dense_indices
.packed_accessor32<index_t, 1, at::RestrictPtrTraits>());
MAKE_PTA_WITH_NAME(func_name, update_row_indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, update_table_indices, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, index_remappings, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, index_remappings_offsets, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32));
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
return dense_indices;
Expand Down

0 comments on commit f3ee1de

Please sign in to comment.