Skip to content

Commit

Permalink
Add helper ops to support cache conflict misses (pytorch#2571)
Browse files Browse the repository at this point in the history
Summary:

This diff adds helper operators for the cache conflict miss support
enablement in SSD TBE.  Changes include:
- Extend `get_unique_indices_cuda` to compute and return inverse
  linear indices (the tensor that contains the original positions of
  lienar indices before sorting)
- Extend `lru_cache_find_uncached_cuda` to compute and return the
  inverse cache sets (the tensor that contains the original positions
  of cache sets of unique indices before sorting)
- Update SSD backend to support cache conflict misses instead of
  failing. The rows that experience conflict misses will be stored in
  a scratch pad for TBE kernels to consume. They will be evicted to
  SSD once the backward+optimizer step of TBE is completed.
- Add `ssd_generate_row_addrs` for generating row addresses of data
  that is fetched from SSD (data can be in either a scratch pad or LXU
  cache).

Reviewed By: q10

Differential Revision: D55926421
  • Loading branch information
sryap authored and facebook-github-bot committed May 9, 2024
1 parent c216005 commit 5dea02b
Show file tree
Hide file tree
Showing 16 changed files with 618 additions and 204 deletions.
48 changes: 48 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/split_embeddings_cache_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Tuple, Union

import torch

lib = torch.library.Library("fbgemm", "FRAGMENT")
lib.define(
"""
get_unique_indices(
Tensor linear_indices,
int max_indices,
bool compute_count=False,
bool compute_inverse_indices=False
) -> (Tensor, Tensor, Tensor?, Tensor?)
"""
)


@torch.library.impl(lib, "get_unique_indices", "CUDA")
def get_unique_indices(
linear_indices: torch.Tensor,
max_indices: int,
compute_count: bool = False,
compute_inverse_indices: bool = False,
) -> Union[
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]],
]:
"""
A wrapper for get_unique_indices for overloading the return type
based on inputs
"""
ret = torch.ops.fbgemm.get_unique_indices_internal(
linear_indices,
max_indices,
compute_count,
compute_inverse_indices,
)
if not compute_inverse_indices:
# Return only 3 tensors
return ret[:-1]
# Return all tensors
return ret
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from dataclasses import dataclass
from typing import List, NamedTuple

import fbgemm_gpu.split_embeddings_cache_ops # noqa


