Skip to content

Commit abd21e7

Browse files
committed
[Executorch] Add quantized kv cache to oss ci
Pull Request resolved: #6997 Fixes to make sure quantized kv cache works in oss ghstack-source-id: 256616977 @exported-using-ghexport Differential Revision: [D66269487](https://our.internmc.facebook.com/intern/diff/D66269487/)
1 parent 28e2a90 commit abd21e7

File tree

8 files changed

+50
-11
lines changed

8 files changed

+50
-11
lines changed

.ci/scripts/test_llama.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,12 @@ else
110110
COREML=OFF
111111
fi
112112

113+
if [[ "${MODE}" =~ .*quantize_kv.* ]]; then
114+
QUANTIZE_KV_CACHE=ON
115+
else
116+
QUANTIZE_KV_CACHE=OFF
117+
fi
118+
113119
echo "COREML option ${COREML}"
114120

115121
if [[ "${MODE}" =~ .*qnn.* ]]; then
@@ -249,6 +255,9 @@ if [[ "${QNN}" == "ON" ]]; then
249255
EXPORT_ARGS+=" --tokenizer_path tokenizer.model --pt2e_quantize qnn_16a16w --calibration_tasks wikitext --calibration_limit 1 --calibration_seq_length 128 --calibration_data Once "
250256
fi
251257
fi
258+
if [[ "${QUANTIZE_KV_CACHE}" == "ON" ]]; then
259+
EXPORT_ARGS="${EXPORT_ARGS} --quantize_kv_cache"
260+
fi
252261
# Add dynamically linked library location
253262
$PYTHON_EXECUTABLE -m examples.models.llama.export_llama ${EXPORT_ARGS}
254263

.github/workflows/pull.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ jobs:
8686
strategy:
8787
matrix:
8888
dtype: [fp32]
89-
mode: [portable, xnnpack+custom, xnnpack+custom+qe]
89+
mode: [portable, xnnpack+custom, xnnpack+custom+qe,xnnpack+custom+quantize_kv,xnnpack+quantize_kv]
9090
include:
9191
- dtype: bf16
9292
mode: portable

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ jobs:
225225
strategy:
226226
matrix:
227227
dtype: [fp32]
228-
mode: [portable, xnnpack+kv+custom, mps, coreml]
228+
mode: [portable, xnnpack+kv+custom, mps, coreml, xnnpack+custom+quantize_kv]
229229
include:
230230
- dtype: bf16
231231
mode: portable

examples/models/llama/source_transformation/quantized_kv_cache.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,18 @@
1010
import torch
1111
import torch.nn as nn
1212
from executorch.examples.models.llama.llama_transformer import KVCache
13+
1314
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
1415

1516

17+
try:
18+
op = torch.ops.quantized_decomposed.quantize_per_token.out
19+
assert op is not None
20+
except:
21+
import executorch.kernels.quantized # noqa: F401
22+
op = torch.ops.quantized_decomposed.quantize_per_token.out
23+
assert op is not None
24+
1625
"""
1726
Heavily "inspired" by AO's implementation of the same in torchao/_models/llama/model.py
1827
"""
@@ -221,6 +230,8 @@ def from_float(cls, kv_cache, cache_type: QuantizedCacheType):
221230

222231

223232
def replace_kv_cache_with_quantized_kv_cache(module):
233+
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
234+
224235
logging.warning(
225236
"Replacing KVCache with QuantizedKVCache. This modifies the model in place."
226237
)

examples/models/llama/source_transformation/sdpa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def forward(
5656

5757
k_cache = self.kv_cache.k_cache
5858
v_cache = self.kv_cache.v_cache
59-
if isinstance(self.kv_cache, QuantizedKVCache):
59+
if hasattr(self.kv_cache, "quantized_cache_dtype"):
6060
# updated quantize cache, scale and zero points
6161
# returns dequantized kv cache
6262
# Not most optimal. Optimizations to follow next

extension/llm/custom_ops/custom_ops.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,13 @@
2626

2727
import executorch
2828

29-
executorch_package_path = executorch.__path__[0]
29+
# This is needed to ensure that custom ops are registered
30+
from executorch.extension.pybindings import portable_lib # noqa # usort: skip
31+
32+
# Ideally package is installed in only one location but usage of
33+
# PYATHONPATH can result in multiple locations.
34+
# ATM this is mainly used in CI for qnn runner. Will need to revisit this
35+
executorch_package_path = executorch.__path__[-1]
3036
logging.info(f"Looking for libcustom_ops_aot_lib.so in {executorch_package_path}")
3137
libs = list(
3238
glob.glob(

kernels/quantized/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,17 @@ if(NOT CMAKE_GENERATOR STREQUAL "Xcode"
6060
set(_quantized_aot_ops
6161
"quantized_decomposed::add.out"
6262
"quantized_decomposed::choose_qparams.Tensor_out"
63+
"quantized_decomposed::choose_qparams_per_token_asymmetric.out"
6364
"quantized_decomposed::dequantize_per_channel.out"
6465
"quantized_decomposed::dequantize_per_tensor.out"
6566
"quantized_decomposed::dequantize_per_tensor.Tensor_out"
67+
"quantized_decomposed::dequantize_per_token.out"
6668
"quantized_decomposed::mixed_linear.out"
6769
"quantized_decomposed::mixed_mm.out"
6870
"quantized_decomposed::quantize_per_channel.out"
6971
"quantized_decomposed::quantize_per_tensor.out"
7072
"quantized_decomposed::quantize_per_tensor.Tensor_out"
73+
"quantized_decomposed::quantize_per_token.out"
7174
)
7275
gen_selected_ops(
7376
LIB_NAME "quantized_ops_aot_lib" ROOT_OPS ${_quantized_aot_ops}

kernels/quantized/__init__.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,26 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
8+
79
try:
8-
from pathlib import Path
10+
import glob
911

10-
libs = list(Path(__file__).parent.resolve().glob("**/libquantized_ops_aot_lib.*"))
11-
del Path
12-
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
13-
import torch as _torch
12+
import torch
13+
import executorch
1414

15-
_torch.ops.load_library(libs[0])
16-
del _torch
15+
# Ideally package is installed in only one location but usage of
16+
# PYATHONPATH can result in multiple locations.
17+
# ATM this is mainly used in CI for qnn runner. Will need to revisit this
18+
executorch_package_path = executorch.__path__[-1]
19+
libs = list(
20+
glob.glob(
21+
f"{executorch_package_path}/**/libquantized_ops_aot_lib.*", recursive=True
22+
)
23+
)
24+
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
25+
logging.info(f"Loading custom ops library: {libs[0]}")
26+
torch.ops.load_library(libs[0])
1727
except:
1828
import logging
1929

0 commit comments

Comments
 (0)