From 734078c7f8ec5626ef91f8c9d79d86ef5507b834 Mon Sep 17 00:00:00 2001 From: Joey Tsai Date: Mon, 9 Sep 2024 12:58:26 +0800 Subject: [PATCH 1/2] Qualcomm AI Engine Direct - Add llama io be quantized - Add general function to tag io obtain/genetate quantized tensor - Add quantizing io function to llama2.py --- .../qualcomm/quantizer/custom_annotation.py | 34 +++++++++++++++++++ backends/qualcomm/utils/utils.py | 10 ++++++ examples/models/llama/export_llama_lib.py | 19 +++++++++-- 3 files changed, 61 insertions(+), 2 deletions(-) diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index 881d24bbb5e..79b17758a3f 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -12,6 +12,8 @@ QuantizationConfig, ) from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.extension.llm.export.builder import LLMEdgeManager from torch.ao.quantization.quantizer import ( QuantizationAnnotation, SharedQuantizationSpec, @@ -144,3 +146,35 @@ def annotate_matmul(node: Node, quantization_config: QuantizationConfig): for node in gm.graph.nodes: if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: annotate_matmul(node, quantization_config_16a8w) + + +def get_custom_quant_ios_dtype( + cache_shape: torch.Size, + node: torch.fx.Node, + kv_dtype=torch.uint8, + sharding_dtype=torch.uint16, +): + """ + This function is specific for llama inputs and outputs + """ + if node.op == "placeholder" and "attention_sdpa_kv_cache_past_" in node.name: + return kv_dtype + + # Tag index put node before copy node, because copy is a skipped node in qnn + if ( + exir_ops.edge.aten.index_put.default == node.target + and node.meta["val"].shape == cache_shape + ): + return kv_dtype + + # Tag sharding io + if exir_ops.edge.llama.fallback.default in [ + u.target for u in list(node.users.keys()) + ] + [node.target]: + return sharding_dtype + + # Tag index op as quantized tensors. It is caused by sharding + if exir_ops.edge.aten.index.Tensor in [ + u.target for u in list(node.users.keys()) + ] + [node.target]: + return sharding_dtype diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 88a84f2f9a6..30e04750b58 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -71,6 +71,7 @@ QCOM_PASS_EXPAND_BROADCAST_SHAPE, QCOM_PASS_SKIP_ADVANCED_REQUANT, QCOM_QNN_COMPILE_SPEC, + QCOM_QUANTIZED_IO, ) from executorch.exir import ExirExportedProgram @@ -876,3 +877,12 @@ def get_soc_to_chipset_map(): "SM8475": QcomChipset.SM8475, "SM8450": QcomChipset.SM8450, } + + +def tag_quant_io(gm: torch.fx.GraphModule, get_quant_io_dtype_fn: Callable): + """ + Tag io nodes which get/output quantized tensor. No need to insert q/dq in qnn_preprocess + """ + for node in gm.graph.nodes: + if dtype := get_quant_io_dtype_fn(node): + node.meta[QCOM_QUANTIZED_IO] = dtype diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 04bd5bddaaf..e4ed14ae648 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -643,8 +643,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 ) ) # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` - from executorch.backends.qualcomm.utils.utils import _transform - + from executorch.backends.qualcomm.utils.utils import _transform, tag_quant_io # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program` _transform(builder_exported_to_edge.edge_manager.exported_program()) @@ -656,6 +655,22 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 shares=args.num_sharding, ) + from functools import partial + + from executorch.backends.qualcomm.quantizer.custom_annotation import ( + get_custom_quant_ios_dtype, + ) + + tag_quant_io( + builder_exported_to_edge.edge_manager.exported_program().graph_module, + partial( + get_custom_quant_ios_dtype, + builder_exported_to_edge.model.layers[ + 0 + ].attention.kv_cache.past_k_caches.shape, + ), + ) + logging.info("Lowering model using following partitioner(s): ") for partitioner in partitioners: logging.info(f"--> {partitioner.__class__.__name__}") From 57846dab14c3bc31be3df1192db353d408f138a5 Mon Sep 17 00:00:00 2001 From: Joey Tsai Date: Mon, 21 Oct 2024 11:44:00 +0800 Subject: [PATCH 2/2] [Fix lint] --- backends/qualcomm/quantizer/custom_annotation.py | 1 - examples/models/llama/export_llama_lib.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index 79b17758a3f..db82172a9e2 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -13,7 +13,6 @@ ) from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY from executorch.exir.dialects._ops import ops as exir_ops -from executorch.extension.llm.export.builder import LLMEdgeManager from torch.ao.quantization.quantizer import ( QuantizationAnnotation, SharedQuantizationSpec, diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index e4ed14ae648..d691ffe7433 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -644,6 +644,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 ) # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` from executorch.backends.qualcomm.utils.utils import _transform, tag_quant_io + # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program` _transform(builder_exported_to_edge.edge_manager.exported_program())