From 0830111f84132cd28c9e852a09ea93df0f310159 Mon Sep 17 00:00:00 2001 From: Emma Lin Date: Sun, 19 Oct 2025 15:55:26 -0700 Subject: [PATCH] disable random init in inference operator for embedding cache (#5026) Summary: X-link: https://github.com/meta-pytorch/torchrec/pull/3466 X-link: https://github.com/facebookresearch/FBGEMM/pull/2040 For embedding cache mode, we do not expect random value if there is cache missing. This diff passed the embedding cache mode to inference operator, and use that to disable the backend random initialization. Differential Revision: D84367061 --- .../tbe/cache/kv_embedding_ops_inference.py | 7 +- .../dram_kv_embedding_inference_wrapper.cpp | 30 +++++-- .../dram_kv_embedding_inference_wrapper.h | 4 +- .../dram_kv_inference_embedding.h | 56 ++++++++----- .../tbe/dram_kv/dram_kv_inference_test.py | 84 +++++++++++++++++-- 5 files changed, 144 insertions(+), 37 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py index dc536dccdb..b887f2ba6c 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py @@ -76,6 +76,7 @@ def __init__( # noqa C901 reverse_qparam: bool = False, # True to load qparams at end of each row; False to load qparam at begnning of each row. feature_names_per_table: Optional[list[list[str]]] = None, indices_dtype: torch.dtype = torch.int32, # Used for construction of the remap_indices tensors. Should match the dtype of the indices passed in the forward() call (INT32 or INT64). + embedding_cache_mode: bool = False, # True for zero initialization, False for randomized initialization ) -> None: # noqa C901 # tuple of (rows, dims,) super(KVEmbeddingInference, self).__init__( embedding_specs=embedding_specs, @@ -114,9 +115,13 @@ def __init__( # noqa C901 num_shards = 32 uniform_init_lower: float = -0.01 uniform_init_upper: float = 0.01 + # pyre-fixme[4]: Attribute must be annotated. self.kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper( - num_shards, uniform_init_lower, uniform_init_upper + num_shards, + uniform_init_lower, + uniform_init_upper, + embedding_cache_mode, # in embedding_cache_mode, we disable random init ) self.specs: list[tuple[int, int, int]] = [ diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.cpp b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.cpp index 787e0fc027..9500593435 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.cpp +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.cpp @@ -15,10 +15,16 @@ namespace fbgemm_gpu { DramKVEmbeddingInferenceWrapper::DramKVEmbeddingInferenceWrapper( int64_t num_shards, double uniform_init_lower, - double uniform_init_upper) + double uniform_init_upper, + bool disable_random_init) : num_shards_(num_shards), uniform_init_lower_(uniform_init_lower), - uniform_init_upper_(uniform_init_upper) {} + uniform_init_upper_(uniform_init_upper), + disable_random_init_(disable_random_init) { + LOG(INFO) + << "DramKVEmbeddingInferenceWrapper created with disable_random_init = " + << disable_random_init_; +} void DramKVEmbeddingInferenceWrapper::init( const std::vector& specs, @@ -70,8 +76,8 @@ void DramKVEmbeddingInferenceWrapper::init( 8 /* row_storage_bitwidth */, false /* enable_async_update */, std::nullopt /* table_dims */, - hash_size_cumsum); - return; + hash_size_cumsum, + disable_random_init_); } std::shared_ptr> @@ -96,7 +102,6 @@ void DramKVEmbeddingInferenceWrapper::set_embeddings( } folly::coro::blockingWait(dram_kv_->inference_set_kv_db_async( indices, weights, count, inplacee_update_ts)); - return; } at::Tensor DramKVEmbeddingInferenceWrapper::get_embeddings( @@ -113,7 +118,7 @@ at::Tensor DramKVEmbeddingInferenceWrapper::get_embeddings( } void DramKVEmbeddingInferenceWrapper::log_inplace_update_stats() { - return dram_kv_->log_inplace_update_stats(); + dram_kv_->log_inplace_update_stats(); } void DramKVEmbeddingInferenceWrapper::trigger_evict( @@ -159,8 +164,17 @@ static auto dram_kv_embedding_inference_wrapper = torch::class_( "fbgemm", "DramKVEmbeddingInferenceWrapper") - .def(torch::init()) - .def("init", &fbgemm_gpu::DramKVEmbeddingInferenceWrapper::init) + .def(torch::init()) + .def( + "init", + &fbgemm_gpu::DramKVEmbeddingInferenceWrapper::init, + "", + { + torch::arg("specs"), + torch::arg("row_alignment"), + torch::arg("scale_bias_size_in_bytes"), + torch::arg("hash_size_cumsum"), + }) .def( "set_embeddings", &fbgemm_gpu::DramKVEmbeddingInferenceWrapper::set_embeddings, diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.h index aada01c47a..68fde2d205 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.h @@ -18,7 +18,8 @@ class DramKVEmbeddingInferenceWrapper : public torch::jit::CustomClassHolder { DramKVEmbeddingInferenceWrapper( int64_t num_shards = 32, double uniform_init_lower = 0.0, - double uniform_init_upper = 0.0); + double uniform_init_upper = 0.0, + bool disable_random_init = false); using SerializedSepcType = std::tuple; // (rows, dime, sparse_type) @@ -55,6 +56,7 @@ class DramKVEmbeddingInferenceWrapper : public torch::jit::CustomClassHolder { int64_t num_shards_ = 32; double uniform_init_lower_ = 0.0; double uniform_init_upper_ = 0.0; + bool disable_random_init_ = false; std::shared_ptr> dram_kv_; int64_t max_row_bytes_ = 0; diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h index 29ebb714aa..f723a14e2b 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h @@ -108,7 +108,8 @@ class DramKVInferenceEmbedding { int64_t row_storage_bitwidth = 32, bool enable_async_update = false, std::optional table_dims = std::nullopt, - std::optional hash_size_cumsum = std::nullopt) + std::optional hash_size_cumsum = std::nullopt, + bool disable_random_init = false) : max_D_(max_D), num_shards_(num_shards), block_size_(FixedBlockPool::calculate_block_size(max_D)), @@ -120,7 +121,8 @@ class DramKVInferenceEmbedding { block_alignment_, /*blocks_per_chunk=*/8192)), elem_size_(row_storage_bitwidth / 8), - feature_evict_config_(std::move(feature_evict_config)) { + feature_evict_config_(std::move(feature_evict_config)), + disable_random_init_(disable_random_init) { executor_ = std::make_unique(std::max( num_threads, facebook::Proc::getCpuInfo().numCpuCores)); initialize_initializers( @@ -151,6 +153,8 @@ class DramKVInferenceEmbedding { sub_table_hash_cumsum_, false /* is_train */); } + LOG(INFO) << "DramKVInferenceEmbedding initialized: disable_random_init " + << disable_random_init_; } void initialize_initializers( @@ -327,14 +331,25 @@ class DramKVInferenceEmbedding { } return folly::collect(std::move(futures)) .via(executor_.get()) - .thenValue( - [this](const std::vector>& tuples) { - for (const auto& pair : tuples) { - inplace_update_hit_cnt_ += std::get<0>(pair); - inplace_update_miss_cnt_ += std::get<1>(pair); - } - return std::vector(tuples.size()); - }); + .thenValue([this](const std::vector>& + tuples) { + auto hit_cnt = 0; + auto miss_cnt = 0; + for (const auto& pair : tuples) { + hit_cnt += std::get<0>(pair); + miss_cnt += std::get<1>(pair); + } + inplace_update_hit_cnt_ += hit_cnt; + inplace_update_miss_cnt_ += miss_cnt; + auto total_count = hit_cnt + miss_cnt; + LOG_EVERY_MS(INFO, 5000) << fmt::format( + "inference_set_kv_db_async: hit count {}, miss count {}, inplace update hit rate {}", + hit_cnt, + miss_cnt, + total_count ? static_cast(hit_cnt) / total_count : 0.0); + + return std::vector(tuples.size()); + }); } /// Get embeddings from kvstore. @@ -391,7 +406,7 @@ class DramKVInferenceEmbedding { int64_t local_read_missing_load = 0; FBGEMM_DISPATCH_INTEGRAL_TYPES( indices.scalar_type(), - "dram_kvstore_set", + "get_kv_db_async_impl", [this, shard_id, indexes, @@ -419,7 +434,7 @@ class DramKVInferenceEmbedding { facebook::WallClockUtil::NowInUsecFast() - before_read_lock_ts; - if (!wlmap->empty()) { + if (!wlmap->empty() && !disable_random_init_) { // Simple block-based randomization using get_block with // cursor auto* pool = kv_store_.pool_by(shard_id); @@ -665,13 +680,12 @@ class DramKVInferenceEmbedding { auto inplace_update_hit_cnt = inplace_update_hit_cnt_.exchange(reset_val); auto inplace_update_miss_cnt = inplace_update_miss_cnt_.exchange(reset_val); - LOG(INFO) << "inplace update stats: hit count: " << inplace_update_hit_cnt - << ", miss count: " << inplace_update_miss_cnt - << ", total count: " - << inplace_update_hit_cnt + inplace_update_miss_cnt - << ", hit ratio: " - << (double)inplace_update_hit_cnt / - (inplace_update_hit_cnt + inplace_update_miss_cnt); + auto total_cnt = inplace_update_hit_cnt + inplace_update_miss_cnt; + LOG_EVERY_MS(INFO, 5000) + << "inplace update stats: hit count: " << inplace_update_hit_cnt + << ", miss count: " << inplace_update_miss_cnt + << ", total count: " << total_cnt << ", hit ratio: " + << (total_cnt > 0 ? (double)inplace_update_hit_cnt / total_cnt : 0.0); } std::optional get_feature_evict_metric() const { @@ -740,6 +754,7 @@ class DramKVInferenceEmbedding { auto copied_bytes = elem_size_ * copied_width; int64_t start_offset_bytes = elem_size_ * width_offset; int64_t row_index = 0; + // TODO: fill the opt state as zeros for init value? std::copy( &(row_storage_data_ptr @@ -901,6 +916,7 @@ class DramKVInferenceEmbedding { std::optional> feature_evict_config_; std::unique_ptr> feature_evict_; int current_iter_ = 0; + bool disable_random_init_ = false; // perf stats std::atomic read_total_duration_{0}; @@ -928,8 +944,6 @@ class DramKVInferenceEmbedding { std::atomic inplace_update_hit_cnt_{0}; std::atomic inplace_update_miss_cnt_{0}; - - bool disable_random_init_; }; // class DramKVInferenceEmbedding } // namespace kv_mem diff --git a/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py b/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py index 3877b4b1c1..796d453681 100644 --- a/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py +++ b/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py @@ -33,7 +33,10 @@ def test_serialize(self) -> None: uniform_init_upper: float = 0.01 kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper( - num_shards, uniform_init_lower, uniform_init_upper + num_shards, + uniform_init_lower, + uniform_init_upper, + False, # disable_random_init ) serialized_result = kv_embedding_cache.serialize() @@ -48,12 +51,15 @@ def test_serialize_deserialize(self) -> None: uniform_init_upper: float = 0.01 kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper( - num_shards, uniform_init_lower, uniform_init_upper + num_shards, + uniform_init_lower, + uniform_init_upper, + False, # disable_random_init ) serialized_result = kv_embedding_cache.serialize() kv_embedding_cache_2 = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper( - 0, 0.0, 0.0 + 0, 0.0, 0.0, False # disable_random_init ) kv_embedding_cache_2.deserialize(serialized_result) @@ -65,7 +71,10 @@ def test_set_get_embeddings(self) -> None: uniform_init_upper: float = 0.0 kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper( - num_shards, uniform_init_lower, uniform_init_upper + num_shards, + uniform_init_lower, + uniform_init_upper, + False, # disable_random_init ) kv_embedding_cache.init( [(20, 4, SparseType.INT8.as_int())], @@ -122,7 +131,10 @@ def test_inplace_update(self) -> None: uniform_init_upper: float = 0.0 kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper( - num_shards, uniform_init_lower, uniform_init_upper + num_shards, + uniform_init_lower, + uniform_init_upper, + False, # disable_random_init ) kv_embedding_cache.init( [(20, 4, SparseType.INT8.as_int())], @@ -258,7 +270,10 @@ def test_randomized_cache_miss_initialization(self) -> None: # Create DRAM KV inference cache kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper( - num_shards, uniform_init_lower, uniform_init_upper + num_shards, + uniform_init_lower, + uniform_init_upper, + False, # disable_random_init ) kv_embedding_cache.init( [(32, 4, SparseType.FP16.as_int())], @@ -313,3 +328,60 @@ def test_randomized_cache_miss_initialization(self) -> None: torch.any(result[:, :4] != 0), "Cache miss results should contain non-zero values when cache has data", ) + + def test_zero_cache_miss_initialization_with_embedding_cache_mode(self) -> None: + """Test that cache misses return all zero values when embedding_cache_mode=True.""" + num_shards = 8 + uniform_init_lower: float = -0.01 + uniform_init_upper: float = 0.01 + + # Setup: Create DRAM KV inference cache with embedding_cache_mode=True (zero initialization) + kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper( + num_shards, + uniform_init_lower, + uniform_init_upper, + True, # embedding_cache_mode=True for zero initialization + ) + kv_embedding_cache.init( + [(32, 4, SparseType.FP16.as_int())], + 32, + 4, + torch.tensor([0, 100], dtype=torch.int64), + ) + + # Populate the cache with some initial non-zero values to ensure zero initialization + # is not just due to empty cache + setup_indices = torch.arange(0, 50, dtype=torch.int64) + setup_weights = torch.randint( + 1, 255, (50, 32), dtype=torch.uint8 + ) # Non-zero values + kv_embedding_cache.set_embeddings(setup_indices, setup_weights) + + # Execute: Request cache misses - these should get zero initialization due to embedding_cache_mode=True + # Use indices outside the range [0, 49] to ensure they are actual cache misses + miss_indices = torch.tensor([100, 101, 102, 103, 104], dtype=torch.int64) + results = [] + + # Get cache miss results multiple times to ensure consistent behavior + for _ in range(3): + current_output = kv_embedding_cache.get_embeddings(miss_indices) + results.append(current_output.clone()) + + # Assert: Verify that all cache miss results are zeros when embedding_cache_mode=True + expected_zeros = torch.zeros((5, 32), dtype=torch.uint8) + + for i, result in enumerate(results): + # Check that all cache miss results are zero + self.assertTrue( + torch.equal(result, expected_zeros), + f"Cache miss results should be all zeros when embedding_cache_mode=True, " + f"but got non-zero values in iteration {i}: {result[:, :4]}", + ) + + # Additional verification: all results should be identical since they're all zeros + for i in range(1, len(results)): + self.assertTrue( + torch.equal(results[0], results[i]), + f"All zero cache miss results should be identical across calls, " + f"but results[0] != results[{i}]", + )