# Maximum number of times prefetch() can be called without
# a corresponding forward() call
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,10 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
evicted_indices,
assigned_cache_slots,
actions_count_gpu,
_,
_,
_,
_,
) = torch.ops.fbgemm.ssd_cache_populate_actions(
linear_cache_indices,
self.total_hash_size,
Expand Down Expand Up @@ -874,6 +878,10 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Tensor:
evicted_indices,
assigned_cache_slots,
actions_count_gpu,
_,
_,
_,
_,
) = torch.ops.fbgemm.ssd_cache_populate_actions(
linear_cache_indices,
self.total_hash_size,
Expand Down
15 changes: 11 additions & 4 deletions fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,22 @@ enum uvm_cache_stats_index {

///@ingroup table-batched-embed-cuda
/// Deduplicate indices.
std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>
std::tuple<
at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
c10::optional<at::Tensor>>
get_unique_indices_cuda(
at::Tensor linear_indices,
int64_t max_indices,
bool compute_count);
bool compute_count,
const bool compute_inverse_indices);

///@ingroup table-batched-embed-cuda
/// Lookup LRU cache to find uncached indices, and then sort them based on the
/// set.
std::pair<at::Tensor, at::Tensor> lru_cache_find_uncached_cuda(
std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>
lru_cache_find_uncached_cuda(
at::Tensor unique_indices,
at::Tensor unique_indices_length,
int64_t max_indices,
Expand All @@ -47,7 +53,8 @@ std::pair<at::Tensor, at::Tensor> lru_cache_find_uncached_cuda(
bool gather_cache_stats,
at::Tensor uvm_cache_stats,
bool lock_cache_line,
at::Tensor lxu_cache_locking_counter);
at::Tensor lxu_cache_locking_counter,
const bool compute_inverse_indices);

///@ingroup table-batched-embed-cuda
/// Map index to cache_set. h_in: linear_indices; C: #cache_sets.
Expand Down
16 changes: 10 additions & 6 deletions fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,16 @@ DLL_PUBLIC void lfu_cache_populate_cuda(
}
// get unqiue indices
Tensor unique_indices;
Tensor unique_indices_length;
c10::optional<Tensor> unique_indices_count;
std::tie(unique_indices, unique_indices_length, unique_indices_count) =
get_unique_indices_cuda(
linear_cache_indices, total_cache_hash_size, true);
auto
[unique_indices,
unique_indices_length,
unique_indices_count,
linear_cache_indices_positions_sorted] =
get_unique_indices_cuda(
linear_cache_indices,
total_cache_hash_size,
/*compute_count=*/true,
/*compute_inverse_indices=*/false);
// update lfu counts
lfu_update_counts_cuda(
Expand Down
16 changes: 10 additions & 6 deletions fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,16 @@ DLL_PUBLIC void lfu_cache_populate_byte_cuda(
}
// get unqiue indices
Tensor unique_indices;
Tensor unique_indices_length;
c10::optional<Tensor> unique_indices_count;
std::tie(unique_indices, unique_indices_length, unique_indices_count) =
get_unique_indices_cuda(
linear_cache_indices, total_cache_hash_size, true);
auto
[unique_indices,
unique_indices_length,
unique_indices_count,
linear_indices_postions_sorted] =
get_unique_indices_cuda(
linear_cache_indices,
total_cache_hash_size,
/*compute_count=*/true,
/*compute_inverse_indices=*/false);
// update lfu counts
lfu_update_counts_cuda(
Expand Down
170 changes: 97 additions & 73 deletions fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cu
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,13 @@ DLL_PUBLIC Tensor linearize_cache_indices_from_row_idx_cuda(
return linear_cache_indices;
}
DLL_PUBLIC std::tuple<Tensor, Tensor, c10::optional<Tensor>>
DLL_PUBLIC
std::tuple<Tensor, Tensor, c10::optional<Tensor>, c10::optional<Tensor>>
get_unique_indices_cuda(
Tensor linear_indices,
int64_t max_indices,
bool compute_count) {
bool compute_count,
const bool compute_inverse_indices) {
TENSOR_ON_CUDA_GPU(linear_indices);
CUDA_DEVICE_GUARD(linear_indices);
Expand All @@ -216,90 +218,112 @@ get_unique_indices_cuda(
auto unique_indices_length =
at::empty({1}, linear_indices.options().dtype(at::kInt));
c10::optional<Tensor> unique_indices_count = c10::nullopt;
c10::optional<Tensor> linear_index_positions_sorted = c10::nullopt;
Tensor linear_index_positions;
if (compute_inverse_indices) {
linear_index_positions = at::arange(
{linear_indices.numel()}, linear_indices.options().dtype(at::kInt));
linear_index_positions_sorted = at::empty_like(linear_index_positions);
}
if (compute_count) {
unique_indices_count = at::empty(
{linear_indices.numel()}, linear_indices.options().dtype(at::kInt));
}
#define INVOKE_CUB_SORT_PAIRS(TEMP_STORAGE_PTR) \
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( \
TEMP_STORAGE_PTR, \
temp_storage_bytes, \
linear_indices.data_ptr<index_t>(), \
sorted_indices.data_ptr<index_t>(), \
linear_index_positions.data_ptr<int32_t>(), \
linear_index_positions_sorted->data_ptr<int32_t>(), \
N, \
0, \
int(log2(float(max_indices + 1)) + 1), \
at::cuda::getCurrentCUDAStream(), \
false))
#define INVOKE_CUB_SORT_KEYS(TEMP_STORAGE_PTR) \
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortKeys( \
TEMP_STORAGE_PTR, \
temp_storage_bytes, \
linear_indices.data_ptr<index_t>(), \
sorted_indices.data_ptr<index_t>(), \
N, \
0, \
int(log2(float(max_indices + 1)) + 1), \
at::cuda::getCurrentCUDAStream(), \
false))
#define INVOKE_CUB_ENCODE(TEMP_STORAGE_PTR) \
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRunLengthEncode::Encode( \
TEMP_STORAGE_PTR, \
temp_storage_bytes, \
sorted_indices.data_ptr<index_t>(), \
unique_indices.data_ptr<index_t>(), \
unique_indices_count->data_ptr<int32_t>(), \
unique_indices_length.data_ptr<int32_t>(), \
N, \
at::cuda::getCurrentCUDAStream(), \
false))
#define INVOKE_CUB_UNIQUE(TEMP_STORAGE_PTR) \
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceSelect::Unique( \
TEMP_STORAGE_PTR, \
temp_storage_bytes, \
sorted_indices.data_ptr<index_t>(), \
unique_indices.data_ptr<index_t>(), \
unique_indices_length.data_ptr<int32_t>(), \
N, \
at::cuda::getCurrentCUDAStream(), \
false))
AT_DISPATCH_INDEX_TYPES(
linear_indices.scalar_type(), "get_unique_indices_cuda", [&] {
// sort indices
size_t temp_storage_bytes_0 = 0;
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortKeys(
nullptr,
temp_storage_bytes_0,
linear_indices.data_ptr<index_t>(),
sorted_indices.data_ptr<index_t>(),
N,
0,
int(log2(float(max_indices + 1)) + 1),
at::cuda::getCurrentCUDAStream(),
false));
auto temp_storage_0 = at::empty(
{static_cast<index_t>(temp_storage_bytes_0)},
linear_indices.options().dtype(at::kByte));
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortKeys(
temp_storage_0.data_ptr(),
temp_storage_bytes_0,
linear_indices.data_ptr<index_t>(),
sorted_indices.data_ptr<index_t>(),
N,
0,
int(log2(float(max_indices + 1)) + 1),
at::cuda::getCurrentCUDAStream(),
false));
if (compute_inverse_indices) {
size_t temp_storage_bytes = 0;
INVOKE_CUB_SORT_PAIRS(nullptr);
auto temp_storage = at::empty(
{static_cast<int64_t>(temp_storage_bytes)},
linear_indices.options().dtype(at::kByte));
INVOKE_CUB_SORT_PAIRS(temp_storage.data_ptr());
} else {
size_t temp_storage_bytes = 0;
INVOKE_CUB_SORT_KEYS(nullptr);
auto temp_storage = at::empty(
{static_cast<index_t>(temp_storage_bytes)},
linear_indices.options().dtype(at::kByte));
INVOKE_CUB_SORT_KEYS(temp_storage.data_ptr());
}
// get unique indices
if (compute_count) {
size_t temp_storage_bytes_1 = 0;
AT_CUDA_CHECK(
FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRunLengthEncode::Encode(
nullptr,
temp_storage_bytes_1,
sorted_indices.data_ptr<index_t>(),
unique_indices.data_ptr<index_t>(),
unique_indices_count->data_ptr<int32_t>(),
unique_indices_length.data_ptr<int32_t>(),
N,
at::cuda::getCurrentCUDAStream(),
false));
auto temp_storage_1 = at::empty(
{static_cast<index_t>(temp_storage_bytes_1)},
size_t temp_storage_bytes = 0;
INVOKE_CUB_ENCODE(nullptr);
auto temp_storage = at::empty(
{static_cast<index_t>(temp_storage_bytes)},
linear_indices.options().dtype(at::kByte));
AT_CUDA_CHECK(
FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRunLengthEncode::Encode(
temp_storage_1.data_ptr(),
temp_storage_bytes_1,
sorted_indices.data_ptr<index_t>(),
unique_indices.data_ptr<index_t>(),
unique_indices_count->data_ptr<int32_t>(),
unique_indices_length.data_ptr<int32_t>(),
N,
at::cuda::getCurrentCUDAStream(),
false));
INVOKE_CUB_ENCODE(temp_storage.data_ptr());
} else {
size_t temp_storage_bytes_1 = 0;
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceSelect::Unique(
nullptr,
temp_storage_bytes_1,
sorted_indices.data_ptr<index_t>(),
unique_indices.data_ptr<index_t>(),
unique_indices_length.data_ptr<int32_t>(),
N,
at::cuda::getCurrentCUDAStream(),
false));
auto temp_storage_1 = at::empty(
{static_cast<index_t>(temp_storage_bytes_1)},
size_t temp_storage_bytes = 0;
INVOKE_CUB_UNIQUE(nullptr);
auto temp_storage = at::empty(
{static_cast<index_t>(temp_storage_bytes)},
linear_indices.options().dtype(at::kByte));
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceSelect::Unique(
temp_storage_1.data_ptr(),
temp_storage_bytes_1,
sorted_indices.data_ptr<index_t>(),
unique_indices.data_ptr<index_t>(),
unique_indices_length.data_ptr<int32_t>(),
N,
at::cuda::getCurrentCUDAStream(),
false));
INVOKE_CUB_UNIQUE(temp_storage.data_ptr());
}
});
return std::make_tuple(
unique_indices, unique_indices_length, unique_indices_count);
unique_indices,
unique_indices_length,
unique_indices_count,
linear_index_positions_sorted);
#undef INVOKE_CUB_SORT_PAIRS
#undef INVOKE_CUB_SORT_KEYS
#undef INVOKE_CUB_ENCODE
#undef INVOKE_CUB_UNIQUE
}
Loading

0 comments on commit 5dea02b

Please sign in to comment.