Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__( # noqa C901
reverse_qparam: bool = False, # True to load qparams at end of each row; False to load qparam at begnning of each row.
feature_names_per_table: Optional[list[list[str]]] = None,
indices_dtype: torch.dtype = torch.int32, # Used for construction of the remap_indices tensors. Should match the dtype of the indices passed in the forward() call (INT32 or INT64).
embedding_cache_mode: bool = False, # True for zero initialization, False for randomized initialization
) -> None: # noqa C901 # tuple of (rows, dims,)
super(KVEmbeddingInference, self).__init__(
embedding_specs=embedding_specs,
Expand Down Expand Up @@ -114,9 +115,13 @@ def __init__( # noqa C901
num_shards = 32
uniform_init_lower: float = -0.01
uniform_init_upper: float = 0.01

# pyre-fixme[4]: Attribute must be annotated.
self.kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
num_shards, uniform_init_lower, uniform_init_upper
num_shards,
uniform_init_lower,
uniform_init_upper,
embedding_cache_mode, # in embedding_cache_mode, we disable random init
)

self.specs: list[tuple[int, int, int]] = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,16 @@ namespace fbgemm_gpu {
DramKVEmbeddingInferenceWrapper::DramKVEmbeddingInferenceWrapper(
int64_t num_shards,
double uniform_init_lower,
double uniform_init_upper)
double uniform_init_upper,
bool disable_random_init)
: num_shards_(num_shards),
uniform_init_lower_(uniform_init_lower),
uniform_init_upper_(uniform_init_upper) {}
uniform_init_upper_(uniform_init_upper),
disable_random_init_(disable_random_init) {
LOG(INFO)
<< "DramKVEmbeddingInferenceWrapper created with disable_random_init = "
<< disable_random_init_;
}

void DramKVEmbeddingInferenceWrapper::init(
const std::vector<SerializedSepcType>& specs,
Expand Down Expand Up @@ -70,8 +76,8 @@ void DramKVEmbeddingInferenceWrapper::init(
8 /* row_storage_bitwidth */,
false /* enable_async_update */,
std::nullopt /* table_dims */,
hash_size_cumsum);
return;
hash_size_cumsum,
disable_random_init_);
}

std::shared_ptr<kv_mem::DramKVInferenceEmbedding<uint8_t>>
Expand All @@ -96,7 +102,6 @@ void DramKVEmbeddingInferenceWrapper::set_embeddings(
}
folly::coro::blockingWait(dram_kv_->inference_set_kv_db_async(
indices, weights, count, inplacee_update_ts));
return;
}

