diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py index dc536dccdb..b887f2ba6c 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py @@ -76,6 +76,7 @@ def __init__( # noqa C901 reverse_qparam: bool = False, # True to load qparams at end of each row; False to load qparam at begnning of each row. feature_names_per_table: Optional[list[list[str]]] = None, indices_dtype: torch.dtype = torch.int32, # Used for construction of the remap_indices tensors. Should match the dtype of the indices passed in the forward() call (INT32 or INT64). + embedding_cache_mode: bool = False, # True for zero initialization, False for randomized initialization ) -> None: # noqa C901 # tuple of (rows, dims,) super(KVEmbeddingInference, self).__init__( embedding_specs=embedding_specs, @@ -114,9 +115,13 @@ def __init__( # noqa C901 num_shards = 32 uniform_init_lower: float = -0.01 uniform_init_upper: float = 0.01 + # pyre-fixme[4]: Attribute must be annotated. self.kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper( - num_shards, uniform_init_lower, uniform_init_upper + num_shards, + uniform_init_lower, + uniform_init_upper, + embedding_cache_mode, # in embedding_cache_mode, we disable random init ) self.specs: list[tuple[int, int, int]] = [ diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.cpp b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.cpp index 787e0fc027..9500593435 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.cpp +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.cpp @@ -15,10 +15,16 @@ namespace fbgemm_gpu { DramKVEmbeddingInferenceWrapper::DramKVEmbeddingInferenceWrapper( int64_t num_shards, double uniform_init_lower, - double uniform_init_upper) + double uniform_init_upper, + bool disable_random_init) : num_shards_(num_shards), uniform_init_lower_(uniform_init_lower), - uniform_init_upper_(uniform_init_upper) {} + uniform_init_upper_(uniform_init_upper), + disable_random_init_(disable_random_init) { + LOG(INFO) + << "DramKVEmbeddingInferenceWrapper created with disable_random_init = " + << disable_random_init_; +} void DramKVEmbeddingInferenceWrapper::init( const std::vector& specs, @@ -70,8 +76,8 @@ void DramKVEmbeddingInferenceWrapper::init( 8 /* row_storage_bitwidth */, false /* enable_async_update */, std::nullopt /* table_dims */, - hash_size_cumsum); - return; + hash_size_cumsum, + disable_random_init_); } std::shared_ptr> @@ -96,7 +102,6 @@ void DramKVEmbeddingInferenceWrapper::set_embeddings( } folly::coro::blockingWait(dram_kv_->inference_set_kv_db_async( indices, weights, count, inplacee_update_ts)); - return; } at::Tensor DramKVEmbeddingInferenceWrapper::get_embeddings( @@ -113,7 +118,7 @@ at::Tensor DramKVEmbeddingInferenceWrapper::get_embeddings( } void DramKVEmbeddingInferenceWrapper::log_inplace_update_stats() { - return dram_kv_->log_inplace_update_stats(); + dram_kv_->log_inplace_update_stats(); } void DramKVEmbeddingInferenceWrapper::trigger_evict( @@ -159,8 +164,17 @@ static auto dram_kv_embedding_inference_wrapper = torch::class_( "fbgemm", "DramKVEmbeddingInferenceWrapper") - .def(torch::init()) - .def("init", &fbgemm_gpu::DramKVEmbeddingInferenceWrapper::init) + .def(torch::init()) + .def( + "init", + &fbgemm_gpu::DramKVEmbeddingInferenceWrapper::init, + "", + { + torch::arg("specs"), + torch::arg("row_alignment"), + torch::arg("scale_bias_size_in_bytes"), + torch::arg("hash_size_cumsum"), + }) .def( "set_embeddings", &fbgemm_gpu::DramKVEmbeddingInferenceWrapper::set_embeddings, diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.h index aada01c47a..68fde2d205 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.h @@ -18,7 +18,8 @@ class DramKVEmbeddingInferenceWrapper : public torch::jit::CustomClassHolder { DramKVEmbeddingInferenceWrapper( int64_t num_shards = 32, double uniform_init_lower = 0.0, - double uniform_init_upper = 0.0); + double uniform_init_upper = 0.0, + bool disable_random_init = false); using SerializedSepcType = std::tuple; // (rows, dime, sparse_type) @@ -55,6 +56,7 @@ class DramKVEmbeddingInferenceWrapper : public torch::jit::CustomClassHolder { int64_t num_shards_ = 32; double uniform_init_lower_ = 0.0; double uniform_init_upper_ = 0.0; + bool disable_random_init_ = false; std::shared_ptr> dram_kv_; int64_t max_row_bytes_ = 0; diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h index 29ebb714aa..f723a14e2b 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h @@ -108,7 +108,8 @@ class DramKVInferenceEmbedding { int64_t row_storage_bitwidth = 32, bool enable_async_update = false, std::optional table_dims = std::nullopt, - std::optional hash_size_cumsum = std::nullopt) + std::optional hash_size_cumsum = std::nullopt, + bool disable_random_init = false) : max_D_(max_D), num_shards_(num_shards), block_size_(FixedBlockPool::calculate_block_size(max_D)), @@ -120,7 +121,8 @@ class DramKVInferenceEmbedding { block_alignment_, /*blocks_per_chunk=*/8192)), elem_size_(row_storage_bitwidth / 8), - feature_evict_config_(std::move(feature_evict_config)) { + feature_evict_config_(std::move(feature_evict_config)), + disable_random_init_(disable_random_init) { executor_ = std::make_unique(std::max( num_threads, facebook::Proc::getCpuInfo().numCpuCores)); initialize_initializers( @@ -151,6 +153,8 @@ class DramKVInferenceEmbedding { sub_table_hash_cumsum_, false /* is_train */); } + LOG(INFO) << "DramKVInferenceEmbedding initialized: disable_random_init " + << disable_random_init_; } void initialize_initializers( @@ -327,14 +331,25 @@ class DramKVInferenceEmbedding { } return folly::collect(std::move(futures)) .via(executor_.get()) - .thenValue( - [this](const std::vector>& tuples) { - for (const auto& pair : tuples) { - inplace_update_hit_cnt_ += std::get<0>(pair); - inplace_update_miss_cnt_ += std::get<1>(pair); - } - return std::vector(tuples.size()); - }); + .thenValue([this](const std::vector>& + tuples) { + auto hit_cnt = 0; + auto miss_cnt = 0; + for (const auto& pair : tuples) { + hit_cnt += std::get<0>(pair); + miss_cnt += std::get<1>(pair); + } + inplace_update_hit_cnt_ += hit_cnt; + inplace_update_miss_cnt_ += miss_cnt; + auto total_count = hit_cnt + miss_cnt; + LOG_EVERY_MS(INFO, 5000) << fmt::format( + "inference_set_kv_db_async: hit count {}, miss count {}, inplace update hit rate {}", + hit_cnt, + miss_cnt, + total_count ? static_cast(hit_cnt) / total_count : 0.0); + + return std::vector(tuples.size()); + }); } /// Get embeddings from kvstore. @@ -391,7 +406,7 @@ class DramKVInferenceEmbedding { int64_t local_read_missing_load = 0; FBGEMM_DISPATCH_INTEGRAL_TYPES( indices.scalar_type(), - "dram_kvstore_set", + "get_kv_db_async_impl", [this, shard_id, indexes, @@ -419,7 +434,7 @@ class DramKVInferenceEmbedding { facebook::WallClockUtil::NowInUsecFast() - before_read_lock_ts; - if (!wlmap->empty()) { + if (!wlmap->empty() && !disable_random_init_) { // Simple block-based randomization using get_block with // cursor auto* pool = kv_store_.pool_by(shard_id); @@ -665,13 +680,12 @@ class DramKVInferenceEmbedding { auto inplace_update_hit_cnt = inplace_update_hit_cnt_.exchange(reset_val); auto inplace_update_miss_cnt = inplace_update_miss_cnt_.exchange(reset_val); - LOG(INFO) << "inplace update stats: hit count: " << inplace_update_hit_cnt - << ", miss count: " << inplace_update_miss_cnt - << ", total count: " - << inplace_update_hit_cnt + inplace_update_miss_cnt - << ", hit ratio: " - << (double)inplace_update_hit_cnt / - (inplace_update_hit_cnt + inplace_update_miss_cnt); + auto total_cnt = inplace_update_hit_cnt + inplace_update_miss_cnt; + LOG_EVERY_MS(INFO, 5000) + << "inplace update stats: hit count: " << inplace_update_hit_cnt + << ", miss count: " << inplace_update_miss_cnt + << ", total count: " << total_cnt << ", hit ratio: " + << (total_cnt > 0 ? (double)inplace_update_hit_cnt / total_cnt : 0.0); } std::optional get_feature_evict_metric() const { @@ -740,6 +754,7 @@ class DramKVInferenceEmbedding { auto copied_bytes = elem_size_ * copied_width; int64_t start_offset_bytes = elem_size_ * width_offset; int64_t row_index = 0; + // TODO: fill the opt state as zeros for init value? std::copy( &(row_storage_data_ptr @@ -901,6 +916,7 @@ class DramKVInferenceEmbedding { std::optional> feature_evict_config_; std::unique_ptr> feature_evict_; int current_iter_ = 0; + bool disable_random_init_ = false; // perf stats std::atomic read_total_duration_{0}; @@ -928,8 +944,6 @@ class DramKVInferenceEmbedding { std::atomic inplace_update_hit_cnt_{0}; std::atomic inplace_update_miss_cnt_{0}; - - bool disable_random_init_; }; // class DramKVInferenceEmbedding } // namespace kv_mem diff --git a/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py b/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py index 3877b4b1c1..796d453681 100644 --- a/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py +++ b/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py @@ -33,7 +33,10 @@ def test_serialize(self) -> None: uniform_init_upper: float = 0.01 kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper( - num_shards, uniform_init_lower, uniform_init_upper + num_shards, + uniform_init_lower, + uniform_init_upper, + False, # disable_random_init ) serialized_result = kv_embedding_cache.serialize() @@ -48,12 +51,15 @@ def test_serialize_deserialize(self) -> None: uniform_init_upper: float = 0.01 kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper( - num_shards, uniform_init_lower, uniform_init_upper + num_shards, + uniform_init_lower, + uniform_init_upper, + False, # disable_random_init ) serialized_result = kv_embedding_cache.serialize() kv_embedding_cache_2 = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper( - 0, 0.0, 0.0 + 0, 0.0, 0.0, False # disable_random_init ) kv_embedding_cache_2.deserialize(serialized_result) @@ -65,7 +71,10 @@ def test_set_get_embeddings(self) -> None: uniform_init_upper: float = 0.0 kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper( - num_shards, uniform_init_lower, uniform_init_upper + num_shards, + uniform_init_lower, + uniform_init_upper, + False, # disable_random_init ) kv_embedding_cache.init( [(20, 4, SparseType.INT8.as_int())], @@ -122,7 +131,10 @@ def test_inplace_update(self) -> None: uniform_init_upper: float = 0.0 kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper( - num_shards, uniform_init_lower, uniform_init_upper + num_shards, + uniform_init_lower, + uniform_init_upper, + False, # disable_random_init ) kv_embedding_cache.init( [(20, 4, SparseType.INT8.as_int())], @@ -258,7 +270,10 @@ def test_randomized_cache_miss_initialization(self) -> None: # Create DRAM KV inference cache kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper( - num_shards, uniform_init_lower, uniform_init_upper + num_shards, + uniform_init_lower, + uniform_init_upper, + False, # disable_random_init ) kv_embedding_cache.init( [(32, 4, SparseType.FP16.as_int())], @@ -313,3 +328,60 @@ def test_randomized_cache_miss_initialization(self) -> None: torch.any(result[:, :4] != 0), "Cache miss results should contain non-zero values when cache has data", ) + + def test_zero_cache_miss_initialization_with_embedding_cache_mode(self) -> None: + """Test that cache misses return all zero values when embedding_cache_mode=True.""" + num_shards = 8 + uniform_init_lower: float = -0.01 + uniform_init_upper: float = 0.01 + + # Setup: Create DRAM KV inference cache with embedding_cache_mode=True (zero initialization) + kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper( + num_shards, + uniform_init_lower, + uniform_init_upper, + True, # embedding_cache_mode=True for zero initialization + ) + kv_embedding_cache.init( + [(32, 4, SparseType.FP16.as_int())], + 32, + 4, + torch.tensor([0, 100], dtype=torch.int64), + ) + + # Populate the cache with some initial non-zero values to ensure zero initialization + # is not just due to empty cache + setup_indices = torch.arange(0, 50, dtype=torch.int64) + setup_weights = torch.randint( + 1, 255, (50, 32), dtype=torch.uint8 + ) # Non-zero values + kv_embedding_cache.set_embeddings(setup_indices, setup_weights) + + # Execute: Request cache misses - these should get zero initialization due to embedding_cache_mode=True + # Use indices outside the range [0, 49] to ensure they are actual cache misses + miss_indices = torch.tensor([100, 101, 102, 103, 104], dtype=torch.int64) + results = [] + + # Get cache miss results multiple times to ensure consistent behavior + for _ in range(3): + current_output = kv_embedding_cache.get_embeddings(miss_indices) + results.append(current_output.clone()) + + # Assert: Verify that all cache miss results are zeros when embedding_cache_mode=True + expected_zeros = torch.zeros((5, 32), dtype=torch.uint8) + + for i, result in enumerate(results): + # Check that all cache miss results are zero + self.assertTrue( + torch.equal(result, expected_zeros), + f"Cache miss results should be all zeros when embedding_cache_mode=True, " + f"but got non-zero values in iteration {i}: {result[:, :4]}", + ) + + # Additional verification: all results should be identical since they're all zeros + for i in range(1, len(results)): + self.assertTrue( + torch.equal(results[0], results[i]), + f"All zero cache miss results should be identical across calls, " + f"but results[0] != results[{i}]", + )