Fix QNN runner KV cache bitwidth in Android JNI#17622
Fix QNN runner KV cache bitwidth in Android JNI#17622infil00p wants to merge 1 commit intopytorch:mainfrom
Conversation
The QNN runner was hardcoded to use Runner<uint16_t>, but all current Llama quantization recipes use annotate_kv_8bit for 8-bit KV cache. This mismatch caused the KV cache data to be misinterpreted, resulting in degenerate output (repetitive real words like "Nigeria Nigeria...") while the model otherwise ran correctly on the HTP NPU. Authored with Claude Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17622
Note: Links to docs will display an error until the docs builds have been completed. ❌ 7 New Failures, 1 Cancelled Job, 5 Unrelated FailuresAs of commit af1f892 with merge base 298311e ( NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
| executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors); | ||
| std::string decoder_model = "llama3"; // use llama3 for now | ||
| runner_ = std::make_unique<example::Runner<uint16_t>>( // QNN runner | ||
| runner_ = std::make_unique<example::Runner<uint8_t>>( // QNN runner (8-bit KV cache) |
There was a problem hiding this comment.
I think we need a better way to handle this..I remember llama uses 16 bit kv cache and qwen uses 8 bit kv cache cc: @haowhsu-quic
There was a problem hiding this comment.
Yes, probably need a branch here to dispatch runner correctly.
There was a problem hiding this comment.
@haowhsu-quic Ccan we detect it by dynamically querying get_kv_io_bit_width from the model if the method exists and do something like (default to 8 bit):
example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8;
if (module->method_names()->count("get_kv_io_bit_width") > 0) {
kv_bitwidth = static_cast<example::KvBitWidth>(
module->get("get_kv_io_bit_width")
.get()
.toScalar()
.to<int64_t>());
}
if (kv_bitwidth == example::KvBitWidth::kWidth16) {
runner_ = std::make_unique<example::Runner<uint16_t>>(...)
} else {
runner_ = std::make_unique<example::Runner<uint8_t>>(...)
}
There was a problem hiding this comment.
@haowhsu-quic there an update on dynamically querying the bit-width? Is there anything to recommend for the PR's author to do?
There was a problem hiding this comment.
Hi @abhinaykukkadapu , we actually do this in our runner (https://github.com/pytorch/executorch/blob/main/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp).
example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8;
if (module->method_names()->count("get_kv_io_bit_width") > 0) {
kv_bitwidth = static_cast<example::KvBitWidth>(
module->get("get_kv_io_bit_width").get().toScalar().to<int64_t>());
}Maybe PR's author can try to follow this apporach.
There was a problem hiding this comment.
@infil00p could you try this approach and request review once done?
Summary: The QNN runner in the Android JNI layer was hardcoded to use Runner<uint16_t>, but models can be exported with either 8-bit or 16-bit KV caches. This mismatch caused the KV cache data to be misinterpreted, resulting in gibberish output in the Android demo app while the same model worked correctly via the CLI runner. This change mirrors the dynamic KV bitwidth detection already present in qnn_llama_runner.cpp by querying the model's get_kv_io_bit_width method and instantiating the correct Runner<uint8_t> or Runner<uint16_t> accordingly. Also passes temperature_ to the Runner constructor which was previously omitted. Fixes #18571 Closes #17622 Test Plan: - Built Android AAR with QNN support (SDK 2.37) — jni_layer_llama.cpp compiles cleanly with both Runner<uint8_t> and Runner<uint16_t> template instantiations - Unit tests pass (gradlew testDebugUnitTest)
Summary: The QNN runner in the Android JNI layer was hardcoded to use Runner<uint16_t>, but models can be exported with either 8-bit or 16-bit KV caches. This mismatch caused the KV cache data to be misinterpreted, resulting in gibberish output in the Android demo app while the same model worked correctly via the CLI runner. This change mirrors the dynamic KV bitwidth detection already present in qnn_llama_runner.cpp by querying the model's get_kv_io_bit_width method and instantiating the correct Runner<uint8_t> or Runner<uint16_t> accordingly. Also passes temperature_ to the Runner constructor which was previously omitted. Fixes pytorch#18571 Closes pytorch#17622 Test Plan: - Built Android AAR with QNN support (SDK 2.37) — jni_layer_llama.cpp compiles cleanly with both Runner<uint8_t> and Runner<uint16_t> template instantiations - Unit tests pass (gradlew testDebugUnitTest)
Summary
The QNN runner was hardcoded to use Runner<uint16_t>, but all current Llama quantization recipes use annotate_kv_8bit for 8-bit KV cache. This mismatch caused the KV cache data to be misinterpreted, resulting in degenerate output (repetitive real words like "Nigeria Nigeria...") while the model otherwise ran correctly on the HTP NPU.
Authored with Claude
Test plan
I built the AAR and used the AAR in the LlamaDemo in executorch-examples. This does require QAIRT 2.43.0 to test, which is later than the last QAIRT used to build the last Executorch release. This was tested on a OnePlus 15 running a Snapdragon 8 Elite Gen 5 (SM8850).
cc @kirklandsign @cbilgin @cccclai