Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add helper ops to support cache conflict misses #2571

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@
from fbgemm_gpu.docs.version import __version__ # noqa: F401, E402

# Trigger meta operator registrations
from . import sparse_ops # noqa: F401, E402

from . import sparse_ops, split_embeddings_cache_ops # noqa: F401, E402
48 changes: 48 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/split_embeddings_cache_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Tuple, Union

import torch

lib = torch.library.Library("fbgemm", "FRAGMENT")
lib.define(
"""
get_unique_indices(
Tensor linear_indices,
int max_indices,
bool compute_count=False,
bool compute_inverse_indices=False
) -> (Tensor, Tensor, Tensor?, Tensor?)
"""
)


@torch.library.impl(lib, "get_unique_indices", "CUDA")
def get_unique_indices(
linear_indices: torch.Tensor,
max_indices: int,
compute_count: bool = False,
compute_inverse_indices: bool = False,
) -> Union[
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]],
]:
"""
A wrapper for get_unique_indices for overloading the return type
based on inputs
"""
ret = torch.ops.fbgemm.get_unique_indices_internal(
linear_indices,
max_indices,
compute_count,
compute_inverse_indices,
)
if not compute_inverse_indices:
# Return only 3 tensors
return ret[:-1]
# Return all tensors
return ret
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,10 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
evicted_indices,
assigned_cache_slots,
actions_count_gpu,
_,
_,
_,
_,
) = torch.ops.fbgemm.ssd_cache_populate_actions(
linear_cache_indices,
self.total_hash_size,
Expand Down Expand Up @@ -874,6 +878,10 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Tensor:
evicted_indices,
assigned_cache_slots,
actions_count_gpu,
_,
_,
_,
_,
) = torch.ops.fbgemm.ssd_cache_populate_actions(
linear_cache_indices,
self.total_hash_size,
Expand Down
15 changes: 11 additions & 4 deletions fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,22 @@ enum uvm_cache_stats_index {

///@ingroup table-batched-embed-cuda
/// Deduplicate indices.
std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>
std::tuple<
at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
c10::optional<at::Tensor>>
get_unique_indices_cuda(
at::Tensor linear_indices,
int64_t max_indices,
bool compute_count);
bool compute_count,
const bool compute_inverse_indices);

///@ingroup table-batched-embed-cuda
/// Lookup LRU cache to find uncached indices, and then sort them based on the
/// set.
std::pair<at::Tensor, at::Tensor> lru_cache_find_uncached_cuda(
std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>
lru_cache_find_uncached_cuda(
at::Tensor unique_indices,
at::Tensor unique_indices_length,
int64_t max_indices,
Expand All @@ -47,7 +53,8 @@ std::pair<at::Tensor, at::Tensor> lru_cache_find_uncached_cuda(
bool gather_cache_stats,
at::Tensor uvm_cache_stats,
bool lock_cache_line,
at::Tensor lxu_cache_locking_counter);
at::Tensor lxu_cache_locking_counter,
const bool compute_inverse_indices);

///@ingroup table-batched-embed-cuda
/// Map index to cache_set. h_in: linear_indices; C: #cache_sets.
Expand Down
16 changes: 10 additions & 6 deletions fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,16 @@ DLL_PUBLIC void lfu_cache_populate_cuda(
}

// get unqiue indices
Tensor unique_indices;
Tensor unique_indices_length;
c10::optional<Tensor> unique_indices_count;
std::tie(unique_indices, unique_indices_length, unique_indices_count) =
get_unique_indices_cuda(
linear_cache_indices, total_cache_hash_size, true);
auto
[unique_indices,
unique_indices_length,
unique_indices_count,
linear_cache_indices_positions_sorted] =
get_unique_indices_cuda(
linear_cache_indices,
total_cache_hash_size,
/*compute_count=*/true,
/*compute_inverse_indices=*/false);

// update lfu counts
lfu_update_counts_cuda(
Expand Down
16 changes: 10 additions & 6 deletions fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,16 @@ DLL_PUBLIC void lfu_cache_populate_byte_cuda(
}

// get unqiue indices
Tensor unique_indices;
Tensor unique_indices_length;
c10::optional<Tensor> unique_indices_count;
std::tie(unique_indices, unique_indices_length, unique_indices_count) =
get_unique_indices_cuda(
linear_cache_indices, total_cache_hash_size, true);
auto
[unique_indices,
unique_indices_length,
unique_indices_count,
linear_indices_postions_sorted] =
get_unique_indices_cuda(
linear_cache_indices,
total_cache_hash_size,
/*compute_count=*/true,
/*compute_inverse_indices=*/false);

// update lfu counts
lfu_update_counts_cuda(
Expand Down
170 changes: 97 additions & 73 deletions fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cu
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,13 @@ DLL_PUBLIC Tensor linearize_cache_indices_from_row_idx_cuda(
return linear_cache_indices;
}

DLL_PUBLIC std::tuple<Tensor, Tensor, c10::optional<Tensor>>
DLL_PUBLIC
std::tuple<Tensor, Tensor, c10::optional<Tensor>, c10::optional<Tensor>>
get_unique_indices_cuda(
Tensor linear_indices,
int64_t max_indices,
bool compute_count) {
bool compute_count,
const bool compute_inverse_indices) {
TENSOR_ON_CUDA_GPU(linear_indices);

CUDA_DEVICE_GUARD(linear_indices);
Expand All @@ -216,90 +218,112 @@ get_unique_indices_cuda(
auto unique_indices_length =
at::empty({1}, linear_indices.options().dtype(at::kInt));
c10::optional<Tensor> unique_indices_count = c10::nullopt;
c10::optional<Tensor> linear_index_positions_sorted = c10::nullopt;

Tensor linear_index_positions;
if (compute_inverse_indices) {
linear_index_positions = at::arange(
{linear_indices.numel()}, linear_indices.options().dtype(at::kInt));
linear_index_positions_sorted = at::empty_like(linear_index_positions);
}
if (compute_count) {
unique_indices_count = at::empty(
{linear_indices.numel()}, linear_indices.options().dtype(at::kInt));
}

#define INVOKE_CUB_SORT_PAIRS(TEMP_STORAGE_PTR) \
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( \
TEMP_STORAGE_PTR, \
temp_storage_bytes, \
linear_indices.data_ptr<index_t>(), \
sorted_indices.data_ptr<index_t>(), \
linear_index_positions.data_ptr<int32_t>(), \
linear_index_positions_sorted->data_ptr<int32_t>(), \
N, \
0, \
int(log2(float(max_indices + 1)) + 1), \
at::cuda::getCurrentCUDAStream(), \
false))

#define INVOKE_CUB_SORT_KEYS(TEMP_STORAGE_PTR) \
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortKeys( \
TEMP_STORAGE_PTR, \
temp_storage_bytes, \
linear_indices.data_ptr<index_t>(), \
sorted_indices.data_ptr<index_t>(), \
N, \
0, \
int(log2(float(max_indices + 1)) + 1), \
at::cuda::getCurrentCUDAStream(), \
false))

#define INVOKE_CUB_ENCODE(TEMP_STORAGE_PTR) \
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRunLengthEncode::Encode( \
TEMP_STORAGE_PTR, \
temp_storage_bytes, \
sorted_indices.data_ptr<index_t>(), \
unique_indices.data_ptr<index_t>(), \
unique_indices_count->data_ptr<int32_t>(), \
unique_indices_length.data_ptr<int32_t>(), \
N, \
at::cuda::getCurrentCUDAStream(), \
false))

#define INVOKE_CUB_UNIQUE(TEMP_STORAGE_PTR) \
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceSelect::Unique( \
TEMP_STORAGE_PTR, \
temp_storage_bytes, \
sorted_indices.data_ptr<index_t>(), \
unique_indices.data_ptr<index_t>(), \
unique_indices_length.data_ptr<int32_t>(), \
N, \
at::cuda::getCurrentCUDAStream(), \
false))

AT_DISPATCH_INDEX_TYPES(
linear_indices.scalar_type(), "get_unique_indices_cuda", [&] {
// sort indices
size_t temp_storage_bytes_0 = 0;
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortKeys(
nullptr,
temp_storage_bytes_0,
linear_indices.data_ptr<index_t>(),
sorted_indices.data_ptr<index_t>(),
N,
0,
int(log2(float(max_indices + 1)) + 1),
at::cuda::getCurrentCUDAStream(),
false));
auto temp_storage_0 = at::empty(
{static_cast<index_t>(temp_storage_bytes_0)},
linear_indices.options().dtype(at::kByte));
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortKeys(
temp_storage_0.data_ptr(),
temp_storage_bytes_0,
linear_indices.data_ptr<index_t>(),
sorted_indices.data_ptr<index_t>(),
N,
0,
int(log2(float(max_indices + 1)) + 1),
at::cuda::getCurrentCUDAStream(),
false));
if (compute_inverse_indices) {
size_t temp_storage_bytes = 0;
INVOKE_CUB_SORT_PAIRS(nullptr);
auto temp_storage = at::empty(
{static_cast<int64_t>(temp_storage_bytes)},
linear_indices.options().dtype(at::kByte));
INVOKE_CUB_SORT_PAIRS(temp_storage.data_ptr());
} else {
size_t temp_storage_bytes = 0;
INVOKE_CUB_SORT_KEYS(nullptr);
auto temp_storage = at::empty(
{static_cast<index_t>(temp_storage_bytes)},
linear_indices.options().dtype(at::kByte));
INVOKE_CUB_SORT_KEYS(temp_storage.data_ptr());
}
// get unique indices
if (compute_count) {
size_t temp_storage_bytes_1 = 0;
AT_CUDA_CHECK(
FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRunLengthEncode::Encode(
nullptr,
temp_storage_bytes_1,
sorted_indices.data_ptr<index_t>(),
unique_indices.data_ptr<index_t>(),
unique_indices_count->data_ptr<int32_t>(),
unique_indices_length.data_ptr<int32_t>(),
N,
at::cuda::getCurrentCUDAStream(),
false));
auto temp_storage_1 = at::empty(
{static_cast<index_t>(temp_storage_bytes_1)},
size_t temp_storage_bytes = 0;
INVOKE_CUB_ENCODE(nullptr);
auto temp_storage = at::empty(
{static_cast<index_t>(temp_storage_bytes)},
linear_indices.options().dtype(at::kByte));
AT_CUDA_CHECK(
FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRunLengthEncode::Encode(
temp_storage_1.data_ptr(),
temp_storage_bytes_1,
sorted_indices.data_ptr<index_t>(),
unique_indices.data_ptr<index_t>(),
unique_indices_count->data_ptr<int32_t>(),
unique_indices_length.data_ptr<int32_t>(),
N,
at::cuda::getCurrentCUDAStream(),
false));
INVOKE_CUB_ENCODE(temp_storage.data_ptr());
} else {
size_t temp_storage_bytes_1 = 0;
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceSelect::Unique(
nullptr,
temp_storage_bytes_1,
sorted_indices.data_ptr<index_t>(),
unique_indices.data_ptr<index_t>(),
unique_indices_length.data_ptr<int32_t>(),
N,
at::cuda::getCurrentCUDAStream(),
false));
auto temp_storage_1 = at::empty(
{static_cast<index_t>(temp_storage_bytes_1)},
size_t temp_storage_bytes = 0;
INVOKE_CUB_UNIQUE(nullptr);
auto temp_storage = at::empty(
{static_cast<index_t>(temp_storage_bytes)},
linear_indices.options().dtype(at::kByte));
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceSelect::Unique(
temp_storage_1.data_ptr(),
temp_storage_bytes_1,
sorted_indices.data_ptr<index_t>(),
unique_indices.data_ptr<index_t>(),
unique_indices_length.data_ptr<int32_t>(),
N,
at::cuda::getCurrentCUDAStream(),
false));
INVOKE_CUB_UNIQUE(temp_storage.data_ptr());
}
});

return std::make_tuple(
unique_indices, unique_indices_length, unique_indices_count);
unique_indices,
unique_indices_length,
unique_indices_count,
linear_index_positions_sorted);

#undef INVOKE_CUB_SORT_PAIRS
#undef INVOKE_CUB_SORT_KEYS
#undef INVOKE_CUB_ENCODE
#undef INVOKE_CUB_UNIQUE
}