Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
48dc7ce
[Executorch] Add quantized kv cache to oss ci
kimishpatel Nov 20, 2024
1c93b02
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Nov 21, 2024
0c7d99e
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Nov 21, 2024
128e461
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Nov 21, 2024
a943088
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Nov 22, 2024
73c277a
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Nov 22, 2024
5cf2c56
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Nov 22, 2024
f38e8b4
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Nov 22, 2024
37f0480
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Nov 22, 2024
05232bb
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Nov 22, 2024
47cfa2a
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Nov 22, 2024
f37d62b
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Nov 22, 2024
5063dcf
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Nov 23, 2024
fcc6efd
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Nov 23, 2024
cee6d9f
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Nov 23, 2024
982bc6a
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Nov 23, 2024
10d0292
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Nov 27, 2024
ba4083d
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Nov 28, 2024
456b904
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Nov 29, 2024
98f223b
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Dec 4, 2024
c984a6e
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Dec 4, 2024
b95fea7
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Dec 4, 2024
9d97753
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Dec 4, 2024
753f87f
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Dec 4, 2024
643086c
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Dec 5, 2024
6f1efc5
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Dec 5, 2024
082b308
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Dec 5, 2024
e49b3ad
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Dec 5, 2024
b015d80
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Dec 5, 2024
b39b8a7
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Dec 5, 2024
e605bf2
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Dec 5, 2024
1e2b684
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Dec 5, 2024
42ad61a
Update on "[Executorch] Add quantized kv cache to oss ci"
kimishpatel Dec 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .ci/scripts/test_llama.sh
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ else
COREML=OFF
fi

if [[ "${MODE}" =~ .*quantize_kv.* ]]; then
QUANTIZE_KV_CACHE=ON
else
QUANTIZE_KV_CACHE=OFF
fi

echo "COREML option ${COREML}"

if [[ "${MODE}" =~ .*qnn.* ]]; then
Expand Down Expand Up @@ -249,6 +255,9 @@ if [[ "${QNN}" == "ON" ]]; then
EXPORT_ARGS+=" --tokenizer_path tokenizer.model --pt2e_quantize qnn_16a16w --calibration_tasks wikitext --calibration_limit 1 --calibration_seq_length 128 --calibration_data Once "
fi
fi
if [[ "${QUANTIZE_KV_CACHE}" == "ON" ]]; then
EXPORT_ARGS="${EXPORT_ARGS} --quantize_kv_cache"
fi
# Add dynamically linked library location
$PYTHON_EXECUTABLE -m examples.models.llama.export_llama ${EXPORT_ARGS}

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ jobs:
strategy:
matrix:
dtype: [fp32]
mode: [portable, xnnpack+custom, xnnpack+custom+qe]
mode: [portable, xnnpack+custom, xnnpack+custom+qe,xnnpack+custom+quantize_kv,xnnpack+quantize_kv]
include:
- dtype: bf16
mode: portable
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/trunk.yml
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ jobs:
strategy:
matrix:
dtype: [fp32]
mode: [portable, xnnpack+kv+custom, mps, coreml]
mode: [portable, xnnpack+kv+custom, mps, coreml, xnnpack+custom+quantize_kv]
include:
- dtype: bf16
mode: portable
Expand Down
28 changes: 28 additions & 0 deletions examples/models/llama/source_transformation/quantized_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.nn as nn
from executorch.examples.models.llama.llama_transformer import KVCache

from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401


Expand Down Expand Up @@ -221,6 +222,33 @@ def from_float(cls, kv_cache, cache_type: QuantizedCacheType):


def replace_kv_cache_with_quantized_kv_cache(module):
try:
op = torch.ops.quantized_decomposed.quantize_per_token.out
assert op is not None
except:
import glob

import executorch
from executorch.extension.pybindings import portable_lib # noqa # usort: skip

# Ideally package is installed in only one location but usage of
# PYATHONPATH can result in multiple locations.
# ATM this is mainly used in CI for qnn runner. Will need to revisit this
executorch_package_path = executorch.__path__[-1]
libs = list(
glob.glob(
f"{executorch_package_path}/**/libquantized_ops_aot_lib.*",
recursive=True,
)
)
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
logging.info(f"Loading custom ops library: {libs[0]}")
torch.ops.load_library(libs[0])
op = torch.ops.quantized_decomposed.quantize_per_token.out
assert op is not None
# This is needed to ensure that custom ops are registered
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401

logging.warning(
"Replacing KVCache with QuantizedKVCache. This modifies the model in place."
)
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llama/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def forward(

k_cache = self.kv_cache.k_cache
v_cache = self.kv_cache.v_cache
if isinstance(self.kv_cache, QuantizedKVCache):
if hasattr(self.kv_cache, "quantized_cache_dtype"):
# updated quantize cache, scale and zero points
# returns dequantized kv cache
# Not most optimal. Optimizations to follow next
Expand Down
8 changes: 7 additions & 1 deletion extension/llm/custom_ops/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@

import executorch

executorch_package_path = executorch.__path__[0]
# This is needed to ensure that custom ops are registered
from executorch.extension.pybindings import portable_lib # noqa # usort: skip

# Ideally package is installed in only one location but usage of
# PYATHONPATH can result in multiple locations.
# ATM this is mainly used in CI for qnn runner. Will need to revisit this
executorch_package_path = executorch.__path__[-1]
logging.info(f"Looking for libcustom_ops_aot_lib.so in {executorch_package_path}")
libs = list(
glob.glob(
Expand Down
3 changes: 3 additions & 0 deletions kernels/quantized/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,17 @@ if(NOT CMAKE_GENERATOR STREQUAL "Xcode"
set(_quantized_aot_ops
"quantized_decomposed::add.out"
"quantized_decomposed::choose_qparams.Tensor_out"
"quantized_decomposed::choose_qparams_per_token_asymmetric.out"
"quantized_decomposed::dequantize_per_channel.out"
"quantized_decomposed::dequantize_per_tensor.out"
"quantized_decomposed::dequantize_per_tensor.Tensor_out"
"quantized_decomposed::dequantize_per_token.out"
"quantized_decomposed::mixed_linear.out"
"quantized_decomposed::mixed_mm.out"
"quantized_decomposed::quantize_per_channel.out"
"quantized_decomposed::quantize_per_tensor.out"
"quantized_decomposed::quantize_per_tensor.Tensor_out"
"quantized_decomposed::quantize_per_token.out"
)
gen_selected_ops(
LIB_NAME "quantized_ops_aot_lib" ROOT_OPS ${_quantized_aot_ops}
Expand Down