From bf524d39af80ab96498c1885c19e7d13b25493c4 Mon Sep 17 00:00:00 2001 From: Emma Lin Date: Sun, 5 Oct 2025 18:35:17 -0700 Subject: [PATCH] change from first element to a random element for cache missing items (#4955) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/1974 In inference zch backend, we cannot use initializer to randomized init value for cache missing items, as the intializer does not work in parallel read and write mode. The current behavior is to always get the first item in hash map, but that has less randmization. This diff added a randmization for cache missing ids, also add a log to show the missing ids in every batch. Reviewed By: EddyLXJ Differential Revision: D83612329 --- .../dram_kv_inference_embedding.h | 38 +++++++++-- .../tbe/dram_kv/dram_kv_inference_test.py | 64 +++++++++++++++++++ 2 files changed, 98 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h index 9856c550ad..29ebb714aa 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h @@ -22,6 +22,7 @@ #include #include #include +#include #include "common/time/Time.h" #include "../ssd_split_embeddings_cache/initializer.h" @@ -419,9 +420,36 @@ class DramKVInferenceEmbedding { before_read_lock_ts; if (!wlmap->empty()) { - row_storage_data_ptr = - FixedBlockPool::data_ptr( - wlmap->begin()->second); + // Simple block-based randomization using get_block with + // cursor + auto* pool = kv_store_.pool_by(shard_id); + + // Random starting cursor based on map size for good + // entropy + size_t random_start = + folly::Random::rand32(wlmap->size()); + + // Try to find a used block starting from random + // position + weight_type* block = nullptr; + for (int attempts = 0; attempts < 16; ++attempts) { + block = pool->template get_block( + random_start + attempts); + if (block != nullptr) { + // Block is used (not null) + row_storage_data_ptr = + FixedBlockPool::data_ptr(block); + break; + } + } + + // Fallback: if no used block found, use first element + // from map + if (block == nullptr) { + row_storage_data_ptr = + FixedBlockPool::data_ptr( + wlmap->begin()->second); + } } else { const auto& init_storage = initializers_[shard_id]->row_storage_; @@ -526,7 +554,9 @@ class DramKVInferenceEmbedding { read_lookup_cache_total_duration / num_shards_; read_acquire_lock_avg_duration_ += read_acquire_lock_total_duration / num_shards_; - read_missing_load_avg_ += read_missing_load / num_shards_; + LOG_EVERY_MS(INFO, 5000) + << "get_kv_db_async total read_missing_load per batch: " + << read_missing_load; return std::vector(results.size()); }); }; diff --git a/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py b/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py index 34d996ea6f..3877b4b1c1 100644 --- a/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py +++ b/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py @@ -249,3 +249,67 @@ def reader_thread() -> None: # pyre-ignore self.assertTrue(equal_one_of(embs[5, :4], possible_embs)) reader_thread.join() self.assertFalse(reader_failed_event.is_set()) + + def test_randomized_cache_miss_initialization(self) -> None: + """Test that cache misses use randomized data from existing blocks.""" + num_shards = 8 + uniform_init_lower: float = -0.01 + uniform_init_upper: float = 0.01 + + # Create DRAM KV inference cache + kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper( + num_shards, uniform_init_lower, uniform_init_upper + ) + kv_embedding_cache.init( + [(32, 4, SparseType.FP16.as_int())], + 32, + 4, + torch.tensor([0, 100], dtype=torch.int64), + ) + + # Setup: Populate the cache with many initial values for better randomization diversity + # Use 400 setup items to ensure each shard (8 shards) gets ~50 entries for good randomization + setup_indices = torch.arange(0, 400, dtype=torch.int64) # 400 setup items + setup_weights = torch.randint( + 1, 255, (400, 32), dtype=torch.uint8 + ) # Non-zero values to ensure randomization source + print(f"setup_weights: {setup_weights}") + + # Populate cache + kv_embedding_cache.set_embeddings(setup_indices, setup_weights) + + # Execute: Request cache misses multiple times - these should get randomized initialization + # Use indices outside the range [0, 399] to ensure they are actual cache misses + miss_indices = torch.tensor([500, 501, 502, 503, 504], dtype=torch.int64) + + # Get the cache miss results multiple times to check for randomization + results = [] + for _ in range(5): + current_output = kv_embedding_cache.get_embeddings(miss_indices) + results.append(current_output.clone()) + + # Assert: Verify that randomization occurs + # The results should not all be identical if randomization is working + all_identical = True + for i in range(1, len(results)): + if not torch.equal( + results[0][:, :4], results[i][:, :4] + ): # Only check first 4 columns (actual data) + all_identical = False + break + + # Since we're using randomization, results should be different + # Note: There's a small chance they could be identical by random chance, + # but with 5 trials of 5 vectors of 4 bytes, this is extremely unlikely + self.assertFalse( + all_identical, + "Randomized cache miss initialization should produce different results", + ) + + # All results should be non-zero (since we populated the cache with non-zero random values) + for result in results: + # Check that at least some values are non-zero (indicating data came from existing blocks) + self.assertTrue( + torch.any(result[:, :4] != 0), + "Cache miss results should contain non-zero values when cache has data", + )