Skip to content

Commit

Permalink
Add/modify LXU cache lookup ops for pipeline prefetching (#2154)
Browse files Browse the repository at this point in the history
Summary:

This diff adds/updates LXU cache APIs for pipeline prefetching:

- Update `lxu_cache_lookup` to allow for unique linear cache indices
lookup and external output tensor to be passed to the op
- Update `lxu_cache_locations_update` to support unique cache
locations update
- Add Python binding for `get_unique_indices`

Reviewed By: levythu

Differential Revision: D51532548
  • Loading branch information
Sarunya Pumma authored and facebook-github-bot committed Nov 27, 2023
1 parent b40f419 commit 4da1f0c
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 32 deletions.
5 changes: 4 additions & 1 deletion fbgemm_gpu/codegen/embedding_forward_quantized_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,10 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
lxu_cache_state.value(),
total_cache_hash_size.value(),
gather_uvm_stats,
uvm_cache_stats);
uvm_cache_stats,
c10::optional<Tensor>(), // num_uniq_cache_indices
c10::optional<Tensor>() // lxu_cache_locations_output
);

#ifdef FBCODE_CAFFE2
if (FLAGS_tbe_uvm_cache_enforced_misses > 0) {
Expand Down
7 changes: 5 additions & 2 deletions fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ at::Tensor lxu_cache_lookup_cuda(
at::Tensor lxu_cache_state,
int64_t invalid_index,
bool gather_cache_stats,
c10::optional<at::Tensor> uvm_cache_stats);
c10::optional<at::Tensor> uvm_cache_stats,
c10::optional<at::Tensor> num_uniq_cache_indices,
c10::optional<at::Tensor> lxu_cache_locations_output);

at::Tensor emulate_cache_miss(
at::Tensor lxu_cache_locations,
Expand Down Expand Up @@ -240,4 +242,5 @@ void lxu_cache_locking_counter_decrement_cuda(
/// and lxu_cache_locations_new[i] >= 0
void lxu_cache_locations_update_cuda(
at::Tensor lxu_cache_locations,
at::Tensor lxu_cache_locations_new);
at::Tensor lxu_cache_locations_new,
c10::optional<at::Tensor> num_uniq_cache_indices);
4 changes: 3 additions & 1 deletion fbgemm_gpu/src/split_embeddings_cache/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ Tensor lxu_cache_lookup_cpu(
Tensor lxu_cache_state,
int64_t invalid_index,
bool gather_cache_stats,
c10::optional<Tensor> uvm_cache_stats);
c10::optional<Tensor> uvm_cache_stats,
c10::optional<Tensor> num_uniq_cache_indices,
c10::optional<Tensor> lxu_cache_locations_output);