at::Tensor DramKVEmbeddingInferenceWrapper::get_embeddings(
Expand All @@ -113,7 +118,7 @@ at::Tensor DramKVEmbeddingInferenceWrapper::get_embeddings(
}

void DramKVEmbeddingInferenceWrapper::log_inplace_update_stats() {
return dram_kv_->log_inplace_update_stats();
dram_kv_->log_inplace_update_stats();
}

void DramKVEmbeddingInferenceWrapper::trigger_evict(
Expand Down Expand Up @@ -159,8 +164,17 @@ static auto dram_kv_embedding_inference_wrapper =
torch::class_<fbgemm_gpu::DramKVEmbeddingInferenceWrapper>(
"fbgemm",
"DramKVEmbeddingInferenceWrapper")
.def(torch::init<int64_t, double, double>())
.def("init", &fbgemm_gpu::DramKVEmbeddingInferenceWrapper::init)
.def(torch::init<int64_t, double, double, bool>())
.def(
"init",
&fbgemm_gpu::DramKVEmbeddingInferenceWrapper::init,
"",
{
torch::arg("specs"),
torch::arg("row_alignment"),
torch::arg("scale_bias_size_in_bytes"),
torch::arg("hash_size_cumsum"),
})
.def(
"set_embeddings",
&fbgemm_gpu::DramKVEmbeddingInferenceWrapper::set_embeddings,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class DramKVEmbeddingInferenceWrapper : public torch::jit::CustomClassHolder {
DramKVEmbeddingInferenceWrapper(
int64_t num_shards = 32,
double uniform_init_lower = 0.0,
double uniform_init_upper = 0.0);
double uniform_init_upper = 0.0,
bool disable_random_init = false);

using SerializedSepcType =
std::tuple<int64_t, int64_t, int64_t>; // (rows, dime, sparse_type)
Expand Down Expand Up @@ -55,6 +56,7 @@ class DramKVEmbeddingInferenceWrapper : public torch::jit::CustomClassHolder {
int64_t num_shards_ = 32;
double uniform_init_lower_ = 0.0;
double uniform_init_upper_ = 0.0;
bool disable_random_init_ = false;

std::shared_ptr<kv_mem::DramKVInferenceEmbedding<uint8_t>> dram_kv_;
int64_t max_row_bytes_ = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ class DramKVInferenceEmbedding {
int64_t row_storage_bitwidth = 32,
bool enable_async_update = false,
std::optional<at::Tensor> table_dims = std::nullopt,
std::optional<at::Tensor> hash_size_cumsum = std::nullopt)
std::optional<at::Tensor> hash_size_cumsum = std::nullopt,
bool disable_random_init = false)
: max_D_(max_D),
num_shards_(num_shards),
block_size_(FixedBlockPool::calculate_block_size<weight_type>(max_D)),
Expand All @@ -120,7 +121,8 @@ class DramKVInferenceEmbedding {
block_alignment_,
/*blocks_per_chunk=*/8192)),
elem_size_(row_storage_bitwidth / 8),
feature_evict_config_(std::move(feature_evict_config)) {
feature_evict_config_(std::move(feature_evict_config)),
disable_random_init_(disable_random_init) {
executor_ = std::make_unique<folly::CPUThreadPoolExecutor>(std::max<size_t>(
num_threads, facebook::Proc::getCpuInfo().numCpuCores));
initialize_initializers(
Expand Down Expand Up @@ -151,6 +153,8 @@ class DramKVInferenceEmbedding {
sub_table_hash_cumsum_,
false /* is_train */);
}
LOG(INFO) << "DramKVInferenceEmbedding initialized: disable_random_init "
<< disable_random_init_;
}

void initialize_initializers(
Expand Down Expand Up @@ -327,14 +331,25 @@ class DramKVInferenceEmbedding {
}
return folly::collect(std::move(futures))
.via(executor_.get())
.thenValue(
[this](const std::vector<std::tuple<int64_t, int64_t>>& tuples) {
for (const auto& pair : tuples) {
inplace_update_hit_cnt_ += std::get<0>(pair);
inplace_update_miss_cnt_ += std::get<1>(pair);
}
return std::vector<folly::Unit>(tuples.size());
});
.thenValue([this](const std::vector<std::tuple<int64_t, int64_t>>&
tuples) {
auto hit_cnt = 0;
auto miss_cnt = 0;
for (const auto& pair : tuples) {
hit_cnt += std::get<0>(pair);
miss_cnt += std::get<1>(pair);
}
inplace_update_hit_cnt_ += hit_cnt;
inplace_update_miss_cnt_ += miss_cnt;
auto total_count = hit_cnt + miss_cnt;
LOG_EVERY_MS(INFO, 5000) << fmt::format(
"inference_set_kv_db_async: hit count {}, miss count {}, inplace update hit rate {}",
hit_cnt,
miss_cnt,
total_count ? static_cast<double>(hit_cnt) / total_count : 0.0);

return std::vector<folly::Unit>(tuples.size());
});
}

/// Get embeddings from kvstore.
Expand Down Expand Up @@ -391,7 +406,7 @@ class DramKVInferenceEmbedding {
int64_t local_read_missing_load = 0;
FBGEMM_DISPATCH_INTEGRAL_TYPES(
indices.scalar_type(),
"dram_kvstore_set",
"get_kv_db_async_impl",
[this,
shard_id,
indexes,
Expand Down Expand Up @@ -419,7 +434,7 @@ class DramKVInferenceEmbedding {
facebook::WallClockUtil::NowInUsecFast() -
before_read_lock_ts;

if (!wlmap->empty()) {
if (!wlmap->empty() && !disable_random_init_) {
// Simple block-based randomization using get_block with
// cursor
auto* pool = kv_store_.pool_by(shard_id);
Expand Down Expand Up @@ -665,13 +680,12 @@ class DramKVInferenceEmbedding {

auto inplace_update_hit_cnt = inplace_update_hit_cnt_.exchange(reset_val);
auto inplace_update_miss_cnt = inplace_update_miss_cnt_.exchange(reset_val);
LOG(INFO) << "inplace update stats: hit count: " << inplace_update_hit_cnt
<< ", miss count: " << inplace_update_miss_cnt
<< ", total count: "
<< inplace_update_hit_cnt + inplace_update_miss_cnt
<< ", hit ratio: "
<< (double)inplace_update_hit_cnt /
(inplace_update_hit_cnt + inplace_update_miss_cnt);
auto total_cnt = inplace_update_hit_cnt + inplace_update_miss_cnt;
LOG_EVERY_MS(INFO, 5000)
<< "inplace update stats: hit count: " << inplace_update_hit_cnt
<< ", miss count: " << inplace_update_miss_cnt
<< ", total count: " << total_cnt << ", hit ratio: "
<< (total_cnt > 0 ? (double)inplace_update_hit_cnt / total_cnt : 0.0);
}

std::optional<FeatureEvictMetricTensors> get_feature_evict_metric() const {
Expand Down Expand Up @@ -740,6 +754,7 @@ class DramKVInferenceEmbedding {
auto copied_bytes = elem_size_ * copied_width;
int64_t start_offset_bytes = elem_size_ * width_offset;
int64_t row_index = 0;

// TODO: fill the opt state as zeros for init value?
std::copy(
&(row_storage_data_ptr
Expand Down Expand Up @@ -901,6 +916,7 @@ class DramKVInferenceEmbedding {
std::optional<c10::intrusive_ptr<FeatureEvictConfig>> feature_evict_config_;
std::unique_ptr<FeatureEvict<weight_type>> feature_evict_;
int current_iter_ = 0;
bool disable_random_init_ = false;

// perf stats
std::atomic<int64_t> read_total_duration_{0};
Expand Down Expand Up @@ -928,8 +944,6 @@ class DramKVInferenceEmbedding {

std::atomic<int64_t> inplace_update_hit_cnt_{0};
std::atomic<int64_t> inplace_update_miss_cnt_{0};

bool disable_random_init_;
}; // class DramKVInferenceEmbedding

} // namespace kv_mem
84 changes: 78 additions & 6 deletions fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ def test_serialize(self) -> None:
uniform_init_upper: float = 0.01

kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
num_shards, uniform_init_lower, uniform_init_upper
num_shards,
uniform_init_lower,
uniform_init_upper,
False, # disable_random_init
)
serialized_result = kv_embedding_cache.serialize()

Expand All @@ -48,12 +51,15 @@ def test_serialize_deserialize(self) -> None:
uniform_init_upper: float = 0.01

kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
num_shards, uniform_init_lower, uniform_init_upper
num_shards,
uniform_init_lower,
uniform_init_upper,
False, # disable_random_init
)
serialized_result = kv_embedding_cache.serialize()

kv_embedding_cache_2 = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
0, 0.0, 0.0
0, 0.0, 0.0, False # disable_random_init
)
kv_embedding_cache_2.deserialize(serialized_result)

Expand All @@ -65,7 +71,10 @@ def test_set_get_embeddings(self) -> None:
uniform_init_upper: float = 0.0

kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
num_shards, uniform_init_lower, uniform_init_upper
num_shards,
uniform_init_lower,
uniform_init_upper,
False, # disable_random_init
)
kv_embedding_cache.init(
[(20, 4, SparseType.INT8.as_int())],
Expand Down Expand Up @@ -122,7 +131,10 @@ def test_inplace_update(self) -> None:
uniform_init_upper: float = 0.0

kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
num_shards, uniform_init_lower, uniform_init_upper
num_shards,
uniform_init_lower,
uniform_init_upper,
False, # disable_random_init
)
kv_embedding_cache.init(
[(20, 4, SparseType.INT8.as_int())],
Expand Down Expand Up @@ -258,7 +270,10 @@ def test_randomized_cache_miss_initialization(self) -> None:

# Create DRAM KV inference cache
kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
num_shards, uniform_init_lower, uniform_init_upper
num_shards,
uniform_init_lower,
uniform_init_upper,
False, # disable_random_init
)
kv_embedding_cache.init(
[(32, 4, SparseType.FP16.as_int())],
Expand Down Expand Up @@ -313,3 +328,60 @@ def test_randomized_cache_miss_initialization(self) -> None:
torch.any(result[:, :4] != 0),
"Cache miss results should contain non-zero values when cache has data",
)

def test_zero_cache_miss_initialization_with_embedding_cache_mode(self) -> None:
"""Test that cache misses return all zero values when embedding_cache_mode=True."""
num_shards = 8
uniform_init_lower: float = -0.01
uniform_init_upper: float = 0.01

# Setup: Create DRAM KV inference cache with embedding_cache_mode=True (zero initialization)
kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
num_shards,
uniform_init_lower,
uniform_init_upper,
True, # embedding_cache_mode=True for zero initialization
)
kv_embedding_cache.init(
[(32, 4, SparseType.FP16.as_int())],
32,
4,
torch.tensor([0, 100], dtype=torch.int64),
)

# Populate the cache with some initial non-zero values to ensure zero initialization
# is not just due to empty cache
setup_indices = torch.arange(0, 50, dtype=torch.int64)
setup_weights = torch.randint(
1, 255, (50, 32), dtype=torch.uint8
) # Non-zero values
kv_embedding_cache.set_embeddings(setup_indices, setup_weights)

# Execute: Request cache misses - these should get zero initialization due to embedding_cache_mode=True
# Use indices outside the range [0, 49] to ensure they are actual cache misses
miss_indices = torch.tensor([100, 101, 102, 103, 104], dtype=torch.int64)
results = []

# Get cache miss results multiple times to ensure consistent behavior
for _ in range(3):
current_output = kv_embedding_cache.get_embeddings(miss_indices)
results.append(current_output.clone())

# Assert: Verify that all cache miss results are zeros when embedding_cache_mode=True
expected_zeros = torch.zeros((5, 32), dtype=torch.uint8)

for i, result in enumerate(results):
# Check that all cache miss results are zero
self.assertTrue(
torch.equal(result, expected_zeros),
f"Cache miss results should be all zeros when embedding_cache_mode=True, "
f"but got non-zero values in iteration {i}: {result[:, :4]}",
)

# Additional verification: all results should be identical since they're all zeros
for i in range(1, len(results)):
self.assertTrue(
torch.equal(results[0], results[i]),
f"All zero cache miss results should be identical across calls, "
f"but results[0] != results[{i}]",
)
Loading