From 4980305f9a76dd42ccc7f9d2700620f08b5e3ce5 Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Sun, 14 Apr 2024 11:13:54 -0700 Subject: [PATCH 1/3] qnn end to end flow Patch a few changes including: - support bool tensor type - support fp16 and fix the 8w8a quantization. - add two non-supported ops (slice_scatter and index_put) in common_defs.py stories model working end to end: AOT: fp16: ``` python -m examples.models.llama2.export_llama -kv --qnn -c stories110M.pt -p params.json ``` quantize: ``` python -m examples.models.llama2.export_llama -kv --qnn --pt2e_quantize -c stories110M.pt -p params.json ``` Runtime: ``` /llama_main --model_path=llama2_fp16_qnn_2.21.pte --tokenizer_path=tokenizer.bin --prompt="Once" ``` Output: ``` Once upon a time, there was a boy named Tim. Tim had a pet dog named Max. Max was a big, strong dog. They liked to play and run in the park. One day, Tim and Max went to the park to play. They saw a cat. The cat was up in a tree. Max wanted to help the cat. He tried to climb the tree, but he could not. Then, something unexpected happened. Max started to climb the tree! He was very strong. Max helped the cat come down. The cat was happy. Tim was so proud of his pet. ``` Stories model is too small and sensitive to qunatization. Differential Revision: [D56119738](https://our.internmc.facebook.com/intern/diff/D56119738/) [ghstack-poisoned] --- backends/qualcomm/builders/node_visitor.py | 1 + backends/qualcomm/partition/common_defs.py | 2 ++ examples/models/llama2/export_llama_lib.py | 12 ++++++++---- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index 3dae32f882e..3f40dc56737 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -29,6 +29,7 @@ QNN_uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16, } QNN_TENSOR_TYPE_MAP = { + torch.bool: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, torch.float32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_8, torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_16, diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index b06a5766a63..36a2986f09a 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -13,6 +13,8 @@ exir_ops.edge.aten.clone.default, exir_ops.edge.aten.index.Tensor, exir_ops.edge.aten.full.default, + exir_ops.edge.aten.slice_scatter.default, + exir_ops.edge.aten.index_put.default, ] allow_list_operator = [ diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 6bfe53de208..47f7095ec52 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -19,6 +19,7 @@ import pkg_resources import torch +import torch.nn.functional as F from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( XnnpackDynamicallyQuantizedPartitioner, @@ -34,7 +35,6 @@ from executorch.sdk.etrecord import generate_etrecord from executorch.util.activation_memory_profiler import generate_memory_trace from sentencepiece import SentencePieceProcessor -from torch.nn import functional as F from .builder import DType, LlamaEdgeManager, load_llama_model, WeightType from .quant_lib import _get_pt2e_quantization_params, get_pt2e_quantizers @@ -607,6 +607,8 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager: if args.use_sdpa_with_kv_cache: transforms.append(replace_sdpa_with_custom_op) + if args.qnn and args.use_kv_cache: + transforms.append(replace_sdpa_with_simple_sdpa) return ( load_llama_model( checkpoint=checkpoint_path, @@ -629,7 +631,7 @@ def _export_llama(modelname, args) -> str: # noqa: C901 # export_to_edge pt2e_quant_params = _get_pt2e_quantization_params(args) quantizers = get_pt2e_quantizers(pt2e_quant_params, args) - if args.qnn: + if args.qnn and args.pt2e_quantize: assert ( args.quantization_mode is None ), "Currently qnn backend only supports QnnQuantizer via pt2e flow" @@ -763,7 +765,9 @@ def _export_llama(modelname, args) -> str: # noqa: C901 ) # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm` - backend_options = generate_htp_compiler_spec(use_fp16=False) + backend_options = generate_htp_compiler_spec( + use_fp16=False if args.pt2e_quantize else True + ) partitioners.append( # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm` QnnPartitioner( @@ -780,7 +784,7 @@ def _export_llama(modelname, args) -> str: # noqa: C901 ) ) # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm` - _transform(builder_exported_to_edge.export_program()) + _transform(builder_exported_to_edge.edge_manager.exported_program()) if args.generate_etrecord: if not builder_exported_to_edge.edge_manager: From 0c114593a5e6dad79f89606bedda16c928bb5e8f Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Sun, 14 Apr 2024 11:23:09 -0700 Subject: [PATCH 2/3] Update on "qnn end to end flow" Patch a few changes including: - support bool tensor type - support fp16 and fix the 8w8a quantization. - add two non-supported ops (slice_scatter and index_put) in common_defs.py stories model working end to end: AOT: fp16: ``` python -m examples.models.llama2.export_llama -kv --qnn -c stories110M.pt -p params.json ``` quantize: ``` python -m examples.models.llama2.export_llama -kv --qnn --pt2e_quantize -c stories110M.pt -p params.json ``` Runtime: ``` /llama_main --model_path=llama2_fp16_qnn_2.21.pte --tokenizer_path=tokenizer.bin --prompt="Once" ``` Output: ``` Once upon a time, there was a boy named Tim. Tim had a pet dog named Max. Max was a big, strong dog. They liked to play and run in the park. One day, Tim and Max went to the park to play. They saw a cat. The cat was up in a tree. Max wanted to help the cat. He tried to climb the tree, but he could not. Then, something unexpected happened. Max started to climb the tree! He was very strong. Max helped the cat come down. The cat was happy. Tim was so proud of his pet. ``` Stories model is too small and sensitive to qunatization. Differential Revision: [D56119738](https://our.internmc.facebook.com/intern/diff/D56119738/) [ghstack-poisoned] --- examples/models/llama2/export_llama_lib.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 47f7095ec52..d76f25d21ab 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -765,8 +765,13 @@ def _export_llama(modelname, args) -> str: # noqa: C901 ) # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm` - backend_options = generate_htp_compiler_spec( - use_fp16=False if args.pt2e_quantize else True + use_fp16 = False if args.pt2e_quantize else True + if use_fp16: + logging.info("Using fp16 for QNN backend, expect performance degradation") + backend_options = generate_htp_compiler_spec(use_fp16=use_fp16) + soc_model = QcomChipset.SM8650 + logging.info( + f"Default to soc {soc_model}, other available options can be found in {QcomChipset}" ) partitioners.append( # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm` @@ -774,7 +779,7 @@ def _export_llama(modelname, args) -> str: # noqa: C901 # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm` generate_qnn_executorch_compiler_spec( # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. - soc_model=QcomChipset.SM8650, # default to SM8650 + soc_model=soc_model, # default to SM8650 backend_options=backend_options, debug=False, saver=False, From 39bbecf67ef07ef3b27e0b6e871bf759ce6c4b7b Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Sun, 14 Apr 2024 11:24:44 -0700 Subject: [PATCH 3/3] Update on "qnn end to end flow" Patch a few changes including: - support bool tensor type - support fp16 and fix the 8w8a quantization. - add two non-supported ops (slice_scatter and index_put) in common_defs.py stories model working end to end: AOT: fp16: ``` python -m examples.models.llama2.export_llama -kv --qnn -c stories110M.pt -p params.json ``` quantize: ``` python -m examples.models.llama2.export_llama -kv --qnn --pt2e_quantize -c stories110M.pt -p params.json ``` Runtime: ``` /llama_main --model_path=llama2_fp16_qnn_2.21.pte --tokenizer_path=tokenizer.bin --prompt="Once" ``` Output: ``` Once upon a time, there was a boy named Tim. Tim had a pet dog named Max. Max was a big, strong dog. They liked to play and run in the park. One day, Tim and Max went to the park to play. They saw a cat. The cat was up in a tree. Max wanted to help the cat. He tried to climb the tree, but he could not. Then, something unexpected happened. Max started to climb the tree! He was very strong. Max helped the cat come down. The cat was happy. Tim was so proud of his pet. ``` Stories model is too small and sensitive to qunatization. Differential Revision: [D56119738](https://our.internmc.facebook.com/intern/diff/D56119738/) [ghstack-poisoned] --- examples/models/llama2/export_llama_lib.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index d76f25d21ab..76bb2d8db61 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -649,7 +649,9 @@ def _export_llama(modelname, args) -> str: # noqa: C901 # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. qnn_quantizer = QnnQuantizer() - # more custom quantization are supported including 16a4w etc. default to 8bit quantized + logging.info( + "More custom quantization are supported including 16a4w etc. default to 8bit quantized" + ) custom_annotations = () qnn_quantizer.add_custom_quant_annotations(custom_annotations) quantizers.append(qnn_quantizer)