Skip to content

Commit

Permalink
Add helper ops to support cache conflict misses (#2571)
Browse files Browse the repository at this point in the history
Summary:

This diff adds helper operators for the cache conflict miss support
enablement in SSD TBE.  Changes include:
- Extend `get_unique_indices_cuda` to compute and return inverse
  linear indices (the tensor that contains the original positions of
  lienar indices before sorting)
- Extend `lru_cache_find_uncached_cuda` to compute and return the
  inverse cache sets (the tensor that contains the original positions
  of cache sets of unique indices before sorting)
- Update SSD backend to support cache conflict misses instead of
  failing. The rows that experience conflict misses will be stored in
  a scratch pad for TBE kernels to consume. They will be evicted to
  SSD once the backward+optimizer step of TBE is completed.
- Add `ssd_generate_row_addrs` for generating row addresses of data
  that is fetched from SSD (data can be in either a scratch pad or LXU
  cache).

Differential Revision: D55926421
  • Loading branch information
sryap authored and facebook-github-bot committed May 8, 2024
1 parent 7d15c59 commit c8c316d
Show file tree
Hide file tree
Showing 12 changed files with 556 additions and 203 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,10 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
evicted_indices,
assigned_cache_slots,
actions_count_gpu,
_,
_,
_,
_,
) = torch.ops.fbgemm.ssd_cache_populate_actions(
linear_cache_indices,
self.total_hash_size,
Expand Down Expand Up @@ -874,6 +878,10 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Tensor:
evicted_indices,
assigned_cache_slots,
actions_count_gpu,
_,
_,
_,
_,
) = torch.ops.fbgemm.ssd_cache_populate_actions(
linear_cache_indices,
self.total_hash_size,
Expand Down
15 changes: 11 additions & 4 deletions fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,22 @@ enum uvm_cache_stats_index {

///@ingroup table-batched-embed-cuda
/// Deduplicate indices.
std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>
std::tuple<
at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
c10::optional<at::Tensor>>
get_unique_indices_cuda(
at::Tensor linear_indices,
int64_t max_indices,
bool compute_count);
bool compute_count,
const bool compute_inverse_indices);

///@ingroup table-batched-embed-cuda
/// Lookup LRU cache to find uncached indices, and then sort them based on the
/// set.
std::pair<at::Tensor, at::Tensor> lru_cache_find_uncached_cuda(
std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>
lru_cache_find_uncached_cuda(
at::Tensor unique_indices,
at::Tensor unique_indices_length,
int64_t max_indices,
Expand All @@ -47,7 +53,8 @@ std::pair<at::Tensor, at::Tensor> lru_cache_find_uncached_cuda(
bool gather_cache_stats,
at::Tensor uvm_cache_stats,
bool lock_cache_line,
at::Tensor lxu_cache_locking_counter);
at::Tensor lxu_cache_locking_counter,
const bool compute_inverse_indices);

///@ingroup table-batched-embed-cuda
/// Map index to cache_set. h_in: linear_indices; C: #cache_sets.
Expand Down
16 changes: 10 additions & 6 deletions fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,16 @@ DLL_PUBLIC void lfu_cache_populate_cuda(
}
// get unqiue indices
Tensor unique_indices;
Tensor unique_indices_length;
c10::optional<Tensor> unique_indices_count;
std::tie(unique_indices, unique_indices_length, unique_indices_count) =
get_unique_indices_cuda(
linear_cache_indices, total_cache_hash_size, true);
auto
[unique_indices,
unique_indices_length,
unique_indices_count,
linear_cache_indices_positions_sorted] =
get_unique_indices_cuda(
linear_cache_indices,
total_cache_hash_size,
/*compute_count=*/true,
/*compute_inverse_indices=*/false);
// update lfu counts
lfu_update_counts_cuda(
Expand Down
16 changes: 10 additions & 6 deletions fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,16 @@ DLL_PUBLIC void lfu_cache_populate_byte_cuda(
}
// get unqiue indices
Tensor unique_indices;
Tensor unique_indices_length;
c10::optional<Tensor> unique_indices_count;
std::tie(unique_indices, unique_indices_length, unique_indices_count) =
get_unique_indices_cuda(
linear_cache_indices, total_cache_hash_size, true);
auto
[unique_indices,
unique_indices_length,
unique_indices_count,
linear_indices_postions_sorted] =
get_unique_indices_cuda(
linear_cache_indices,
total_cache_hash_size,
/*compute_count=*/true,
/*compute_inverse_indices=*/false);
// update lfu counts
lfu_update_counts_cuda(
Expand Down
170 changes: 97 additions & 73 deletions fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cu
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,13 @@ DLL_PUBLIC Tensor linearize_cache_indices_from_row_idx_cuda(
return linear_cache_indices;
}
DLL_PUBLIC std::tuple<Tensor, Tensor, c10::optional<Tensor>>
DLL_PUBLIC
std::tuple<Tensor, Tensor, c10::optional<Tensor>, c10::optional<Tensor>>
get_unique_indices_cuda(
Tensor linear_indices,
int64_t max_indices,
bool compute_count) {
bool compute_count,
const bool compute_inverse_indices) {
TENSOR_ON_CUDA_GPU(linear_indices);
CUDA_DEVICE_GUARD(linear_indices);
Expand All @@ -216,90 +218,112 @@ get_unique_indices_cuda(
auto unique_indices_length =
at::empty({1}, linear_indices.options().dtype(at::kInt));
c10::optional<Tensor> unique_indices_count = c10::nullopt;
c10::optional<Tensor> linear_index_positions_sorted = c10::nullopt;
Tensor linear_index_positions;
if (compute_inverse_indices) {
linear_index_positions = at::arange(
{linear_indices.numel()}, linear_indices.options().dtype(at::kInt));
linear_index_positions_sorted = at::empty_like(linear_index_positions);
}
if (compute_count) {
unique_indices_count = at::empty(
{linear_indices.numel()}, linear_indices.options().dtype(at::kInt));
}
#define INVOKE_CUB_SORT_PAIRS(TEMP_STORAGE_PTR) \
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( \
TEMP_STORAGE_PTR, \
temp_storage_bytes, \
linear_indices.data_ptr<index_t>(), \
sorted_indices.data_ptr<index_t>(), \
linear_index_positions.data_ptr<int32_t>(), \
linear_index_positions_sorted->data_ptr<int32_t>(), \
N, \
0, \
int(log2(float(max_indices + 1)) + 1), \
at::cuda::getCurrentCUDAStream(), \
false))
#define INVOKE_CUB_SORT_KEYS(TEMP_STORAGE_PTR) \
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortKeys( \
TEMP_STORAGE_PTR, \
temp_storage_bytes, \
linear_indices.data_ptr<index_t>(), \
sorted_indices.data_ptr<index_t>(), \
N, \
0, \
int(log2(float(max_indices + 1)) + 1), \
at::cuda::getCurrentCUDAStream(), \
false))
#define INVOKE_CUB_ENCODE(TEMP_STORAGE_PTR) \
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRunLengthEncode::Encode( \
TEMP_STORAGE_PTR, \
temp_storage_bytes, \
sorted_indices.data_ptr<index_t>(), \
unique_indices.data_ptr<index_t>(), \
unique_indices_count->data_ptr<int32_t>(), \
unique_indices_length.data_ptr<int32_t>(), \
N, \
at::cuda::getCurrentCUDAStream(), \
false))
#define INVOKE_CUB_UNIQUE(TEMP_STORAGE_PTR) \
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceSelect::Unique( \
TEMP_STORAGE_PTR, \
temp_storage_bytes, \
sorted_indices.data_ptr<index_t>(), \
unique_indices.data_ptr<index_t>(), \
unique_indices_length.data_ptr<int32_t>(), \
N, \
at::cuda::getCurrentCUDAStream(), \
false))
AT_DISPATCH_INDEX_TYPES(
linear_indices.scalar_type(), "get_unique_indices_cuda", [&] {
// sort indices
size_t temp_storage_bytes_0 = 0;
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortKeys(
nullptr,
temp_storage_bytes_0,
linear_indices.data_ptr<index_t>(),
sorted_indices.data_ptr<index_t>(),
N,
0,
int(log2(float(max_indices + 1)) + 1),
at::cuda::getCurrentCUDAStream(),
false));
auto temp_storage_0 = at::empty(
{static_cast<index_t>(temp_storage_bytes_0)},
linear_indices.options().dtype(at::kByte));
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortKeys(
temp_storage_0.data_ptr(),
temp_storage_bytes_0,
linear_indices.data_ptr<index_t>(),
sorted_indices.data_ptr<index_t>(),
N,
0,
int(log2(float(max_indices + 1)) + 1),
at::cuda::getCurrentCUDAStream(),
false));
if (compute_inverse_indices) {
size_t temp_storage_bytes = 0;
INVOKE_CUB_SORT_PAIRS(nullptr);
auto temp_storage = at::empty(
{static_cast<int64_t>(temp_storage_bytes)},
linear_indices.options().dtype(at::kByte));
INVOKE_CUB_SORT_PAIRS(temp_storage.data_ptr());
} else {
size_t temp_storage_bytes = 0;
INVOKE_CUB_SORT_KEYS(nullptr);
auto temp_storage = at::empty(
{static_cast<index_t>(temp_storage_bytes)},
linear_indices.options().dtype(at::kByte));
INVOKE_CUB_SORT_KEYS(temp_storage.data_ptr());
}
// get unique indices
if (compute_count) {
size_t temp_storage_bytes_1 = 0;
AT_CUDA_CHECK(
FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRunLengthEncode::Encode(
nullptr,
temp_storage_bytes_1,
sorted_indices.data_ptr<index_t>(),
unique_indices.data_ptr<index_t>(),
unique_indices_count->data_ptr<int32_t>(),
unique_indices_length.data_ptr<int32_t>(),
N,
at::cuda::getCurrentCUDAStream(),
false));
auto temp_storage_1 = at::empty(
{static_cast<index_t>(temp_storage_bytes_1)},
size_t temp_storage_bytes = 0;
INVOKE_CUB_ENCODE(nullptr);
auto temp_storage = at::empty(
{static_cast<index_t>(temp_storage_bytes)},
linear_indices.options().dtype(at::kByte));
AT_CUDA_CHECK(
FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRunLengthEncode::Encode(
temp_storage_1.data_ptr(),
temp_storage_bytes_1,
sorted_indices.data_ptr<index_t>(),
unique_indices.data_ptr<index_t>(),
unique_indices_count->data_ptr<int32_t>(),
unique_indices_length.data_ptr<int32_t>(),
N,
at::cuda::getCurrentCUDAStream(),
false));
INVOKE_CUB_ENCODE(temp_storage.data_ptr());
} else {
size_t temp_storage_bytes_1 = 0;
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceSelect::Unique(
nullptr,
temp_storage_bytes_1,
sorted_indices.data_ptr<index_t>(),
unique_indices.data_ptr<index_t>(),
unique_indices_length.data_ptr<int32_t>(),
N,
at::cuda::getCurrentCUDAStream(),
false));
auto temp_storage_1 = at::empty(
{static_cast<index_t>(temp_storage_bytes_1)},
size_t temp_storage_bytes = 0;
INVOKE_CUB_UNIQUE(nullptr);
auto temp_storage = at::empty(
{static_cast<index_t>(temp_storage_bytes)},
linear_indices.options().dtype(at::kByte));
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceSelect::Unique(
temp_storage_1.data_ptr(),
temp_storage_bytes_1,
sorted_indices.data_ptr<index_t>(),
unique_indices.data_ptr<index_t>(),
unique_indices_length.data_ptr<int32_t>(),
N,
at::cuda::getCurrentCUDAStream(),
false));
INVOKE_CUB_UNIQUE(temp_storage.data_ptr());
}
});
return std::make_tuple(
unique_indices, unique_indices_length, unique_indices_count);
unique_indices,
unique_indices_length,
unique_indices_count,
linear_index_positions_sorted);
#undef INVOKE_CUB_SORT_PAIRS
#undef INVOKE_CUB_SORT_KEYS
#undef INVOKE_CUB_ENCODE
#undef INVOKE_CUB_UNIQUE
}
Loading

0 comments on commit c8c316d

Please sign in to comment.