diff --git a/fbgemm_gpu/fbgemm_gpu/split_embeddings_cache_ops.py b/fbgemm_gpu/fbgemm_gpu/split_embeddings_cache_ops.py new file mode 100644 index 0000000000..002418a4a0 --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/split_embeddings_cache_ops.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple, Union + +import torch + +lib = torch.library.Library("fbgemm", "FRAGMENT") +lib.define( + """ + get_unique_indices( + Tensor linear_indices, + int max_indices, + bool compute_count=False, + bool compute_inverse_indices=False + ) -> (Tensor, Tensor, Tensor?, Tensor?) + """ +) + + +@torch.library.impl(lib, "get_unique_indices", "CUDA") +def get_unique_indices( + linear_indices: torch.Tensor, + max_indices: int, + compute_count: bool = False, + compute_inverse_indices: bool = False, +) -> Union[ + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]], + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], +]: + """ + A wrapper for get_unique_indices for overloading the return type + based on inputs + """ + ret = torch.ops.fbgemm.get_unique_indices_internal( + linear_indices, + max_indices, + compute_count, + compute_inverse_indices, + ) + if not compute_inverse_indices: + # Return only 3 tensors + return ret[:-1] + # Return all tensors + return ret diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py index 52d3d9005f..24e906e9ac 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py @@ -13,6 +13,8 @@ from dataclasses import dataclass from typing import List, NamedTuple +import fbgemm_gpu.split_embeddings_cache_ops # noqa + # Maximum number of times prefetch() can be called without # a corresponding forward() call diff --git a/fbgemm_gpu/fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py index 5906254274..64316bc6c8 100644 --- a/fbgemm_gpu/fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py @@ -364,6 +364,10 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]: evicted_indices, assigned_cache_slots, actions_count_gpu, + _, + _, + _, + _, ) = torch.ops.fbgemm.ssd_cache_populate_actions( linear_cache_indices, self.total_hash_size, @@ -874,6 +878,10 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Tensor: evicted_indices, assigned_cache_slots, actions_count_gpu, + _, + _, + _, + _, ) = torch.ops.fbgemm.ssd_cache_populate_actions( linear_cache_indices, self.total_hash_size, diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh index dc302d67c7..537b33725d 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh @@ -28,16 +28,22 @@ enum uvm_cache_stats_index { ///@ingroup table-batched-embed-cuda /// Deduplicate indices. -std::tuple> +std::tuple< + at::Tensor, + at::Tensor, + c10::optional, + c10::optional> get_unique_indices_cuda( at::Tensor linear_indices, int64_t max_indices, - bool compute_count); + bool compute_count, + const bool compute_inverse_indices); ///@ingroup table-batched-embed-cuda /// Lookup LRU cache to find uncached indices, and then sort them based on the /// set. -std::pair lru_cache_find_uncached_cuda( +std::tuple> +lru_cache_find_uncached_cuda( at::Tensor unique_indices, at::Tensor unique_indices_length, int64_t max_indices, @@ -47,7 +53,8 @@ std::pair lru_cache_find_uncached_cuda( bool gather_cache_stats, at::Tensor uvm_cache_stats, bool lock_cache_line, - at::Tensor lxu_cache_locking_counter); + at::Tensor lxu_cache_locking_counter, + const bool compute_inverse_indices); ///@ingroup table-batched-embed-cuda /// Map index to cache_set. h_in: linear_indices; C: #cache_sets. diff --git a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu index d8a913cb47..cb1675bc32 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu @@ -271,12 +271,16 @@ DLL_PUBLIC void lfu_cache_populate_cuda( } // get unqiue indices - Tensor unique_indices; - Tensor unique_indices_length; - c10::optional unique_indices_count; - std::tie(unique_indices, unique_indices_length, unique_indices_count) = - get_unique_indices_cuda( - linear_cache_indices, total_cache_hash_size, true); + auto + [unique_indices, + unique_indices_length, + unique_indices_count, + linear_cache_indices_positions_sorted] = + get_unique_indices_cuda( + linear_cache_indices, + total_cache_hash_size, + /*compute_count=*/true, + /*compute_inverse_indices=*/false); // update lfu counts lfu_update_counts_cuda( diff --git a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu index be659ecd2f..251dd1bf8a 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu @@ -240,12 +240,16 @@ DLL_PUBLIC void lfu_cache_populate_byte_cuda( } // get unqiue indices - Tensor unique_indices; - Tensor unique_indices_length; - c10::optional unique_indices_count; - std::tie(unique_indices, unique_indices_length, unique_indices_count) = - get_unique_indices_cuda( - linear_cache_indices, total_cache_hash_size, true); + auto + [unique_indices, + unique_indices_length, + unique_indices_count, + linear_indices_postions_sorted] = + get_unique_indices_cuda( + linear_cache_indices, + total_cache_hash_size, + /*compute_count=*/true, + /*compute_inverse_indices=*/false); // update lfu counts lfu_update_counts_cuda( diff --git a/fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cu b/fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cu index 2ac076d1cb..e09af1462e 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cu @@ -200,11 +200,13 @@ DLL_PUBLIC Tensor linearize_cache_indices_from_row_idx_cuda( return linear_cache_indices; } -DLL_PUBLIC std::tuple> +DLL_PUBLIC +std::tuple, c10::optional> get_unique_indices_cuda( Tensor linear_indices, int64_t max_indices, - bool compute_count) { + bool compute_count, + const bool compute_inverse_indices) { TENSOR_ON_CUDA_GPU(linear_indices); CUDA_DEVICE_GUARD(linear_indices); @@ -216,90 +218,112 @@ get_unique_indices_cuda( auto unique_indices_length = at::empty({1}, linear_indices.options().dtype(at::kInt)); c10::optional unique_indices_count = c10::nullopt; + c10::optional linear_index_positions_sorted = c10::nullopt; + + Tensor linear_index_positions; + if (compute_inverse_indices) { + linear_index_positions = at::arange( + {linear_indices.numel()}, linear_indices.options().dtype(at::kInt)); + linear_index_positions_sorted = at::empty_like(linear_index_positions); + } if (compute_count) { unique_indices_count = at::empty( {linear_indices.numel()}, linear_indices.options().dtype(at::kInt)); } + +#define INVOKE_CUB_SORT_PAIRS(TEMP_STORAGE_PTR) \ + AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( \ + TEMP_STORAGE_PTR, \ + temp_storage_bytes, \ + linear_indices.data_ptr(), \ + sorted_indices.data_ptr(), \ + linear_index_positions.data_ptr(), \ + linear_index_positions_sorted->data_ptr(), \ + N, \ + 0, \ + int(log2(float(max_indices + 1)) + 1), \ + at::cuda::getCurrentCUDAStream(), \ + false)) + +#define INVOKE_CUB_SORT_KEYS(TEMP_STORAGE_PTR) \ + AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortKeys( \ + TEMP_STORAGE_PTR, \ + temp_storage_bytes, \ + linear_indices.data_ptr(), \ + sorted_indices.data_ptr(), \ + N, \ + 0, \ + int(log2(float(max_indices + 1)) + 1), \ + at::cuda::getCurrentCUDAStream(), \ + false)) + +#define INVOKE_CUB_ENCODE(TEMP_STORAGE_PTR) \ + AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRunLengthEncode::Encode( \ + TEMP_STORAGE_PTR, \ + temp_storage_bytes, \ + sorted_indices.data_ptr(), \ + unique_indices.data_ptr(), \ + unique_indices_count->data_ptr(), \ + unique_indices_length.data_ptr(), \ + N, \ + at::cuda::getCurrentCUDAStream(), \ + false)) + +#define INVOKE_CUB_UNIQUE(TEMP_STORAGE_PTR) \ + AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceSelect::Unique( \ + TEMP_STORAGE_PTR, \ + temp_storage_bytes, \ + sorted_indices.data_ptr(), \ + unique_indices.data_ptr(), \ + unique_indices_length.data_ptr(), \ + N, \ + at::cuda::getCurrentCUDAStream(), \ + false)) + AT_DISPATCH_INDEX_TYPES( linear_indices.scalar_type(), "get_unique_indices_cuda", [&] { // sort indices - size_t temp_storage_bytes_0 = 0; - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortKeys( - nullptr, - temp_storage_bytes_0, - linear_indices.data_ptr(), - sorted_indices.data_ptr(), - N, - 0, - int(log2(float(max_indices + 1)) + 1), - at::cuda::getCurrentCUDAStream(), - false)); - auto temp_storage_0 = at::empty( - {static_cast(temp_storage_bytes_0)}, - linear_indices.options().dtype(at::kByte)); - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortKeys( - temp_storage_0.data_ptr(), - temp_storage_bytes_0, - linear_indices.data_ptr(), - sorted_indices.data_ptr(), - N, - 0, - int(log2(float(max_indices + 1)) + 1), - at::cuda::getCurrentCUDAStream(), - false)); + if (compute_inverse_indices) { + size_t temp_storage_bytes = 0; + INVOKE_CUB_SORT_PAIRS(nullptr); + auto temp_storage = at::empty( + {static_cast(temp_storage_bytes)}, + linear_indices.options().dtype(at::kByte)); + INVOKE_CUB_SORT_PAIRS(temp_storage.data_ptr()); + } else { + size_t temp_storage_bytes = 0; + INVOKE_CUB_SORT_KEYS(nullptr); + auto temp_storage = at::empty( + {static_cast(temp_storage_bytes)}, + linear_indices.options().dtype(at::kByte)); + INVOKE_CUB_SORT_KEYS(temp_storage.data_ptr()); + } // get unique indices if (compute_count) { - size_t temp_storage_bytes_1 = 0; - AT_CUDA_CHECK( - FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRunLengthEncode::Encode( - nullptr, - temp_storage_bytes_1, - sorted_indices.data_ptr(), - unique_indices.data_ptr(), - unique_indices_count->data_ptr(), - unique_indices_length.data_ptr(), - N, - at::cuda::getCurrentCUDAStream(), - false)); - auto temp_storage_1 = at::empty( - {static_cast(temp_storage_bytes_1)}, + size_t temp_storage_bytes = 0; + INVOKE_CUB_ENCODE(nullptr); + auto temp_storage = at::empty( + {static_cast(temp_storage_bytes)}, linear_indices.options().dtype(at::kByte)); - AT_CUDA_CHECK( - FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRunLengthEncode::Encode( - temp_storage_1.data_ptr(), - temp_storage_bytes_1, - sorted_indices.data_ptr(), - unique_indices.data_ptr(), - unique_indices_count->data_ptr(), - unique_indices_length.data_ptr(), - N, - at::cuda::getCurrentCUDAStream(), - false)); + INVOKE_CUB_ENCODE(temp_storage.data_ptr()); } else { - size_t temp_storage_bytes_1 = 0; - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceSelect::Unique( - nullptr, - temp_storage_bytes_1, - sorted_indices.data_ptr(), - unique_indices.data_ptr(), - unique_indices_length.data_ptr(), - N, - at::cuda::getCurrentCUDAStream(), - false)); - auto temp_storage_1 = at::empty( - {static_cast(temp_storage_bytes_1)}, + size_t temp_storage_bytes = 0; + INVOKE_CUB_UNIQUE(nullptr); + auto temp_storage = at::empty( + {static_cast(temp_storage_bytes)}, linear_indices.options().dtype(at::kByte)); - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceSelect::Unique( - temp_storage_1.data_ptr(), - temp_storage_bytes_1, - sorted_indices.data_ptr(), - unique_indices.data_ptr(), - unique_indices_length.data_ptr(), - N, - at::cuda::getCurrentCUDAStream(), - false)); + INVOKE_CUB_UNIQUE(temp_storage.data_ptr()); } }); + return std::make_tuple( - unique_indices, unique_indices_length, unique_indices_count); + unique_indices, + unique_indices_length, + unique_indices_count, + linear_index_positions_sorted); + +#undef INVOKE_CUB_SORT_PAIRS +#undef INVOKE_CUB_SORT_KEYS +#undef INVOKE_CUB_ENCODE +#undef INVOKE_CUB_UNIQUE } diff --git a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_find.cu b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_find.cu index 3ea4cd3498..a0d08d0ee9 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_find.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_find.cu @@ -150,7 +150,8 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_find_uncached_kernel( } // namespace -DLL_PUBLIC std::pair lru_cache_find_uncached_cuda( +DLL_PUBLIC std::tuple> +lru_cache_find_uncached_cuda( Tensor unique_indices, Tensor unique_indices_length, int64_t max_indices, @@ -160,7 +161,8 @@ DLL_PUBLIC std::pair lru_cache_find_uncached_cuda( bool gather_cache_stats, Tensor uvm_cache_stats, bool lock_cache_line, - Tensor lxu_cache_locking_counter) { + Tensor lxu_cache_locking_counter, + const bool compute_inverse_indices) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( unique_indices, unique_indices_length, @@ -180,6 +182,33 @@ DLL_PUBLIC std::pair lru_cache_find_uncached_cuda( auto sorted_cache_sets = empty_like(cache_sets); auto cache_set_sorted_unique_indices = empty_like(unique_indices); + Tensor cache_sets_positions; + c10::optional cache_set_inverse_indices = c10::nullopt; + if (compute_inverse_indices) { + TORCH_CHECK( + cache_sets.numel() <= + static_cast(std::numeric_limits::max()), + "Number of elements in cache_sets is larger than int32_t max"); + cache_sets_positions = + at::arange({cache_sets.numel()}, cache_sets.options().dtype(at::kInt)); + cache_set_inverse_indices = empty_like(cache_sets_positions); + } + +#define INVOKE_CUB_SORT_PAIRS( \ + TEMP_STORAGE_PTR, VALUE_TENSOR, SORTED_VALUE_TENSOR) \ + AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( \ + TEMP_STORAGE_PTR, \ + temp_storage_bytes, \ + cache_sets.data_ptr(), \ + sorted_cache_sets.data_ptr(), \ + VALUE_TENSOR, \ + SORTED_VALUE_TENSOR, \ + N, \ + 0, \ + int(log2(float(lxu_cache_state.size(0) + 1)) + 1), \ + at::cuda::getCurrentCUDAStream(), \ + false)) + AT_DISPATCH_INDEX_TYPES( unique_indices.scalar_type(), "lru_cache_find_uncached_cuda", [&] { #ifdef FBGEMM_GPU_MEMCHECK @@ -208,33 +237,37 @@ DLL_PUBLIC std::pair lru_cache_find_uncached_cuda( C10_CUDA_KERNEL_LAUNCH_CHECK(); // Sort the cache sets and ids size_t temp_storage_bytes = 0; - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( + INVOKE_CUB_SORT_PAIRS( nullptr, - temp_storage_bytes, - cache_sets.data_ptr(), - sorted_cache_sets.data_ptr(), unique_indices.data_ptr(), - cache_set_sorted_unique_indices.data_ptr(), - N, - 0, - int(log2(float(lxu_cache_state.size(0) + 1)) + 1), - at::cuda::getCurrentCUDAStream(), - false)); + cache_set_sorted_unique_indices.data_ptr()); auto temp_storage = at::empty( {static_cast(temp_storage_bytes)}, unique_indices.options().dtype(at::kByte)); - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( + INVOKE_CUB_SORT_PAIRS( temp_storage.data_ptr(), - temp_storage_bytes, - cache_sets.data_ptr(), - sorted_cache_sets.data_ptr(), unique_indices.data_ptr(), - cache_set_sorted_unique_indices.data_ptr(), - N, - 0, - int(log2(float(lxu_cache_state.size(0) + 1)) + 1), - at::cuda::getCurrentCUDAStream(), - false)); + cache_set_sorted_unique_indices.data_ptr()); + + if (compute_inverse_indices) { + INVOKE_CUB_SORT_PAIRS( + nullptr, + cache_sets_positions.data_ptr(), + cache_set_inverse_indices->data_ptr()); + auto temp_storage = at::empty( + {static_cast(temp_storage_bytes)}, + unique_indices.options().dtype(at::kByte)); + INVOKE_CUB_SORT_PAIRS( + temp_storage.data_ptr(), + cache_sets_positions.data_ptr(), + cache_set_inverse_indices->data_ptr()); + } }); - return {sorted_cache_sets, cache_set_sorted_unique_indices}; + + return { + sorted_cache_sets, + cache_set_sorted_unique_indices, + cache_set_inverse_indices}; + +#undef INVOKE_CUB_SORT_PAIRS } diff --git a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu index 1087f64aeb..72f586f9d1 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu @@ -325,26 +325,33 @@ DLL_PUBLIC void lru_cache_populate_cuda( } // Get unqiue indices - Tensor unique_indices; - Tensor unique_indices_length; - c10::optional unique_indices_count; - std::tie(unique_indices, unique_indices_length, unique_indices_count) = - get_unique_indices_cuda( - linear_cache_indices, total_cache_hash_size, false); - - auto cache_sets_and_unique_indices = lru_cache_find_uncached_cuda( - unique_indices, - unique_indices_length, - total_cache_hash_size, - lxu_cache_state, - time_stamp, - lru_state, - gather_cache_stats, - uvm_cache_stats_, - lock_cache_line, - lxu_cache_locking_counter_); - auto sorted_cache_sets = cache_sets_and_unique_indices.first; - auto cache_set_sorted_unique_indices = cache_sets_and_unique_indices.second; + auto + [unique_indices, + unique_indices_length, + unique_indices_count, + linear_cache_indices_positions_sorted] = + get_unique_indices_cuda( + linear_cache_indices, + total_cache_hash_size, + /*compute_count=*/false, + /*compute_inverse_indices=*/false); + + auto + [sorted_cache_sets, + cache_set_sorted_unique_indices, + cache_set_inverse_indices] = + lru_cache_find_uncached_cuda( + unique_indices, + unique_indices_length, + total_cache_hash_size, + lxu_cache_state, + time_stamp, + lru_state, + gather_cache_stats, + uvm_cache_stats_, + lock_cache_line, + lxu_cache_locking_counter_, + /*compute_inverse_indices=*/false); // insert caching weights lru_cache_insert_cuda( diff --git a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate_byte.cu b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate_byte.cu index fdb052d009..cc90385961 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate_byte.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate_byte.cu @@ -549,29 +549,36 @@ DLL_PUBLIC void lru_cache_populate_byte_cuda( } // Get unqiue indices - Tensor unique_indices; - Tensor unique_indices_length; - c10::optional unique_indices_count; - std::tie(unique_indices, unique_indices_length, unique_indices_count) = - get_unique_indices_cuda( - linear_cache_indices, total_cache_hash_size, false); + auto + [unique_indices, + unique_indices_length, + unique_indices_count, + linear_cache_indices_positions_sorted] = + get_unique_indices_cuda( + linear_cache_indices, + total_cache_hash_size, + /*compute_count=*/false, + /*compute_inverse_indices=*/false); // Find uncached indices Tensor lxu_cache_locking_counter = at::empty({0, 0}, lxu_cache_state.options().dtype(at::kInt)); - auto cache_sets_and_unique_indices = lru_cache_find_uncached_cuda( - unique_indices, - unique_indices_length, - total_cache_hash_size, - lxu_cache_state, - time_stamp, - lru_state, - gather_cache_stats, - uvm_cache_stats_, - false, // lock_cache_line - lxu_cache_locking_counter); - auto sorted_cache_sets = cache_sets_and_unique_indices.first; - auto cache_set_sorted_unique_indices = cache_sets_and_unique_indices.second; + auto + [sorted_cache_sets, + cache_set_sorted_unique_indices, + cache_set_inverse_indices] = + lru_cache_find_uncached_cuda( + unique_indices, + unique_indices_length, + total_cache_hash_size, + lxu_cache_state, + time_stamp, + lru_state, + gather_cache_stats, + uvm_cache_stats_, + /*lock_cache_line=*/false, + lxu_cache_locking_counter, + /*compute_inverse_indices=*/false); // insert caching weights lru_cache_insert_byte_cuda( diff --git a/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp b/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp index ec079ee3cc..e049c2efd2 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp +++ b/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp @@ -39,7 +39,12 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "lxu_cache_locations_update(Tensor(a!) lxu_cache_locations, Tensor lxu_cache_locations_new, Tensor? num_uniq_cache_indices=None) -> ()"); m.def( - "get_unique_indices(Tensor linear_indices, int max_indices, bool compute_count) -> (Tensor, Tensor, Tensor?)"); + "get_unique_indices_internal(" + " Tensor linear_indices, " + " int max_indices, " + " bool compute_count, " + " bool compute_inverse_indices=False" + ") -> (Tensor, Tensor, Tensor?, Tensor?)"); } using namespace fbgemm_gpu; diff --git a/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cu b/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cu index fdf5c2ccc7..b6f3504bff 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cu @@ -33,7 +33,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { lxu_cache_locking_counter_decrement_cuda); DISPATCH_TO_CUDA( "lxu_cache_locations_update", lxu_cache_locations_update_cuda); - DISPATCH_TO_CUDA("get_unique_indices", get_unique_indices_cuda); + DISPATCH_TO_CUDA("get_unique_indices_internal", get_unique_indices_cuda); } } // namespace diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu index ec284aebd9..50fa4ca466 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu @@ -15,6 +15,7 @@ #include #include "fbgemm_gpu/dispatch_macros.h" #include "fbgemm_gpu/fbgemm_cuda_utils.cuh" +#include "fbgemm_gpu/sparse_ops.h" #include "fbgemm_gpu/sparse_ops_utils.h" #include "fbgemm_gpu/split_embeddings_cache_cuda.cuh" #include "fbgemm_gpu/split_embeddings_utils.cuh" @@ -40,6 +41,9 @@ __global__ __launch_bounds__(kMaxThreads) void masked_index_put_kernel( return; } const auto idx = indices[n]; + if (idx < 0) { + return; + } const auto D = self.size(1); for (int32_t d = threadIdx.x; d * 4 < D; d += blockDim.x) { Vec4T::copy((&values[n][0]) + d * 4, (&self[idx][0]) + d * 4); @@ -63,6 +67,9 @@ __global__ __launch_bounds__(kMaxThreads) void masked_index_put_kernel( return; } const auto idx = indices[n]; + if (idx < 0) { + return; + } const auto D = self.size(1); // each row is padded with row_alignment (16 bytes on GPUs), so each row will // be multiple of 16 bytes (uint4 = 32bit x 4 = 16 bytes). @@ -163,7 +170,7 @@ __global__ __launch_bounds__(kMaxThreads) void ssd_cache_actions_insert_kernel( if (cache_set >= C) { if (threadIdx.x == 0) { - // ignore the already-existing elements + // Ignore the already-existing elements evicted_indices[n] = -1; assigned_cache_slots[n] = -1; } @@ -172,7 +179,6 @@ __global__ __launch_bounds__(kMaxThreads) void ssd_cache_actions_insert_kernel( // check if this warp is responsible for this whole segment. const bool segment_start = (n == 0 || sorted_cache_sets[n - 1] != cache_set); - if (!segment_start) { // don't have *warp* divergence since we launch full warps in blockDim.x, // so @@ -185,10 +191,6 @@ __global__ __launch_bounds__(kMaxThreads) void ssd_cache_actions_insert_kernel( SL += 1; } - // This will mean that we can't insert all the indices for our segment, - // which will break the guarantees required for the SSD embedding. - // If you hit this, increase the cache size. - CUDA_KERNEL_ASSERT2(SL <= kWarpSize); // now, we need to insert the (unique!) values in indices[n:n + SL] into // our slots. const int32_t slot = threadIdx.x; @@ -201,33 +203,51 @@ __global__ __launch_bounds__(kMaxThreads) void ssd_cache_actions_insert_kernel( const int64_t sorted_time = costs[0]; auto l = threadIdx.x; - if (l >= SL) { - return; - } - const int32_t insert_slot = sorted_slot; - const int64_t insert_time = sorted_time; - - const int64_t insert_idx = cache_set_sorted_indices[n + l]; - const int64_t current_idx = lxu_cache_state[cache_set][insert_slot]; - - // Only check insert_time if tag is for valid entry - if (current_idx != -1) { - // We need to ensure if prefetching (prefetch_dist) batches ahead - // No entries that are younger than (time_stamp - prefetch_dist) are - // evicted from the cache. This will break the guarantees required - // for the SSD embedding. - // If you hit this assert, increase the cache size. - CUDA_KERNEL_ASSERT2(insert_time < (time_stamp - prefetch_dist)); + // Insert rows + if (l < SL) { + // Insert indices + const int32_t insert_slot = sorted_slot; + const int64_t insert_time = sorted_time; + + const int64_t insert_idx = cache_set_sorted_indices[n + l]; + const int64_t current_idx = lxu_cache_state[cache_set][insert_slot]; + +#if 0 + // TODO: Check whether to uncomment this + // Only check insert_time if tag is for valid entry + if (current_idx != -1) { + // We need to ensure if prefetching (prefetch_dist) batches ahead + // No entries that are younger than (time_stamp - prefetch_dist) are + // evicted from the cache. This will break the guarantees required + // for the SSD embedding. + // If you hit this assert, increase the cache size. + CUDA_KERNEL_ASSERT2(insert_time < (time_stamp - prefetch_dist)); + } +#endif + + if (current_idx != -1 && insert_time == time_stamp) { + // Skip this slot as the inserted row was a cache hit + // This is conflict miss + evicted_indices[n + l] = -1; + assigned_cache_slots[n + l] = -1; + } else { + evicted_indices[n + l] = current_idx; // -1 if not set, >= 0 if valid. + assigned_cache_slots[n + l] = cache_set * kWarpSize + insert_slot; + lxu_cache_state[cache_set][insert_slot] = insert_idx; + lru_state[cache_set][insert_slot] = time_stamp; + } } - evicted_indices[n + l] = current_idx; // -1 if not set, >= 0 if valid. - assigned_cache_slots[n + l] = cache_set * kWarpSize + insert_slot; - lxu_cache_state[cache_set][insert_slot] = insert_idx; - lru_state[cache_set][insert_slot] = time_stamp; + // Conflict misses + for (auto l = kWarpSize + threadIdx.x; l < SL; l += kWarpSize) { + evicted_indices[n + l] = -1; + assigned_cache_slots[n + l] = -1; + } } -std::tuple ssd_cache_populate_actions_cuda( +std::tuple +ssd_cache_populate_actions_cuda( Tensor linear_indices, int64_t total_hash_size, Tensor lxu_cache_state, @@ -240,11 +260,22 @@ std::tuple ssd_cache_populate_actions_cuda( CUDA_DEVICE_GUARD(linear_indices); // Get unique indices - Tensor unique_indices; - Tensor unique_indices_length; - c10::optional unique_indices_count; - std::tie(unique_indices, unique_indices_length, unique_indices_count) = - get_unique_indices_cuda(linear_indices, total_hash_size, false); + auto + [unique_indices, + unique_indices_length, + unique_indices_count, + linear_index_inverse_indices] = + get_unique_indices_cuda( + linear_indices, + total_hash_size, + /*compute_count=*/true, + /*compute_inverse_indices=*/true); + + TORCH_CHECK(linear_index_inverse_indices.has_value()); + TORCH_CHECK(unique_indices_count.has_value()); + const auto unique_indices_count_cumsum = + asynchronous_complete_cumsum_gpu(unique_indices_count.value()) + .to(at::kInt); TORCH_CHECK_LT(unique_indices.numel(), std::numeric_limits::max()); const int32_t N = unique_indices.numel(); @@ -260,26 +291,36 @@ std::tuple ssd_cache_populate_actions_cuda( empty_like(unique_indices), evicted_indices, assigned_cache_slots, - actions_count); + actions_count, + /*linear_index_inverse_indices=*/at::empty({0}, int_options), + /*unique_indices_count_cumsum=*/at::empty({0}, int_options), + /*cache_set_inverse_indices=*/at::empty({0}, int_options), + /*cache_set_inverse_indices=*/at::empty({0}, int_options)); } auto actions_count = at::empty({1}, int_options); // Find uncached indices Tensor uvm_cache_stats = at::empty({0}, int_options); Tensor lxu_cache_locking_counter = at::empty({0, 0}, int_options); - auto cache_sets_and_unique_indices = lru_cache_find_uncached_cuda( - unique_indices, - unique_indices_length, - total_hash_size, - lxu_cache_state, - time_stamp, - lru_state, - false, // gather_cache_stats - uvm_cache_stats, - false, // lock_cache_line - lxu_cache_locking_counter); - auto sorted_cache_sets = cache_sets_and_unique_indices.first; - auto cache_set_sorted_unique_indices = cache_sets_and_unique_indices.second; + auto + [sorted_cache_sets, + cache_set_sorted_unique_indices, + cache_set_inverse_indices] = + lru_cache_find_uncached_cuda( + unique_indices, + unique_indices_length, + total_hash_size, + lxu_cache_state, + time_stamp, + lru_state, + /*gather_cache_stats=*/false, + uvm_cache_stats, + /*lock_cache_line=*/false, + lxu_cache_locking_counter, + /*compute_inverse_indices=*/true); + + TORCH_CHECK(cache_set_inverse_indices.has_value()); + TORCH_DSA_KERNEL_LAUNCH( ssd_cache_actions_insert_kernel, div_round_up(N, kMaxThreads / kWarpSize), @@ -302,5 +343,132 @@ std::tuple ssd_cache_populate_actions_cuda( cache_set_sorted_unique_indices, evicted_indices, assigned_cache_slots, - actions_count); + actions_count, + linear_index_inverse_indices.value(), + unique_indices_count_cumsum, + cache_set_inverse_indices.value(), + unique_indices_length); +} + +__global__ __launch_bounds__(kMaxThreads) void ssd_generate_row_addrs_kernel( + at::PackedTensorAccessor32 ssd_row_addrs, + at::PackedTensorAccessor32 + post_bwd_evicted_indices, + const at::PackedTensorAccessor32 + lxu_cache_locations, + const at::PackedTensorAccessor32 + assigned_cache_slots, + const at::PackedTensorAccessor32 + linear_index_inverse_indices, + // TODO: Use int64_t here + const at::PackedTensorAccessor32 + unique_indices_count_cumsum, + const at::PackedTensorAccessor32 + cache_set_inverse_indices, + const at::PackedTensorAccessor32 + cache_set_sorted_unique_indices, + const uint64_t lxu_cache_weights_addr, + const uint64_t inserted_ssd_weights_addr, + const int* N_unique, + const uint64_t cache_row_bytes // has to be 64 bits to prevent overflow +) { + const auto n = blockDim.y * blockIdx.x + threadIdx.y; + if (n >= *N_unique) { + return; + } + + const auto cache_set_id = cache_set_inverse_indices[n]; + const auto segment_start = unique_indices_count_cumsum[cache_set_id]; + const auto segment_end = unique_indices_count_cumsum[cache_set_id + 1]; + // Cache locations + const auto cache_loc = + lxu_cache_locations[linear_index_inverse_indices[segment_start]]; + + const uint64_t ptr_addr = (cache_loc == -1) + // Conflict miss + ? (inserted_ssd_weights_addr + (n * cache_row_bytes)) + // Not conflict miss + : (lxu_cache_weights_addr + (cache_loc * cache_row_bytes)); + + // Set post backward evicted indices + if (assigned_cache_slots[n] == -1 && cache_loc == -1) { + post_bwd_evicted_indices[n] = cache_set_sorted_unique_indices[n]; + } else { + post_bwd_evicted_indices[n] = -1; + } + + // Set pointer address + for (auto l = segment_start + threadIdx.x; l < segment_end; l += blockDim.x) { + auto dst = linear_index_inverse_indices[l]; + *reinterpret_cast(&ssd_row_addrs[dst]) = ptr_addr; + } +} + +std::tuple ssd_generate_row_addrs_cuda( + const Tensor& lxu_cache_locations, + const Tensor& assigned_cache_slots, + const Tensor& linear_index_inverse_indices, + const Tensor& unique_indices_count_cumsum, + const Tensor& cache_set_inverse_indices, + const Tensor& lxu_cache_weights, + const Tensor& inserted_ssd_weights, + const Tensor& unique_indices_length, + const Tensor& cache_set_sorted_unique_indices) { + TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( + lxu_cache_locations, + assigned_cache_slots, + linear_index_inverse_indices, + unique_indices_count_cumsum, + cache_set_inverse_indices, + lxu_cache_weights, + inserted_ssd_weights, + unique_indices_length, + cache_set_sorted_unique_indices); + + CUDA_DEVICE_GUARD(lxu_cache_locations); + + const auto ssd_row_addrs = at::zeros( + {lxu_cache_locations.numel()}, + lxu_cache_locations.options().dtype(at::kLong)); + const auto post_bwd_evicted_indices = at::empty_like(ssd_row_addrs); + + constexpr auto kNumWarps = kMaxThreads / kWarpSize; + const auto cache_row_bytes = + lxu_cache_weights.size(1) * lxu_cache_weights.element_size(); + const auto lxu_cache_weights_addr = + reinterpret_cast(lxu_cache_weights.data_ptr()); + + // All rows are hit in the cache + if (lxu_cache_locations.numel() == 0) { + // TODO: make this more efficient + return {ssd_row_addrs, post_bwd_evicted_indices}; + } + + ssd_generate_row_addrs_kernel<<< + div_round_up(lxu_cache_locations.numel(), kNumWarps), + dim3(kWarpSize, kNumWarps), + 0, + at::cuda::getCurrentCUDAStream()>>>( + ssd_row_addrs.packed_accessor32(), + post_bwd_evicted_indices + .packed_accessor32(), + lxu_cache_locations + .packed_accessor32(), + assigned_cache_slots + .packed_accessor32(), + linear_index_inverse_indices + .packed_accessor32(), + unique_indices_count_cumsum + .packed_accessor32(), + cache_set_inverse_indices + .packed_accessor32(), + cache_set_sorted_unique_indices + .packed_accessor32(), + lxu_cache_weights_addr, + reinterpret_cast(inserted_ssd_weights.data_ptr()), + unique_indices_length.data_ptr(), + cache_row_bytes); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return {ssd_row_addrs, post_bwd_evicted_indices}; } diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index 266567427d..d2a2836745 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -6,6 +6,8 @@ * LICENSE file in the root directory of this source tree. */ +/// @defgroup embedding-ssd Embedding SSD Operators + #include #include #include @@ -17,7 +19,8 @@ using namespace at; -std::tuple ssd_cache_populate_actions_cuda( +std::tuple +ssd_cache_populate_actions_cuda( Tensor linear_indices, int64_t total_hash_size, Tensor lxu_cache_state, @@ -34,6 +37,63 @@ Tensor masked_index_put_byte_cuda( Tensor values, Tensor count); +/// @ingroup embedding-ssd +/// +/// @brief Generate memory addresses for SSD TBE data +/// +/// The data retrieved from SSD can be stored in either a scratch pad +/// (HBM) or LXU cache (also HBM). `lxu_cache_locations` is used to +/// specify the location of the data. If the location is -1, the data +/// for the associated index is in the scratch pad; otherwise, it is +/// in the cache. To enable TBE kernels to access the data +/// conveniently, this operator generates memory addresses of the +/// first byte for each index. When accessing data, a TBE kernel only +/// needs to convert addresses into pointers. +/// +/// Moreover, this operator also generate the list of post backward +/// evicted indices which are basically the indices that their data +/// is in the scratch pad. +/// +/// @param lxu_cache_locations The tensor that contains cache slots +/// where data is stored for the *full* list +/// of indices. -1 is a sentinel value that +/// indicates that data is not in cache. +/// @param assigned_cache_slots The tensor that contains cache slots +/// for the *unique* list of indices. -1 +/// indicates that data is not in cache +/// @param linear_index_inverse_indices The tensor that contains +/// the original position of +/// linear indices before being +/// sorted +/// @param unique_indices_count_cumsum The tensor that contains the +/// the exclusive prefix sum +/// results of the counts of unique +/// indices +/// @param cache_set_inverse_indices The tensor that contains the +/// original positions of cache sets +/// before being sorted +/// @param lxu_cache_weights The LXU cache tensor +/// @param inserted_ssd_weights The scratch pad tensor +/// @param unique_indices_length The tensor that contains the number +/// of unique indices (GPU tensor) +/// @param cache_set_sorted_unique_indices The tensor that contains +/// associated unique indices +/// for the sorted unique cache +/// sets +/// +/// @return A tuple of tensors (the SSD row address tensor and the +/// post backward evicted index tensor) +std::tuple ssd_generate_row_addrs_cuda( + const Tensor& lxu_cache_locations, + const Tensor& assigned_cache_slots, + const Tensor& linear_index_inverse_indices, + const Tensor& unique_indices_count_cumsum, + const Tensor& cache_set_inverse_indices, + const Tensor& lxu_cache_weights, + const Tensor& inserted_ssd_weights, + const Tensor& unique_indices_length, + const Tensor& cache_set_sorted_unique_indices); + namespace { class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { public: @@ -127,11 +187,36 @@ static auto embedding_rocks_db_wrapper = TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( - "masked_index_put(Tensor self, Tensor indices, Tensor values, Tensor count) -> Tensor"); + "masked_index_put(" + " Tensor self, " + " Tensor indices, " + " Tensor values, " + " Tensor count" + ") -> Tensor"); DISPATCH_TO_CUDA("masked_index_put", masked_index_put_cuda); m.def( - "ssd_cache_populate_actions(Tensor linear_indices, int total_hash_size, Tensor lxu_cache_state, int time_stamp, int prefetch_dist, Tensor lru_state) -> (Tensor, Tensor, Tensor, Tensor)"); + "ssd_cache_populate_actions(" + " Tensor linear_indices, " + " int total_hash_size, " + " Tensor lxu_cache_state, " + " int time_stamp, " + " int prefetch_dist, " + " Tensor lru_state" + ") -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"); DISPATCH_TO_CUDA( "ssd_cache_populate_actions", ssd_cache_populate_actions_cuda); + m.def( + "ssd_generate_row_addrs(" + " Tensor lxu_cache_locations, " + " Tensor assigned_cache_slots, " + " Tensor linear_index_inverse_indices, " + " Tensor unique_indices_count_cumsum, " + " Tensor cache_set_inverse_indices, " + " Tensor lxu_cache_weights, " + " Tensor inserted_ssd_weights, " + " Tensor unique_indices_length, " + " Tensor cache_set_sorted_unique_indices" + ") -> (Tensor, Tensor)"); + DISPATCH_TO_CUDA("ssd_generate_row_addrs", ssd_generate_row_addrs_cuda); } } // namespace diff --git a/fbgemm_gpu/test/ssd_split_table_batched_embeddings_test.py b/fbgemm_gpu/test/ssd_split_table_batched_embeddings_test.py index 637d5ffff4..5f3c7ae337 100644 --- a/fbgemm_gpu/test/ssd_split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/ssd_split_table_batched_embeddings_test.py @@ -410,6 +410,10 @@ def test_ssd_cache( _, _, actions_count_gpu, + _, + _, + _, + _, ) = torch.ops.fbgemm.ssd_cache_populate_actions( # noqa linear_cache_indices, emb.total_hash_size, @@ -798,6 +802,10 @@ def test_nbit_ssd_cache( _, _, actions_count_gpu, + _, + _, + _, + _, ) = torch.ops.fbgemm.ssd_cache_populate_actions( # noqa linear_cache_indices, emb.total_hash_size, diff --git a/fbgemm_gpu/test/tbe/cache/failures_dict_fast.json b/fbgemm_gpu/test/tbe/cache/failures_dict_fast.json index 6e59307187..a18ec6b9e9 100644 --- a/fbgemm_gpu/test/tbe/cache/failures_dict_fast.json +++ b/fbgemm_gpu/test/tbe/cache/failures_dict_fast.json @@ -26,6 +26,10 @@ "LXUCacheTest.test_faketensor__test_unique_lxu_cache_lookup": { "comment": "", "status": "xfail" + }, + "LXUCacheTest.test_schema__test_unique_lxu_cache_lookup": { + "comment": "", + "status": "xfail" } }, "fbgemm::int_nbit_split_embedding_codegen_lookup_function": {