Tensor direct_mapped_lxu_cache_lookup_cpu(
Tensor linear_cache_indices,
Expand Down
14 changes: 8 additions & 6 deletions fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ namespace fbgemm_gpu {

DLL_PUBLIC Tensor lxu_cache_lookup_cpu(
Tensor linear_cache_indices,
Tensor lxu_cache_state,
int64_t invalid_index,
bool gather_cache_stats,
c10::optional<Tensor> uvm_cache_stats) {
return empty_like(
linear_cache_indices, linear_cache_indices.options().dtype(at::kInt));
Tensor /* lxu_cache_state */,
int64_t /* invalid_index */,
bool /* gather_cache_stats */,
c10::optional<Tensor> /* uvm_cache_stats */,
c10::optional<Tensor> /* num_uniq_cache_indices */,
c10::optional<Tensor> lxu_cache_locations_output) {
return lxu_cache_locations_output.value_or(empty_like(
linear_cache_indices, linear_cache_indices.options().dtype(at::kInt)));
}

DLL_PUBLIC Tensor direct_mapped_lxu_cache_lookup_cpu(
Expand Down
93 changes: 74 additions & 19 deletions fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,11 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_lookup_kernel(
lxu_cache_locations,
const bool gather_cache_stats,
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
uvm_cache_stats) {
uvm_cache_stats,
const int32_t* N_unique) {
const int32_t C = lxu_cache_state.size(0);
const int32_t N = linear_cache_indices.size(0);
const int32_t N =
N_unique == nullptr ? linear_cache_indices.size(0) : *N_unique;
const int32_t n0 =
blockIdx.x * blockDim.y * blockDim.x + threadIdx.y * blockDim.x;
if (n0 >= N) {
Expand Down Expand Up @@ -368,14 +370,56 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lxu_cache_lookup_kernel(
} // namespace
/// Lookup the cache locations for each linear cache indices in
/// linear_cache_indices and return lxu_cache_locations
///
/// lxu_cache_locations A 1D tensor with the same length as
/// linear_cache_indices. It contains the cache locations
/// (the row indices in the cache) of the corresponding
/// indices in linear_cache_indices, i.e.,
/// lxu_cache_locations[i] is the cache location for
/// linear_cache_indices[i], where 0 <= i <
/// linear_cache_indices.numel().
///
/// @param linear_cache_indices Linear cache indices tensor (1D)
/// @param lxu_cache_state LXU cache state tensor (2D tensor of
/// shape (# of cache sets, # of cache
/// slots per set)). It contains linear
/// indices of rows that are in the
/// corresponding cache slots. If the cache
/// slot is empty, a sentinel value is
/// stored.
/// @param invalid_index A sentinel value for linear cache
/// indices. A cache index is skipped if it
/// is a sentinel value.
/// @param gather_cache_stats A flag to enable/disable cache stats
/// collection.
/// @param uvm_cache_stats A tensor for storing cache stats.
/// @param num_uniq_cache_indices An optional GPU tensor that contains the
/// number of unique cache indices. If this
/// tensor is passed, the kernel will only
/// lookup num_uniq_cache_indices number of
/// indices instead of looking up the entire
/// linear_cache_indices.
/// @param lxu_cache_locations_output An optional output tensor. If the
/// tensor is passed, the operator will not
/// allocate a new output tensor and use
/// this tensor as an output tensor.
DLL_PUBLIC Tensor lxu_cache_lookup_cuda(
Tensor linear_cache_indices,
Tensor lxu_cache_state,
int64_t invalid_index,
bool gather_cache_stats,
c10::optional<Tensor> uvm_cache_stats) {
const Tensor linear_cache_indices,
const Tensor lxu_cache_state,
const int64_t invalid_index,
const bool gather_cache_stats,
const c10::optional<Tensor> uvm_cache_stats,
const c10::optional<Tensor> num_uniq_cache_indices,
const c10::optional<Tensor> lxu_cache_locations_output) {
const auto uniq_lookup = num_uniq_cache_indices.has_value();
// TODO: Support gather_cache_stats=true when uniq_lookup=true
TORCH_CHECK(
!uniq_lookup || !gather_cache_stats,
"Unique lxu_cache_locations generation does not support gather_cache_stats=true");
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
linear_cache_indices, lxu_cache_state);
linear_cache_indices, lxu_cache_state, num_uniq_cache_indices);
Tensor uvm_cache_stats_ =
at::empty({0}, linear_cache_indices.options().dtype(at::kInt));
if (gather_cache_stats) {
Expand All @@ -386,9 +430,12 @@ DLL_PUBLIC Tensor lxu_cache_lookup_cuda(
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(linear_cache_indices.get_device());
const auto lxu_cache_locations =
lxu_cache_locations_output.value_or(empty_like(
linear_cache_indices,
linear_cache_indices.options().dtype(at::kInt)));
const auto N = linear_cache_indices.numel();
auto lxu_cache_locations = empty_like(
linear_cache_indices, linear_cache_indices.options().dtype(at::kInt));
if (linear_cache_indices.numel() == 0) {
// nothing to do
return lxu_cache_locations;
Expand All @@ -412,10 +459,12 @@ DLL_PUBLIC Tensor lxu_cache_lookup_cuda(
invalid_index,
MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations, int32_t, 1, 32),
gather_cache_stats,
MAKE_PTA_WITH_NAME(func_name, uvm_cache_stats_, int32_t, 1, 32));
MAKE_PTA_WITH_NAME(func_name, uvm_cache_stats_, int32_t, 1, 32),
num_uniq_cache_indices.has_value()
? num_uniq_cache_indices.value().data_ptr<int32_t>()
: nullptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
return lxu_cache_locations;
}
Expand Down Expand Up @@ -479,11 +528,13 @@ __launch_bounds__(kMaxThreads) void lxu_cache_locations_update_kernel(
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
lxu_cache_locations,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
lxu_cache_locations_new) {
const int32_t N = lxu_cache_locations.size(0);
lxu_cache_locations_new,
const int32_t* N_unique) {
const auto N = N_unique == nullptr ? lxu_cache_locations.size(0) : *N_unique;
CUDA_KERNEL_LOOP(n, N) {
if (lxu_cache_locations[n] == kCacheLocationMissing &&
lxu_cache_locations_new[n] >= 0) {
if (N_unique != nullptr ||
(lxu_cache_locations[n] == kCacheLocationMissing &&
lxu_cache_locations_new[n] >= 0)) {
lxu_cache_locations[n] = lxu_cache_locations_new[n];
}
}
Expand All @@ -493,9 +544,10 @@ __launch_bounds__(kMaxThreads) void lxu_cache_locations_update_kernel(
DLL_PUBLIC void lxu_cache_locations_update_cuda(
Tensor lxu_cache_locations,
Tensor lxu_cache_locations_new) {
Tensor lxu_cache_locations_new,
c10::optional<Tensor> num_uniq_cache_indices) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
lxu_cache_locations, lxu_cache_locations_new);
lxu_cache_locations, lxu_cache_locations_new, num_uniq_cache_indices);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(lxu_cache_locations.get_device());
Expand All @@ -520,7 +572,10 @@ DLL_PUBLIC void lxu_cache_locations_update_cuda(
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations_new, int32_t, 1, 32));
MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations_new, int32_t, 1, 32),
num_uniq_cache_indices.has_value()
? num_uniq_cache_indices.value().data_ptr<int32_t>()
: nullptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,10 @@ DLL_PUBLIC void reset_weight_momentum_cuda(
lxu_cache_state,
total_cache_hash_size,
false, // gather_cache_stats
uvm_cache_stats);
uvm_cache_stats,
c10::optional<Tensor>(), // num_uniq_cache_indices
c10::optional<Tensor>() // lxu_cache_locations_output
);
}
// Reset weight and momentum of pruned rows
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"lfu_cache_populate_byte(Tensor weights, Tensor cache_hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, Tensor(c!) lfu_state, int row_alignment=16) -> ()");
m.def(
"lxu_cache_lookup(Tensor linear_cache_indices, Tensor lxu_cache_state, int invalid_index = -1, bool gather_cache_stats=False, Tensor(a!)? uvm_cache_stats=None) -> Tensor");
"lxu_cache_lookup(Tensor linear_cache_indices, Tensor lxu_cache_state, int invalid_index = -1, bool gather_cache_stats=False, Tensor(a!)? uvm_cache_stats=None, Tensor? num_uniq_cache_indices=None, Tensor(b!)? lxu_cache_locations_output=None) -> Tensor");
m.def(
"direct_mapped_lxu_cache_lookup(Tensor linear_cache_indices, Tensor lxu_cache_state, int invalid_index = -1, bool gather_cache_stats=False, Tensor(a!)? uvm_cache_stats=None) -> Tensor");
m.def(
Expand All @@ -37,7 +37,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"lxu_cache_locking_counter_decrement(Tensor(a!) lxu_cache_locking_counter, Tensor lxu_cache_locations) -> ()");
m.def(
"lxu_cache_locations_update(Tensor(a!) lxu_cache_locations, Tensor lxu_cache_locations_new) -> ()");
"lxu_cache_locations_update(Tensor(a!) lxu_cache_locations, Tensor lxu_cache_locations_new, Tensor? num_uniq_cache_indices=None) -> ()");
m.def(
"get_unique_indices(Tensor linear_indices, int max_indices, bool compute_count) -> (Tensor, Tensor, Tensor?)");
}

using namespace fbgemm_gpu;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
lxu_cache_locking_counter_decrement_cuda);
DISPATCH_TO_CUDA(
"lxu_cache_locations_update", lxu_cache_locations_update_cuda);
DISPATCH_TO_CUDA("get_unique_indices", get_unique_indices_cuda);
}

} // namespace

0 comments on commit 4da1f0c

Please sign in to comment.