From e5cd4187c1e8dd191903144e788808366b39aa76 Mon Sep 17 00:00:00 2001 From: haowhsu-quic Date: Mon, 21 Jul 2025 09:13:21 +0800 Subject: [PATCH] Qualcomm AI Engine Direct - PTQ for llama3.2 1b/3b - add ptq recipe for llama3.2 1b/3b - add seq_mse support for helping quantizing 1b model - complement qnn_llama_runner for smollm2 --- backends/qualcomm/_passes/__init__.py | 2 + backends/qualcomm/_passes/seq_mse.py | 223 ++++++++++++++++++ .../qualcomm/quantizer/custom_annotation.py | 19 ++ .../observers/per_block_param_observer.py | 1 + backends/qualcomm/tests/test_qnn_delegate.py | 71 ++++-- examples/qualcomm/oss_scripts/llama/README.md | 12 +- .../qualcomm/oss_scripts/llama/__init__.py | 76 ++++-- .../oss_scripts/llama/decoder_constants.py | 3 +- .../oss_scripts/llama/decoder_utils.py | 22 ++ examples/qualcomm/oss_scripts/llama/llama.py | 52 +++- .../oss_scripts/llama/qnn_llama_runner.cpp | 2 +- .../oss_scripts/llama/runner/runner.cpp | 4 +- 12 files changed, 437 insertions(+), 50 deletions(-) create mode 100644 backends/qualcomm/_passes/seq_mse.py diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index bbfb18b1851..15fce79ea12 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -39,6 +39,7 @@ from .remove_redundancy import RemoveRedundancy from .replace_arange_args import ReplaceArangeArgs from .replace_inf_values import ReplaceInfValues +from .seq_mse import SeqMSE from .tag_quant_io import TagQuantIO @@ -78,5 +79,6 @@ RemoveRedundancy, ReplaceArangeArgs, ReplaceInfValues, + SeqMSE, TagQuantIO, ] diff --git a/backends/qualcomm/_passes/seq_mse.py b/backends/qualcomm/_passes/seq_mse.py new file mode 100644 index 00000000000..a1d8343c928 --- /dev/null +++ b/backends/qualcomm/_passes/seq_mse.py @@ -0,0 +1,223 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import types +from contextlib import contextmanager + +import torch +import torchao +from executorch.backends.qualcomm.quantizer.observers.per_block_param_observer import ( + PerBlockParamObserver, +) +from executorch.exir.pass_base import ExportPass, PassResult +from torchao.quantization.pt2e import PerChannelMinMaxObserver + + +class SeqMseModule(torch.nn.Module): + """ + Args: + nominal_weight: Tensor + nominal parameters from operator + nominal_bias: Tensor + nominal parameters from operator + operator: fx.Node + operator to be executed + observer: UniformQuantizationObserverBase + parameter observer (specific for weight) + num_candidates: int + grids to search minimal mse loss + """ + + def __init__( + self, + nominal_weight, + nominal_bias, + operator, + observer, + num_candidates, + ): + super().__init__() + self.nominal_weight = nominal_weight + self.nominal_bias = nominal_bias + self.observer = observer + self.steps = torch.linspace( + 1 / num_candidates, 1, steps=num_candidates + ).tolist() + self.operator = self._make_operator(operator) + self.best_candidate_step = 1.0 + + def _make_operator(self, aten_op): + if aten_op.target == torch.ops.aten.conv2d.default: + stride = [1, 1] if len(aten_op.args) < 4 else aten_op.args[3] + padding = [0, 0] if len(aten_op.args) < 5 else aten_op.args[4] + dilation = [1, 1] if len(aten_op.args) < 6 else aten_op.args[5] + groups = 1 if len(aten_op.args) < 7 else aten_op.args[6] + has_bias = self.nominal_bias is not None + module = torch.nn.Conv2d( + in_channels=self.nominal_weight.shape[1], + out_channels=self.nominal_weight.shape[0], + kernel_size=self.nominal_weight.shape[-2:], + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=has_bias, + ) + module.weight.data = self.nominal_weight + if has_bias: + module.bias.data = self.nominal_bias + return module + else: + raise NotImplementedError(f"target of {aten_op.target} is not implemented") + + def _per_block_qdq(self, scale, zero_point): + return torchao.quantization.quant_primitives._fake_quantize_affine( + input=self.nominal_weight, + block_size=self.observer.block_size, + scale=scale, + zero_point=zero_point, + quant_dtype=self.observer.dtype, + quant_min=self.observer.quant_min, + quant_max=self.observer.quant_max, + ) + + def _per_channel_qdq(self, scale, zero_point): + return torch.fake_quantize_per_channel_affine( + input=self.nominal_weight, + scale=scale, + zero_point=zero_point, + axis=0, + quant_min=self.observer.quant_min, + quant_max=self.observer.quant_max, + ) + + def _fake_quant(self, scale, zero_point): + dispatcher = { + PerChannelMinMaxObserver: self._per_channel_qdq, + PerBlockParamObserver: self._per_block_qdq, + } + return dispatcher[type(self.observer)](scale, zero_point) + + def _find_best_candidate(self, nominal_input, nominal_output): + # calculate current baseline + scale, zero_point = self.observer.calculate_qparams() + zero_point = zero_point.to(torch.int32) + self.operator.weight.data = self._fake_quant(scale, zero_point) + candidate, current_loss = ( + 1, + torch.nn.functional.mse_loss( + self.operator(nominal_input), nominal_output + ).item(), + ) + for step in self.steps: + self.operator.weight.data = self._fake_quant(scale * step, zero_point) + loss = torch.nn.functional.mse_loss( + self.operator(nominal_input), nominal_output + ).item() + if loss < current_loss: + candidate, current_loss = step, loss + return candidate + + def forward(self, nominal_input, nominal_output): + self.best_candidate_step = self._find_best_candidate( + nominal_input=nominal_input, nominal_output=nominal_output + ) + + +class InsertSeqMse(ExportPass): + """ + Insert Seq Mse Observer to find the best quant config for certain node's weight. + """ + + seq_mse_ops = {torch.ops.aten.conv2d.default} + + def __init__(self, num_candidates=1000): + super(InsertSeqMse, self).__init__() + self.num_candidates = num_candidates + + def _insert_seq_mse( + self, graph_module: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + count = 0 + for node in graph_module.graph.nodes: + if node.target in self.seq_mse_ops: + # extract observer + weight_node_obs = node.args[1] + observer = getattr(graph_module, weight_node_obs.name) + # extract parameters + weight_node = weight_node_obs.args[0] + weight_tensor = graph_module.get_parameter(weight_node.target).detach() + bias_tensor = None + if len(node.args) > 2 and node.args[2] is not None: + bias_tensor = graph_module.get_parameter( + node.args[2].args[0].target + ).detach() + + with graph_module.graph.inserting_after(node): + seq_mse_mod = SeqMseModule( + nominal_weight=weight_tensor, + nominal_bias=bias_tensor, + operator=node, + observer=observer, + num_candidates=self.num_candidates, + ) + module_name = f"seq_mse_{count}" + count += 1 + setattr(graph_module, module_name, seq_mse_mod) + input_nodes = (node.args[0], node) + graph_module.graph.create_node( + "call_module", module_name, input_nodes, {} + ) + + def call(self, graph_module: torch.fx.GraphModule): + self._insert_seq_mse(graph_module) + graph_module.recompile() + return PassResult(graph_module, True) + + +class RemoveSeqMse(ExportPass): + """ + Remove Seq Mse before invoking convert_pt2e and update final quantization encoding. + """ + + def __init__(self): + super(RemoveSeqMse, self).__init__() + + def _remove_seq_mse( + self, graph_module: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + node_to_erase = [] + for node in graph_module.graph.nodes: + if node.op == "call_module": + # try extracting SeqMse module + module = getattr(graph_module, node.target) + if isinstance(module, SeqMseModule): + # rewrite observer method for pre-calculated scale + scale, zero_point = module.observer.calculate_qparams() + module.observer.updated_encoding = ( + scale * module.best_candidate_step, + zero_point, + ) + module.observer.calculate_qparams = types.MethodType( + lambda s: s.updated_encoding, module.observer + ) + node_to_erase.append(node) + + for node in node_to_erase: + graph_module.graph.erase_node(node) + + def call(self, graph_module: torch.fx.GraphModule): + self._remove_seq_mse(graph_module) + graph_module.recompile() + return PassResult(graph_module, True) + + +@contextmanager +def SeqMSE(prepared_gm, num_candidates): + prepared_gm = InsertSeqMse(num_candidates)(prepared_gm).graph_module + try: + yield + finally: + prepared_gm = RemoveSeqMse()(prepared_gm).graph_module diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index ef988424cc1..e3bf48056eb 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -31,6 +31,25 @@ ) +def annotate_down_proj( + gm: torch.fx.GraphModule, quantization_config: QuantizationConfig +): + for node in gm.graph.nodes: + if ( + node.target == torch.ops.aten.conv2d.default + and any(s in node.meta["stack_trace"] for s in ["forward_feedfoward_conv"]) + and node.args[0].target == torch.ops.aten.mul.Tensor + ): + input_qspec_map = {} + input_qspec_map[node.args[0]] = quantization_config.input_activation + input_qspec_map[node.args[1]] = quantization_config.weight + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + + def annotate_eurobert(gm: torch.fx.GraphModule): """ QNN does not support int32 -> signed 16bit quant diff --git a/backends/qualcomm/quantizer/observers/per_block_param_observer.py b/backends/qualcomm/quantizer/observers/per_block_param_observer.py index fa91a600a02..b3f854db527 100644 --- a/backends/qualcomm/quantizer/observers/per_block_param_observer.py +++ b/backends/qualcomm/quantizer/observers/per_block_param_observer.py @@ -34,6 +34,7 @@ def __init__( eps=eps, **kwargs, ) + self.dtype = dtype self.block_size = block_size # TODO: expand this when QNN starts to support more configurations self.bitwidth_of_scale = 4 diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 7e492fa3e30..22e050a0471 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -4636,6 +4636,33 @@ def test_qnn_backend_generate_optrace(self): qhas_data = json.load(qhas_file) self.assertIn("data", qhas_data) + def test_qnn_backend_seq_mse(self): + from executorch.backends.qualcomm._passes.seq_mse import SeqMSE + + o_ch, i_ch, kernel, padding = 32, 512, (1, 1), 0 + module = Conv2dSingle( # noqa: F405 + in_channel=i_ch, + out_channel=o_ch, + kernel_size=kernel, + padding=padding, + ) + sample_input = (torch.randn(1, i_ch, 1, o_ch),) + # per-channel / per-block + quantizers = [ + make_quantizer(), + make_quantizer(quant_dtype=QuantDtype.use_16a4w_block), + ] + quantizers[-1].set_block_size_map({"conv2d": (1, 32, 1, 1)}) + + for i, quantizer in enumerate(quantizers): + with self.subTest(i=i): + ep = torch.export.export(module, sample_input).module() + prepared = prepare_pt2e(ep, quantizer) + with SeqMSE(prepared, 100): + prepared(*sample_input) + converted = convert_pt2e(prepared) + self.lower_module_and_test_output(converted, sample_input) + class TestExampleLLMScript(TestQNN): def test_static_gemma3_1b(self): @@ -4709,7 +4736,7 @@ def test_static_gemma3_1b(self): msg["inference_speed"], inference_speed_ref[self.model] ) - def test_llama3_2_1b(self): + def test_llama3_2_instruct(self): if not self.required_envs(): self.skipTest("missing required envs") assert ( @@ -4741,13 +4768,16 @@ def test_llama3_2_1b(self): "--temperature", "0", "--decoder_model", - "llama3_2", + "llama3_2-1b_instruct", "--model_mode", - "hybrid", - "--prefill_ar_len", - "32", + "kv", "--max_seq_len", - "512", + "1024", + "--eval_perplexity", + "--tasks", + "wikitext", + "--limit", + "1", ] if self.compile_only: cmds.extend(["--compile_only"]) @@ -4760,7 +4790,6 @@ def test_llama3_2_1b(self): if self.pre_gen_pte: cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) - golden_start_with = "<|start_header_id|>user<|end_header_id|>" p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) with Listener((self.ip, self.port)) as listener: conn = listener.accept() @@ -4769,19 +4798,17 @@ def test_llama3_2_1b(self): if "Error" in msg: self.fail(msg["Error"]) else: - if not self.compile_only: - model_out = msg["result"][0] - self.assertTrue( - model_out.startswith(golden_start_with), - f"Expected Output: {golden_start_with}. Actual Output: {model_out}", + inference_speed_ref = {"SM8650": 37, "SM8750": 49} + if ( + not self.compile_only + and not self.enable_x86_64 + and self.model in inference_speed_ref + ): + self.assertLessEqual(msg["pte_size"], 1_500_000_000) + self.assertLessEqual(msg["wiki_ppl"], 15) + self.assertGreaterEqual( + msg["inference_speed"], inference_speed_ref[self.model] ) - # x86 does not allow weight sharing, so we don't check pte size. - # Inference speed on x86 is slow, so we only check when running on Android - if not self.enable_x86_64: - pte_size = msg["pte_size"] - self.assertLessEqual(pte_size, 1_300_000_000) # 1.3GB - if not self.compile_only and not self.enable_x86_64: - self.assertGreaterEqual(msg["inference_speed"], 66) # Lanai def test_llama_stories_260k(self): if not self.required_envs(): @@ -4976,12 +5003,6 @@ def test_static_phi4(self): cmds.extend(["--enable_x86_64"]) if self.pre_gen_pte: cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) - cmds.extend( - [ - "--quant_attrs_path", - f"{self.pre_gen_pte}/kv_llama_qnn_quant_attrs.json", - ] - ) p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) with Listener((self.ip, self.port)) as listener: diff --git a/examples/qualcomm/oss_scripts/llama/README.md b/examples/qualcomm/oss_scripts/llama/README.md index 5ce15cabaa9..5a4f622b320 100644 --- a/examples/qualcomm/oss_scripts/llama/README.md +++ b/examples/qualcomm/oss_scripts/llama/README.md @@ -65,10 +65,16 @@ At the end of this step, users should have the following files ready: `consolida python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint stories110M.pt --params params.json --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin --decoder_model stories110m --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "Once upon a time" ``` -#### LLAMA3.2 -Default example using hybrid mode. +#### LLAMA3.2 1B Instruct +Default example using kv mode. +```bash +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-1b_instruct --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +``` + +#### LLAMA3.2 3B Instruct +Default example using kv mode. ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-3b_instruct --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 ``` #### Gemma3 1B diff --git a/examples/qualcomm/oss_scripts/llama/__init__.py b/examples/qualcomm/oss_scripts/llama/__init__.py index ad74754708c..a0db11a5407 100644 --- a/examples/qualcomm/oss_scripts/llama/__init__.py +++ b/examples/qualcomm/oss_scripts/llama/__init__.py @@ -13,6 +13,7 @@ import torch from executorch.backends.qualcomm.quantizer.custom_annotation import ( + annotate_down_proj, annotate_kv_8bit, annotate_output_16a8w, annotate_wv_sha, @@ -70,6 +71,8 @@ class LLMModelConfig(ABC): masked_softmax: The MaskedSoftmax feature is designed to optimize the LLMs accuracy and performance executed on HTP backend. MaskedSoftmax is used to replace the Softmax(Add(In, Mask)) structure in attention block in LLMs during backend optimization. For more details, please refer to QNN documents. Note that it is only supported starting from QNN 2.35. + seq_mse_candidates: Number of steps to sequentially search for optimum scales for quantized parameters which will minimize + the MSE of activation value between floating point golden & fake quantization. r1: Enable SpinQuant R1 quantization optimization. r2: Enable SpinQuant R2 quantization optimization. r3: Enable SpinQuant R3 quantization optimization. @@ -87,6 +90,7 @@ class LLMModelConfig(ABC): ptq: QuantDtype group_size: int masked_softmax: bool + seq_mse_candidates: int r1: bool r2: bool r3: bool @@ -179,6 +183,7 @@ class LlamaStories260K(LLMModelConfig): ptq = QuantDtype.use_16a4w group_size = None masked_softmax = False + seq_mse_candidates = 0 r1 = False r2 = False r3 = False @@ -209,6 +214,7 @@ class LlamaStories110M(LLMModelConfig): ptq = QuantDtype.use_16a4w group_size = None masked_softmax = False + seq_mse_candidates = 0 r1 = False r2 = False r3 = False @@ -225,9 +231,9 @@ class LlamaStories110M(LLMModelConfig): ) -@register_llm_model("llama3_2") +@register_llm_model("llama3_2-1b_instruct") @dataclass(init=False, frozen=True) -class Llama3_2(LLMModelConfig): +class Llama3_2_1B_Instruct(LLMModelConfig): repo_id = None params_path = None convert_weights = None @@ -235,23 +241,49 @@ class Llama3_2(LLMModelConfig): # The Llama3_2 enabled should be instruct, however, Llama's tokenizer does not provide utility to apply chat template. instruct_model = False - num_sharding = 4 + num_sharding = 1 # quant config - ptq = QuantDtype.use_16a4w - group_size = None + ptq = QuantDtype.use_16a4w_block + group_size = 32 masked_softmax = False + seq_mse_candidates = 1000 r1 = False r2 = False r3 = False - quantization_config_wv_sha_8a4w = get_ptq_per_channel_quant_config( - act_dtype=torch.uint8, - weight_dtype=torch.int4, - act_observer=MinMaxObserver, - act_symmetric=True, + quantization_config_down_proj_16a8w = get_ptq_per_channel_quant_config( + torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver ) custom_annotation = ( annotate_kv_8bit, - partial(annotate_wv_sha, quantization_config=quantization_config_wv_sha_8a4w), + annotate_output_16a8w, + partial( + annotate_down_proj, quantization_config=quantization_config_down_proj_16a8w + ), + ) + + +@register_llm_model("llama3_2-3b_instruct") +@dataclass(init=False, frozen=True) +class Llama3_2_3B_Instruct(LLMModelConfig): + repo_id = None + params_path = None + convert_weights = None + transform_weight = True + # The Llama3_2 enabled should be instruct, however, Llama's tokenizer does not provide utility to apply chat template. + instruct_model = False + + num_sharding = 4 + # quant config + ptq = QuantDtype.use_16a4w_block + group_size = 32 + masked_softmax = False + seq_mse_candidates = 0 + r1 = False + r2 = False + r3 = False + custom_annotation = ( + annotate_kv_8bit, + annotate_output_16a8w, ) @@ -271,6 +303,7 @@ class Gemma3(LLMModelConfig): ptq = QuantDtype.use_16a4w_block group_size = 64 masked_softmax = True + seq_mse_candidates = 0 r1 = False r2 = False r3 = False @@ -299,6 +332,7 @@ class Phi4Mini(LLMModelConfig): ptq = QuantDtype.use_16a4w_block group_size = 16 masked_softmax = False + seq_mse_candidates = 0 r1 = False r2 = False r3 = False @@ -331,6 +365,7 @@ class Qwen2_5_0_5B(LLMModelConfig): ptq = QuantDtype.use_16a4w_block group_size = 16 masked_softmax = True + seq_mse_candidates = 0 r1 = False r2 = False r3 = True @@ -353,6 +388,7 @@ class Qwen2_5_1_5B(LLMModelConfig): ptq = QuantDtype.use_16a4w_block group_size = 16 masked_softmax = True + seq_mse_candidates = 0 r1 = False r2 = False r3 = True @@ -372,13 +408,21 @@ class Qwen3_0_6B(LLMModelConfig): num_sharding = 1 # quant config - ptq = QuantDtype.use_16a8w - group_size = None + ptq = QuantDtype.use_16a4w_block + group_size = 32 masked_softmax = True + seq_mse_candidates = 1000 r1 = False r2 = False - r3 = True - custom_annotation = () + r3 = False + quantization_config_down_proj_16a8w = get_ptq_per_channel_quant_config( + torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver + ) + custom_annotation = ( + partial( + annotate_down_proj, quantization_config=quantization_config_down_proj_16a8w + ), + ) @register_llm_model("qwen3-1_7b") @@ -397,6 +441,7 @@ class Qwen3_1_7B(LLMModelConfig): ptq = QuantDtype.use_16a4w_block group_size = 16 masked_softmax = True + seq_mse_candidates = 0 r1 = False r2 = False r3 = True @@ -422,6 +467,7 @@ class Smollm2_135M(LLMModelConfig): ptq = QuantDtype.use_16a8w group_size = None masked_softmax = False + seq_mse_candidates = 0 r1 = False r2 = False r3 = False diff --git a/examples/qualcomm/oss_scripts/llama/decoder_constants.py b/examples/qualcomm/oss_scripts/llama/decoder_constants.py index 03bf5043d60..a115106bd86 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_constants.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_constants.py @@ -14,9 +14,10 @@ DECODER_MODEL_VERSION = { "stories260k": "llama2", "stories110m": "llama2", - "llama3_2": "llama3", "gemma3-1b": "gemma3", "phi_4_mini": "phi_4_mini", + "llama3_2-1b_instruct": "llama3", + "llama3_2-3b_instruct": "llama3", "qwen2_5-0_5b": "qwen2_5", "qwen2_5-1_5b": "qwen2_5", "qwen3-0_6b": "qwen3", diff --git a/examples/qualcomm/oss_scripts/llama/decoder_utils.py b/examples/qualcomm/oss_scripts/llama/decoder_utils.py index eaa25698e90..76cf85c6e9c 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_utils.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_utils.py @@ -12,6 +12,7 @@ import numpy as np import torch +from executorch.backends.qualcomm._passes import SeqMSE from executorch.examples.models.llama.evaluate.eager_eval import EagerEvalWrapper from executorch.examples.qualcomm.oss_scripts.llama.decoder_constants import ( DECODER_MODEL_VERSION, @@ -60,6 +61,7 @@ def __init__( get_example_inputs: Callable, kv_updater: Callable, use_i64_token: bool, + seq_mse_candidates: int, ): # n seq len = n-1 cache len, so we len(inps) = n-1 during _model_call assert max_seq_length is not None, "max_seq_length must be provided" @@ -73,6 +75,7 @@ def __init__( self.max_seq_length = max_seq_length self.kv_updater = kv_updater self.use_i64_token = use_i64_token + self.seq_mse_candidates = seq_mse_candidates def _model_call(self, inps): all_logits = None @@ -80,6 +83,8 @@ def _model_call(self, inps): if self._use_kv_cache: kwargs["ar_len"] = self.ar_len kwargs["kv_updater"] = self.kv_updater + kwargs["seq_mse_candidates"] = self.seq_mse_candidates + all_logits = INFERENCE_REGISTRY[self._use_kv_cache]( self.get_example_inputs, inps, @@ -90,6 +95,8 @@ def _model_call(self, inps): collect_logits=True, **kwargs, ) + # one shot is enough for seq mse + self.seq_mse_candidates = 0 return all_logits @@ -296,6 +303,7 @@ def kv_inference( kv_updater=smart_mask_updater, use_i64_token=False, collect_logits=False, + seq_mse_candidates=0, ): _, atten_mask, _, k_caches, v_caches = get_example_inputs(use_kv_cache=True) @@ -355,6 +363,17 @@ def kv_inference( if collect_logits: result_logits.append(logits[:, :num_tokens_in_chunk]) + # We should have enough calibration data when generating last token if task was specified + if seq_mse_candidates != 0 and pos == num_prompt_tokens - 1: + with SeqMSE(module, seq_mse_candidates): + module( + tmp_token_list, + *atten_mask, + tmp_pos, + *k_caches, + *v_caches, + ) + # Update the pos, KV cache and attention mask. pos, k_caches, v_caches = kv_updater( num_tokens_in_chunk, @@ -414,6 +433,7 @@ def kv_inference( torch.argmax(logits[:, num_tokens_in_chunk - 1], dim=-1).item() ) num_tokens = len(total_token_list) + logging.info(f"kv inference result:\n{tokenizer.decode(total_token_list)}") if collect_logits: result_logits = torch.cat(result_logits, dim=1) @@ -493,6 +513,7 @@ def graph_module_inference( num_fewshot=None, use_i64_token=False, event_name: Optional[str] = None, + seq_mse_candidates: int = 0, ): """ This function supports model execution from static nn.Module decoder model @@ -528,6 +549,7 @@ def graph_module_inference( get_example_inputs=get_example_inputs, kv_updater=kv_updater, use_i64_token=use_i64_token, + seq_mse_candidates=seq_mse_candidates, ) # Evaluate the model with torch.no_grad(): diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index 73a804219d5..79f713c048a 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -142,7 +142,7 @@ def __init__( self.inputs = ( inputs[0], # tokens *inputs[1], # attn_mask - *(inputs[2] if self.llama_meta["get_use_kv_cache"] else []), # pos_ids + *((inputs[2],) if self.llama_meta["get_use_kv_cache"] else []), # pos_ids *(inputs[3] if self.llama_meta["get_use_kv_cache"] else []), # k_caches *(inputs[4] if self.llama_meta["get_use_kv_cache"] else []), # v_caches ) @@ -268,6 +268,7 @@ def quantize( num_fewshot=args.num_fewshot, use_i64_token=args.embedding_quantize is not None, event_name="prepare_pt2e_tasks", + seq_mse_candidates=self.decoder_model_config.seq_mse_candidates, ) # Check user's prompt, helps calibrate special token @@ -435,6 +436,7 @@ def compile( kv_config.use_kv_cache = True kv_config.enable_r3 = decoder_model_config.r3 kv_config.kv_io_bit_width = decoder_model_config.get_kv_io_bit_width() + if decoder_model_config.masked_softmax: if is_qnn_sdk_version_less_than("2.35"): logging.warning( @@ -686,11 +688,11 @@ def permute(w, heads): llama_instance_list[i] = SingleLlama( llama_instance_list[i].eval(), decoder_model_config, pte_filename ) - if args.embedding_quantize: llama_instance_list[i].passes_job[I64toI32][ QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY ]["skip_node"] = {"tokens"} + if decoder_model_config.ptq: start_quantize_ts = time.time() custom_annotations = decoder_model_config.custom_annotation @@ -723,6 +725,48 @@ def permute(w, heads): llama_instance.passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][ "get_quant_io_dtype_fn" ] = partial(llama_instance._tag_ios, fixed_point_type=fixed_point_type) + + # force overriding frozen parameters here for model quantizing under seq mse scenario + # this will make weight sharing work properly + if decoder_model_config.seq_mse_candidates != 0 and args.model_mode in [ + "hybrid", + "lookahead", + ]: + decode, prompt = [ + instance.llama_graph_module for instance in llama_instance_list + ] + override_nodes = { + str(node.meta["nn_module_stack"].values()): node + for node in prompt.graph.nodes + if node.target == torch.ops.aten.conv2d.default + } + indices_map = { + # (affine_tensor, group_size, scales, zero_points, dtype, min, max) + torch.ops.torchao.dequantize_affine: [0, 2, 3], + # (per_channel_tensor, scales, zero_points, dim, dtype, min, max) + torch.ops.quantized_decomposed.dequantize_per_channel.default: [ + 0, + 1, + 2, + ], + # should not need to worry about per-tensor case + } + for node in decode.graph.nodes: + if node.target == torch.ops.aten.conv2d.default: + if target_node := override_nodes.get( + str(node.meta["nn_module_stack"].values()) + ): + # arguments of conv: (input, weight, bias) + for i, dq_node in enumerate(node.args[1:]): + for index in indices_map[dq_node.target]: + setattr( + prompt, + target_node.args[i + 1].args[index].target, + getattr(decode, dq_node.args[index].target), + ) + else: + raise RuntimeError("failed to override quantization attribute") + end_quantize_ts = time.time() logging.info(f"Time for quantizing: {end_quantize_ts - start_quantize_ts}") @@ -967,8 +1011,8 @@ def post_process(): # No pregen inputs, input_list is not required adb.push(inputs=[], files=[runtime_tokenizer_path]) adb.execute(custom_runner_cmd=runner_cmd) - adb.pull(output_path=args.artifact, callback=post_process) + if args.ip and args.port != -1: inference_speed = 0 with open( @@ -1224,7 +1268,7 @@ def export_llama(args) -> None: args.tokenizer_bin is not None ), "Please provide tokenizer_bin for stories." runtime_tokenizer_path = args.tokenizer_bin - elif args.decoder_model == "llama3_2": + elif "llama3_2" in args.decoder_model: tokenizer = get_tokenizer(args.tokenizer_model) assert isinstance( tokenizer, TiktokenTokenizer diff --git a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp index 80d92beb099..e143d314d06 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp @@ -155,7 +155,7 @@ std::string get_formatted_prompt( if (!system_prompt.empty()) { formatted_prompt.append("<|im_start|>system\n"); formatted_prompt.append(system_prompt); - formatted_prompt.append("<|im_end|>\n\n"); + formatted_prompt.append("<|im_end|>\n"); } formatted_prompt.append("<|im_start|>user\n"); formatted_prompt.append(prompt); diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index 83ea0b88ad0..dfba5fbb677 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -191,7 +191,9 @@ Error Runner::load() { eos_ids->insert(tokenizer_->encode("<|eot_id|>", 0, 0).get()[0]); } else if (decoder_model_version_ == DecoderModelVersion::kPhi4) { eos_ids->insert(tokenizer_->encode("<|end|>", 0, 0).get()[0]); - } else if (decoder_model_version_ == DecoderModelVersion::kQwen3) { + } else if ( + decoder_model_version_ == DecoderModelVersion::kQwen3 || + decoder_model_version_ == DecoderModelVersion::kSmollm2_135m) { eos_ids->insert(tokenizer_->encode("<|im_end|>", 0, 0).get()[0]); } else if (decoder_model_version_ == DecoderModelVersion::kGemma3) { eos_ids->insert(tokenizer_->encode("", 0, 0).get()[0]);