diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h index 9856c550ad..29ebb714aa 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h @@ -22,6 +22,7 @@ #include #include #include +#include #include "common/time/Time.h" #include "../ssd_split_embeddings_cache/initializer.h" @@ -419,9 +420,36 @@ class DramKVInferenceEmbedding { before_read_lock_ts; if (!wlmap->empty()) { - row_storage_data_ptr = - FixedBlockPool::data_ptr( - wlmap->begin()->second); + // Simple block-based randomization using get_block with + // cursor + auto* pool = kv_store_.pool_by(shard_id); + + // Random starting cursor based on map size for good + // entropy + size_t random_start = + folly::Random::rand32(wlmap->size()); + + // Try to find a used block starting from random + // position + weight_type* block = nullptr; + for (int attempts = 0; attempts < 16; ++attempts) { + block = pool->template get_block( + random_start + attempts); + if (block != nullptr) { + // Block is used (not null) + row_storage_data_ptr = + FixedBlockPool::data_ptr(block); + break; + } + } + + // Fallback: if no used block found, use first element + // from map + if (block == nullptr) { + row_storage_data_ptr = + FixedBlockPool::data_ptr( + wlmap->begin()->second); + } } else { const auto& init_storage = initializers_[shard_id]->row_storage_; @@ -526,7 +554,9 @@ class DramKVInferenceEmbedding { read_lookup_cache_total_duration / num_shards_; read_acquire_lock_avg_duration_ += read_acquire_lock_total_duration / num_shards_; - read_missing_load_avg_ += read_missing_load / num_shards_; + LOG_EVERY_MS(INFO, 5000) + << "get_kv_db_async total read_missing_load per batch: " + << read_missing_load; return std::vector(results.size()); }); }; diff --git a/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py b/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py index 34d996ea6f..3877b4b1c1 100644 --- a/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py +++ b/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py @@ -249,3 +249,67 @@ def reader_thread() -> None: # pyre-ignore self.assertTrue(equal_one_of(embs[5, :4], possible_embs)) reader_thread.join() self.assertFalse(reader_failed_event.is_set()) + + def test_randomized_cache_miss_initialization(self) -> None: + """Test that cache misses use randomized data from existing blocks.""" + num_shards = 8 + uniform_init_lower: float = -0.01 + uniform_init_upper: float = 0.01 + + # Create DRAM KV inference cache + kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper( + num_shards, uniform_init_lower, uniform_init_upper + ) + kv_embedding_cache.init( + [(32, 4, SparseType.FP16.as_int())], + 32, + 4, + torch.tensor([0, 100], dtype=torch.int64), + ) + + # Setup: Populate the cache with many initial values for better randomization diversity + # Use 400 setup items to ensure each shard (8 shards) gets ~50 entries for good randomization + setup_indices = torch.arange(0, 400, dtype=torch.int64) # 400 setup items + setup_weights = torch.randint( + 1, 255, (400, 32), dtype=torch.uint8 + ) # Non-zero values to ensure randomization source + print(f"setup_weights: {setup_weights}") + + # Populate cache + kv_embedding_cache.set_embeddings(setup_indices, setup_weights) + + # Execute: Request cache misses multiple times - these should get randomized initialization + # Use indices outside the range [0, 399] to ensure they are actual cache misses + miss_indices = torch.tensor([500, 501, 502, 503, 504], dtype=torch.int64) + + # Get the cache miss results multiple times to check for randomization + results = [] + for _ in range(5): + current_output = kv_embedding_cache.get_embeddings(miss_indices) + results.append(current_output.clone()) + + # Assert: Verify that randomization occurs + # The results should not all be identical if randomization is working + all_identical = True + for i in range(1, len(results)): + if not torch.equal( + results[0][:, :4], results[i][:, :4] + ): # Only check first 4 columns (actual data) + all_identical = False + break + + # Since we're using randomization, results should be different + # Note: There's a small chance they could be identical by random chance, + # but with 5 trials of 5 vectors of 4 bytes, this is extremely unlikely + self.assertFalse( + all_identical, + "Randomized cache miss initialization should produce different results", + ) + + # All results should be non-zero (since we populated the cache with non-zero random values) + for result in results: + # Check that at least some values are non-zero (indicating data came from existing blocks) + self.assertTrue( + torch.any(result[:, :4] != 0), + "Cache miss results should contain non-zero values when cache has data", + )