Add KVZCH inference read-time hit rate metrics via fb303 ODS counters#5745
Closed
hy-NJU wants to merge 1 commit into
Closed
Add KVZCH inference read-time hit rate metrics via fb303 ODS counters#5745hy-NJU wants to merge 1 commit into
hy-NJU wants to merge 1 commit into
Conversation
Summary: X-link: facebookresearch/FBGEMM#2675 Re-land of D101879296 (reverted by D104116880 — see S659921), C++-only this time. ## Summary KVZCH embedding cache on the serving/inference side lacked structured metrics for read-time (forward pass) hit rate. The C++ backend already tracked per-shard miss counts internally but never emitted them as ODS counters, and there was no total read count, making hit rate calculation impossible. This diff adds: - Atomic `read_hit_count_` / `read_miss_count_` counters in `DramKVInferenceEmbedding` - Per-batch fb303 ODS counter emission (`kvzch.inference.read_hit_count`, `kvzch.inference.read_miss_count`, `kvzch.inference.read_total_count`) in the `get_kv_db_async_impl` aggregation callback - `get_read_hit_rate_stats()` virtual method on `KVInferenceEmbeddingInterface`, implemented in `DramKVInferenceEmbedding`, no-op stub on `SSDKVInferenceEmbedding` - `DramKVEmbeddingInferenceWrapper::get_read_hit_rate_stats` C++ method + TorchScript registration so the predictor binary recognizes any model that exports the method - `kvzch.inference.*` regex added to `PredictorXUtils.cpp` fb303Collector allowlist so ODS scrapes the new counters - `//fb303:service_data` BUCK dep on `dram_kv_embedding_inference` - One-time `addStatExportType(SUM)` registration via `folly::call_once` so counters surface in ODS as `*.sum.60` time-series Performance impact is minimal: one stack increment per lookup in the hit branch, two atomic fetch-adds and three `addStatValue` calls per batch. ## Why this is C++-only (vs original D101879296) Original D101879296 also added a Python `torch.jit.export def get_read_hit_rate_stats(...)` on `KVEmbeddingInference`. Models trained after the diff baked the method into their TorchScript graph, but stale predictor binaries on `vg_worker_byoc_t16_gti_mrs_trunk_health_prod` (and `sigrid_predictor_gpu.persistent:prod` from Apr 22) lacked the C++ TorchScript registration. Models loaded against those predictors crashed with: `torch::jit::ErrorReport: torch.torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper object has no attribute or method get_read_hit_rate_stats` Root cause: `D101879296` shipped the Python TorchScript-export side and the C++ TorchScript-registration side in a single diff with no coordinated predictor-binary rollout. Newly-trained models required the new C++ class registration, but no predictor binary anywhere had it yet — neither `:prod` (v5529, Apr 22) nor `:LATEST` (v5704, Apr 29) of `sigrid_predictor_gpu.persistent` was built at a revision >= the landing commit. This caused S659921 (mvai/video_udd_lsr serving_eval blocked on R5449-R5450). This re-land drops the Python `torch.jit.export` side entirely. ODS counter emission lives in C++ unconditionally — it does NOT depend on any Python caller invoking `get_read_hit_rate_stats()`. The fb303 `addStatValue` calls fire on every batch automatically. So: - ODS gets the metrics — same end goal as the original diff - No model TorchScript graph changes — no stale-predictor crash possible - No predictor-binary wait gate needed — even predictor binaries built before this diff will load every existing and future model normally, because no model exports a new method that requires the new C++ registration If a Python `get_read_hit_rate_stats()` programmatic API is ever needed, it can ship in a follow-up diff AFTER predictor binaries on every KVZCH serving tier are rolled out at a revision >= this diff's landing commit. ## Note on counter consistency (per AI reviewer feedback on D101879296) The two `exchange(0)` calls on `read_hit_count_` and `read_miss_count_` in `get_read_hit_rate_stats()` are not a combined snapshot — between the two exchanges, a concurrent batch callback may add to `read_miss_count_` before it is exchanged, so the returned hit and miss values may correspond to slightly different time windows. This is acceptable for ODS hit-rate aggregation (where individual sub-second slices don't matter) and is now documented in code. If a strict snapshot is needed in the future, both exchanges can be wrapped in a small mutex. Build verification: - buck build fbcode//deeplearning/fbgemm/fbgemm_gpu:dram_kv_embedding_inference - buck build fbcode//caffe2/caffe2/fb/predictor/embedding_db/kv_embedding_table:SSDKVInferenceEmbedding - buck build fbcode//fblearner/predictor/model_publishing_service/deployment:predictor_x_utils S659921 regression check: - After landing, run a Vanguard serving_eval against vg_worker_byoc_t16_gti_mrs_trunk_health_prod on a freshly-trained model and confirm no TorchScript crash on the previously-failing test case: https://www.internalfb.com/vanguard/serving_test_cases/1302024888565107 ## Related - Original (reverted): D101879296 - Revert: D104116880 - SEV: S659921 — [mvai/video_udd_lsr] Vanguard serving_eval predictor crash — DramKVEmbeddingInferenceWrapper missing get_read_hit_rate_stats on trunk health predictor - Pull Request resolved: pytorch#5730 - Pull Request resolved: facebookresearch/FBGEMM#2659 Reviewed By: EddyLXJ, emlin Differential Revision: D104246537
Contributor
|
@hy-NJU has exported this pull request. If you are a Meta employee, you can view the originating Diff in D104246537. |
Contributor
|
This pull request has been merged in e21ad44. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
X-link: https://github.com/facebookresearch/FBGEMM/pull/2675
Re-land of D101879296 (reverted by D104116880 — see S659921), C++-only this time.
Summary
KVZCH embedding cache on the serving/inference side lacked structured metrics for
read-time (forward pass) hit rate. The C++ backend already tracked per-shard miss
counts internally but never emitted them as ODS counters, and there was no total
read count, making hit rate calculation impossible.
This diff adds:
read_hit_count_/read_miss_count_counters inDramKVInferenceEmbeddingkvzch.inference.read_hit_count,kvzch.inference.read_miss_count,kvzch.inference.read_total_count) in theget_kv_db_async_implaggregation callbackget_read_hit_rate_stats()virtual method onKVInferenceEmbeddingInterface,implemented in
DramKVInferenceEmbedding, no-op stub onSSDKVInferenceEmbeddingDramKVEmbeddingInferenceWrapper::get_read_hit_rate_statsC++ method + TorchScriptregistration so the predictor binary recognizes any model that exports the method
kvzch.inference.*regex added toPredictorXUtils.cppfb303Collector allowlistso ODS scrapes the new counters
//fb303:service_dataBUCK dep ondram_kv_embedding_inferenceaddStatExportType(SUM)registration viafolly::call_onceso counterssurface in ODS as
*.sum.60time-seriesPerformance impact is minimal: one stack increment per lookup in the hit branch,
two atomic fetch-adds and three
addStatValuecalls per batch.Why this is C++-only (vs original D101879296)
Original D101879296 also added a Python
torch.jit.export def get_read_hit_rate_stats(...)on
KVEmbeddingInference. Models trained after the diff baked the method into theirTorchScript graph, but stale predictor binaries on
vg_worker_byoc_t16_gti_mrs_trunk_health_prod(and
sigrid_predictor_gpu.persistent:prodfrom Apr 22) lacked the C++ TorchScriptregistration. Models loaded against those predictors crashed with:
torch::jit::ErrorReport: torch.torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper object has no attribute or method get_read_hit_rate_statsRoot cause:
D101879296shipped the Python TorchScript-export side and the C++TorchScript-registration side in a single diff with no coordinated predictor-binary
rollout. Newly-trained models required the new C++ class registration, but no
predictor binary anywhere had it yet — neither
:prod(v5529, Apr 22) nor:LATEST(v5704, Apr 29) of
sigrid_predictor_gpu.persistentwas built at a revision >= thelanding commit. This caused S659921 (mvai/video_udd_lsr serving_eval blocked on
R5449-R5450).
This re-land drops the Python
torch.jit.exportside entirely. ODS counteremission lives in C++ unconditionally — it does NOT depend on any Python caller
invoking
get_read_hit_rate_stats(). The fb303addStatValuecalls fire on everybatch automatically. So:
this diff will load every existing and future model normally, because no model
exports a new method that requires the new C++ registration
If a Python
get_read_hit_rate_stats()programmatic API is ever needed, it canship in a follow-up diff AFTER predictor binaries on every KVZCH serving tier
are rolled out at a revision >= this diff's landing commit.
Note on counter consistency (per AI reviewer feedback on D101879296)
The two
exchange(0)calls onread_hit_count_andread_miss_count_inget_read_hit_rate_stats()are not a combined snapshot — between the twoexchanges, a concurrent batch callback may add to
read_miss_count_before it isexchanged, so the returned hit and miss values may correspond to slightly different
time windows. This is acceptable for ODS hit-rate aggregation (where individual
sub-second slices don't matter) and is now documented in code. If a strict snapshot
is needed in the future, both exchanges can be wrapped in a small mutex.
Build verification:
S659921 regression check:
vg_worker_byoc_t16_gti_mrs_trunk_health_prod on a freshly-trained model and
confirm no TorchScript crash on the previously-failing test case:
https://www.internalfb.com/vanguard/serving_test_cases/1302024888565107
Related
DramKVEmbeddingInferenceWrapper missing get_read_hit_rate_stats on trunk health
predictor
Reviewed By: EddyLXJ, emlin
Differential Revision: D104246537