Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <thrift/lib/cpp2/protocol/CompactProtocol.h>
#include <thrift/lib/cpp2/protocol/Serializer.h>
#include <torch/script.h>
#include <random>
#include "common/time/Time.h"

#include "../ssd_split_embeddings_cache/initializer.h"
Expand Down Expand Up @@ -419,9 +420,36 @@ class DramKVInferenceEmbedding {
before_read_lock_ts;

if (!wlmap->empty()) {
row_storage_data_ptr =
FixedBlockPool::data_ptr<weight_type>(
wlmap->begin()->second);
// Simple block-based randomization using get_block with
// cursor
auto* pool = kv_store_.pool_by(shard_id);

// Random starting cursor based on map size for good
// entropy
size_t random_start =
folly::Random::rand32(wlmap->size());

// Try to find a used block starting from random
// position
weight_type* block = nullptr;
for (int attempts = 0; attempts < 16; ++attempts) {
block = pool->template get_block<weight_type>(
random_start + attempts);
if (block != nullptr) {
// Block is used (not null)
row_storage_data_ptr =
FixedBlockPool::data_ptr<weight_type>(block);
break;
}
}

// Fallback: if no used block found, use first element
// from map
if (block == nullptr) {
row_storage_data_ptr =
FixedBlockPool::data_ptr<weight_type>(
wlmap->begin()->second);
}
} else {
const auto& init_storage =
initializers_[shard_id]->row_storage_;
Expand Down Expand Up @@ -526,7 +554,9 @@ class DramKVInferenceEmbedding {
read_lookup_cache_total_duration / num_shards_;
read_acquire_lock_avg_duration_ +=
read_acquire_lock_total_duration / num_shards_;
read_missing_load_avg_ += read_missing_load / num_shards_;
LOG_EVERY_MS(INFO, 5000)
<< "get_kv_db_async total read_missing_load per batch: "
<< read_missing_load;
return std::vector<folly::Unit>(results.size());
});
};
Expand Down
64 changes: 64 additions & 0 deletions fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,67 @@ def reader_thread() -> None: # pyre-ignore
self.assertTrue(equal_one_of(embs[5, :4], possible_embs))
reader_thread.join()
self.assertFalse(reader_failed_event.is_set())

def test_randomized_cache_miss_initialization(self) -> None:
"""Test that cache misses use randomized data from existing blocks."""
num_shards = 8
uniform_init_lower: float = -0.01
uniform_init_upper: float = 0.01

# Create DRAM KV inference cache
kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
num_shards, uniform_init_lower, uniform_init_upper
)
kv_embedding_cache.init(
[(32, 4, SparseType.FP16.as_int())],
32,
4,
torch.tensor([0, 100], dtype=torch.int64),
)

# Setup: Populate the cache with many initial values for better randomization diversity
# Use 400 setup items to ensure each shard (8 shards) gets ~50 entries for good randomization
setup_indices = torch.arange(0, 400, dtype=torch.int64) # 400 setup items
setup_weights = torch.randint(
1, 255, (400, 32), dtype=torch.uint8
) # Non-zero values to ensure randomization source
print(f"setup_weights: {setup_weights}")

# Populate cache
kv_embedding_cache.set_embeddings(setup_indices, setup_weights)

# Execute: Request cache misses multiple times - these should get randomized initialization
# Use indices outside the range [0, 399] to ensure they are actual cache misses
miss_indices = torch.tensor([500, 501, 502, 503, 504], dtype=torch.int64)

# Get the cache miss results multiple times to check for randomization
results = []
for _ in range(5):
current_output = kv_embedding_cache.get_embeddings(miss_indices)
results.append(current_output.clone())

# Assert: Verify that randomization occurs
# The results should not all be identical if randomization is working
all_identical = True
for i in range(1, len(results)):
if not torch.equal(
results[0][:, :4], results[i][:, :4]
): # Only check first 4 columns (actual data)
all_identical = False
break

# Since we're using randomization, results should be different
# Note: There's a small chance they could be identical by random chance,
# but with 5 trials of 5 vectors of 4 bytes, this is extremely unlikely
self.assertFalse(
all_identical,
"Randomized cache miss initialization should produce different results",
)

# All results should be non-zero (since we populated the cache with non-zero random values)
for result in results:
# Check that at least some values are non-zero (indicating data came from existing blocks)
self.assertTrue(
torch.any(result[:, :4] != 0),
"Cache miss results should contain non-zero values when cache has data",
)
Loading