From d38eb08f3694eb3770260787f83d54e9916484c5 Mon Sep 17 00:00:00 2001 From: DannyYuyang-quic Date: Tue, 26 Aug 2025 14:12:43 +0800 Subject: [PATCH 1/2] Qualcomm AI Engine Direct - GA Static SmolLM3 3B Summary: - e2e script for GA Static SmolLM3-3B - perf: 16a4w block quant token rate in kv mode: ~= 30 tokens/sec(SM8750) - acc: PPL ~= (fp: 8.345 -> htp:8.976) in wikitext dataset - add model params file & model weight converter --- .../qualcomm/quantizer/custom_annotation.py | 44 ++++++- backends/qualcomm/tests/test_qnn_delegate.py | 99 +++++++++++++--- examples/models/smollm3/3b_config.json | 14 +++ examples/models/smollm3/__init__.py | 16 +++ examples/models/smollm3/convert_weights.py | 112 ++++++++++++++++++ examples/qualcomm/oss_scripts/llama/README.md | 13 +- .../qualcomm/oss_scripts/llama/__init__.py | 44 ++++++- .../oss_scripts/llama/decoder_constants.py | 1 + .../oss_scripts/llama/eval_llama_qnn.py | 7 +- examples/qualcomm/oss_scripts/llama/llama.py | 8 +- .../oss_scripts/llama/masking_utils.py | 19 +-- .../oss_scripts/llama/model/static_llama.py | 25 ++-- .../oss_scripts/llama/qnn_llama_runner.cpp | 15 ++- .../oss_scripts/llama/runner/runner.cpp | 7 +- .../oss_scripts/llama/runner/runner.h | 1 + 15 files changed, 376 insertions(+), 49 deletions(-) create mode 100644 examples/models/smollm3/3b_config.json create mode 100644 examples/models/smollm3/__init__.py create mode 100644 examples/models/smollm3/convert_weights.py diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index e3bf48056eb..3f10dbaa3fc 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from enum import Enum, unique from typing import Sequence import torch @@ -50,6 +51,17 @@ def annotate_down_proj( ) +@unique +class StaticLLMQuantConfig(Enum): + """ + Layer namespace configuration for Qualcomm's static LLaMA quantization. + """ + + wq_sha = "wq_sha" # Query weight (single head) + wk_sha = "wk_sha" # Key weight (single head) + wv_sha = "wv_sha" # Value weight (single head) + + def annotate_eurobert(gm: torch.fx.GraphModule): """ QNN does not support int32 -> signed 16bit quant @@ -185,11 +197,35 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict): ) -def annotate_wv_sha(gm: torch.fx.GraphModule, quantization_config: QuantizationConfig): +def annotate_qkv_proj_sha( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig, + qkv_tags: set[StaticLLMQuantConfig], +): + """ + Annotates QKV projection layers in a GraphModule for quantization, + specifically layers defined in StaticLLMQuantConfig. + + Args: + qkv_tags (set[StaticLLMQuantConfig]): A set of enum tags indicating which QKV layers + (e.g., wq, wk, wv) should be annotated for quantization. Only tags defined in + StaticLLMQuantConfig are allowed. + + Raises: + ValueError: If any tag in `qkv_tags` is not among the allowed enum members. + """ + + # Get all valid tags from the StaticLLMQuantConfig enum + allowed_tags = set(StaticLLMQuantConfig) + invalid_tags = qkv_tags - allowed_tags + if invalid_tags: + raise ValueError( + f"Invalid qkv tags: {invalid_tags}. Allowed tags are: {allowed_tags}" + ) + for node in gm.graph.nodes: - if ( - node.target == torch.ops.aten.conv2d.default - and "wv_sha" in node.meta["stack_trace"] + if node.target == torch.ops.aten.conv2d.default and any( + tag.value in node.meta["stack_trace"] for tag in qkv_tags ): input_qspec_map = {} input_qspec_map[node.args[0]] = quantization_config.input_activation diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 22e050a0471..6ef4fa8fe13 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -5138,6 +5138,60 @@ def test_static_qwen3(self): msg["inference_speed"], inference_speed_ref[self.model] ) + def test_qwen2_5(self): + if not self.required_envs([]): + self.skipTest("missing required envs") + prompt = "My favourite condiment is " + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/qwen2_5/qwen2_5.py", + "--prompt", + prompt, + "--decoder_model", + "qwen2.5_0.5B", + "--ptq", + "16a8w", + "--enable_spinquant_r3", + "--max_seq_len", + "128", + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--model", + self.model, + "--ip", + self.ip, + "--port", + str(self.port), + ] + if self.compile_only: + cmds.extend(["--compile_only"]) + elif self.device: + cmds.extend(["--device", self.device]) + if self.host: + cmds.extend(["--host", self.host]) + elif self.enable_x86_64: + cmds.extend(["--enable_x86_64"]) + if self.pre_gen_pte: + cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) + + golden_start_with = "My favourite condiment is iced tea." + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + if not self.compile_only: + model_out = msg["result"][0] + self.assertTrue( + model_out.startswith(golden_start_with), + f"Expected Output: '{golden_start_with}' Actual Output: '{model_out}'", + ) + def test_static_smollm2(self): if not self.required_envs(): self.skipTest("missing required envs") @@ -5171,6 +5225,8 @@ def test_static_smollm2(self): "--eval_perplexity", "--task", "wikitext", + "--limit", + "1", ] if self.compile_only: cmds.extend(["--compile_only"]) @@ -5194,22 +5250,14 @@ def test_static_smollm2(self): self.assertLessEqual(msg["wiki_ppl"], 25) self.assertGreaterEqual(msg["inference_speed"], 200) - def test_qwen2_5(self): - if not self.required_envs([]): + def test_static_smollm3(self): + if not self.required_envs(): self.skipTest("missing required envs") + prompt = "My favourite condiment is " cmds = [ "python", - f"{self.executorch_root}/examples/qualcomm/oss_scripts/qwen2_5/qwen2_5.py", - "--prompt", - prompt, - "--decoder_model", - "qwen2.5_0.5B", - "--ptq", - "16a8w", - "--enable_spinquant_r3", - "--max_seq_len", - "128", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", "--artifact", self.artifact_dir, "--build_folder", @@ -5220,6 +5268,21 @@ def test_qwen2_5(self): self.ip, "--port", str(self.port), + "--prompt", + f"{prompt}", + "--decoder_model", + "smollm3-3b", + "--model_mode", + "kv", + "--temperature", + "0", + "--max_seq_len", + "1024", + "--eval_perplexity", + "--task", + "wikitext", + "--limit", + "1", ] if self.compile_only: cmds.extend(["--compile_only"]) @@ -5232,7 +5295,6 @@ def test_qwen2_5(self): if self.pre_gen_pte: cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) - golden_start_with = "My favourite condiment is iced tea." p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) with Listener((self.ip, self.port)) as listener: conn = listener.accept() @@ -5241,11 +5303,12 @@ def test_qwen2_5(self): if "Error" in msg: self.fail(msg["Error"]) else: - if not self.compile_only: - model_out = msg["result"][0] - self.assertTrue( - model_out.startswith(golden_start_with), - f"Expected Output: '{golden_start_with}' Actual Output: '{model_out}'", + inference_speed_ref = {"SM8650": 23, "SM8750": 28} + self.assertLessEqual(msg["wiki_ppl"], 10) + self.assertLessEqual(msg["pte_size"], 2_600_000_000) # 2.6GB + if self.model in inference_speed_ref: + self.assertGreaterEqual( + msg["inference_speed"], inference_speed_ref[self.model] ) diff --git a/examples/models/smollm3/3b_config.json b/examples/models/smollm3/3b_config.json new file mode 100644 index 00000000000..76844dd85b5 --- /dev/null +++ b/examples/models/smollm3/3b_config.json @@ -0,0 +1,14 @@ +{ + "dim": 2048, + "ffn_dim_multiplier": 1, + "hidden_dim": 11008, + "n_heads": 16, + "n_kv_heads": 4, + "n_layers": 36, + "norm_eps": 1e-06, + "rope_theta": 5000000.0, + "use_scaled_rope": false, + "vocab_size": 128256, + "use_hf_rope": false, + "attention_qkv_bias": false + } diff --git a/examples/models/smollm3/__init__.py b/examples/models/smollm3/__init__.py new file mode 100644 index 00000000000..627cbc631f0 --- /dev/null +++ b/examples/models/smollm3/__init__.py @@ -0,0 +1,16 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.examples.models.llama.model import Llama2Model +from executorch.examples.models.smollm3.convert_weights import convert_weights + + +class SmolLM3Model(Llama2Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +__all__ = [ + "SmolLM3Model", + "convert_weights", +] diff --git a/examples/models/smollm3/convert_weights.py b/examples/models/smollm3/convert_weights.py new file mode 100644 index 00000000000..df51674c3b0 --- /dev/null +++ b/examples/models/smollm3/convert_weights.py @@ -0,0 +1,112 @@ +import argparse +import json +import os +from typing import Dict + +import torch + +from safetensors.torch import load_file + +from torchtune.models.convert_weights import get_mapped_key + + +_SMOLLM_TO_META = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.norm.weight": "norm.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", +} + + +def smollm_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Convert a state dict from torchtune's format to Meta's format. This function + doesn't handle any sharding or splitting of state dicts. It follows the + state_dict IN -> state_dict OUT pattern. + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format. + + Returns: + Dict[str, torch.Tensor]: State dict in Meta's format. + """ + converted_state_dict = {} + for key, value in state_dict.items(): + new_key = get_mapped_key(key, _SMOLLM_TO_META) + converted_state_dict[new_key] = value + converted_state_dict["output.weight"] = converted_state_dict[ + "tok_embeddings.weight" + ] + + return converted_state_dict + + +def load_checkpoint_from_safetensors(input_dir: str) -> Dict: + index_path = os.path.join(input_dir, "model.safetensors.index.json") + if os.path.exists(index_path): + # Sharded checkpoint. + with open(index_path, "r") as f: + index = json.load(f) + weight_map = index["weight_map"] + checkpoint_shards = sorted(set(weight_map.values())) + + # Load all the shards into memory + shard_to_weights = {} + for shard in checkpoint_shards: + shard_to_weights[shard] = load_file(os.path.join(input_dir, shard)) + + # Merge tensors into consolidated state dict. + merged_state_dict = {} + for weight_name, shard in weight_map.items(): + tensor = shard_to_weights[shard][weight_name] + merged_state_dict[weight_name] = tensor + return merged_state_dict + else: + # Single checkpoint. + state_dict = load_file(os.path.join(input_dir, "model.safetensors")) + return state_dict + + +def load_checkpoint(input_dir: str) -> Dict: + pytorch_path = os.path.join(input_dir, "pytorch_model.bin") + if os.path.exists(pytorch_path): + print("Loading checkpoint from PyTorch .bin file") + return torch.load(pytorch_path, map_location="cpu", weights_only=True) + print("Loading checkpoint from safetensors directory") + return load_checkpoint_from_safetensors(input_dir) + + +def convert_weights(input_dir: str, output_file: str) -> None: + print("Loading checkpoint...") + sd = load_checkpoint(input_dir) + print("Converting checkpoint...") + sd = smollm_to_meta(sd) + print("Saving checkpoint...") + torch.save(sd, output_file) + print("Done.") + + +def main(): + parser = argparse.ArgumentParser( + description="Convert SmolLM weights to Meta format." + ) + parser.add_argument( + "input_dir", + type=str, + help="Path to directory containing checkpoint files", + ) + parser.add_argument("output", type=str, help="Path to the output checkpoint") + + args = parser.parse_args() + convert_weights(args.input_dir, args.output) + + +if __name__ == "__main__": + main() diff --git a/examples/qualcomm/oss_scripts/llama/README.md b/examples/qualcomm/oss_scripts/llama/README.md index 5a4f622b320..1be94ec04d6 100644 --- a/examples/qualcomm/oss_scripts/llama/README.md +++ b/examples/qualcomm/oss_scripts/llama/README.md @@ -9,7 +9,8 @@ This file provides you the instructions to run LLM Decoder model with different 5. Phi4-mini-instruct 6. QWEN2.5 0.5B / 1.5B 7. QWEN3 0.6B / 1.7B - 8. SMOLLM2 135M + 8. SmolLM2 135M + 9. SmolLM3 3B We offer the following modes to execute the model: @@ -113,10 +114,16 @@ Default example using hybrid mode python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen3-1_7b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 ``` -#### SMOLLM2 +#### SmolLM2 Default example using hybrid mode. ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -H mlgtw-linux -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm2_135m --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm2_135m --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +``` + +#### SmolLM3 +Default example using kv mode. +```bash +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm3-3b --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 ``` diff --git a/examples/qualcomm/oss_scripts/llama/__init__.py b/examples/qualcomm/oss_scripts/llama/__init__.py index a0db11a5407..99b0739f1f0 100644 --- a/examples/qualcomm/oss_scripts/llama/__init__.py +++ b/examples/qualcomm/oss_scripts/llama/__init__.py @@ -16,7 +16,8 @@ annotate_down_proj, annotate_kv_8bit, annotate_output_16a8w, - annotate_wv_sha, + annotate_qkv_proj_sha, + StaticLLMQuantConfig, ) from executorch.backends.qualcomm.quantizer.qconfig import ( get_ptq_per_channel_quant_config, @@ -34,6 +35,9 @@ from executorch.examples.models.smollm2 import ( convert_weights as convert_smollm2_weights, ) +from executorch.examples.models.smollm3 import ( + convert_weights as convert_smollm3_weights, +) from executorch.examples.qualcomm.oss_scripts.llama.decoder_constants import ( DECODER_MODEL_VERSION, @@ -51,6 +55,15 @@ LLM_VARIANT_ARCHS = { "gemma3-1b": MultiScopeAwareLlamaModel, } +annotate_wqkv_sha = partial( + annotate_qkv_proj_sha, + qkv_tags={ + StaticLLMQuantConfig.wq_sha, + StaticLLMQuantConfig.wk_sha, + StaticLLMQuantConfig.wv_sha, + }, +) +annotate_wv_sha = partial(annotate_qkv_proj_sha, qkv_tags={StaticLLMQuantConfig.wv_sha}) @dataclass(init=False, frozen=True) @@ -472,3 +485,32 @@ class Smollm2_135M(LLMModelConfig): r2 = False r3 = False custom_annotation = () + + +@register_llm_model("smollm3-3b") +@dataclass(init=False, frozen=True) +class Smollm3_3B(LLMModelConfig): + repo_id: str = "HuggingFaceTB/SmolLM3-3B" + params_path: str = os.path.join(BASE_DIR, "../../../models/smollm3/3b_config.json") + convert_weights = convert_smollm3_weights + transform_weight = False + instruct_model = True + + num_sharding = 4 + # quant config + ptq = QuantDtype.use_16a4w_block + group_size = 32 + masked_softmax = True + r1 = False + r2 = False + r3 = False + quantization_config_wqkv_sha_16a8w = get_ptq_per_channel_quant_config( + torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver + ) + custom_annotation = ( + annotate_kv_8bit, + annotate_output_16a8w, + partial( + annotate_wqkv_sha, quantization_config=quantization_config_wqkv_sha_16a8w + ), + ) diff --git a/examples/qualcomm/oss_scripts/llama/decoder_constants.py b/examples/qualcomm/oss_scripts/llama/decoder_constants.py index a115106bd86..ac96770b889 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_constants.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_constants.py @@ -23,4 +23,5 @@ "qwen3-0_6b": "qwen3", "qwen3-1_7b": "qwen3", "smollm2_135m": "smollm2_135m", + "smollm3-3b": "smollm3", } diff --git a/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py b/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py index 73cf57336ed..b25e0cbdc7d 100644 --- a/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py +++ b/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py @@ -16,7 +16,8 @@ from executorch.backends.qualcomm.quantizer.custom_annotation import ( annotate_kv_8bit, annotate_output_16a8w, - annotate_wv_sha, + annotate_qkv_proj_sha, + StaticLLMQuantConfig, ) from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import ( @@ -218,7 +219,9 @@ def gen_eval_wrapper(model_name, args): custom_annotations = ( annotate_kv_8bit, partial( - annotate_wv_sha, quantization_config=quantization_config_wv_sha_8a4w + annotate_qkv_proj_sha, + qkv_tags={StaticLLMQuantConfig.wv_sha}, + quantization_config=quantization_config_wv_sha_8a4w, ), ) if args.llama_model == "stories110m": diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index 79f713c048a..ed9415f7f8b 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -291,7 +291,6 @@ def quantize( use_i64_token=args.embedding_quantize is not None, event_name="prepare_pt2e_prompt", ) - if scales_state_dict: set_scales( fx_graph_module, scales_state_dict, self.llama_graph_module.head_dim @@ -446,6 +445,13 @@ def compile( else: kv_config.enable_masked_softmax = True + if args.decoder_model == "smollm3-3b": + from transformers import AutoConfig + + kv_config.apply_rope_layers = AutoConfig.from_pretrained( + decoder_model_config.repo_id + ).no_rope_layers + prefill_config = copy.copy(kv_config) prefill_config.use_kv_cache = ( False if args.max_seq_len == args.prefill_ar_len else True diff --git a/examples/qualcomm/oss_scripts/llama/masking_utils.py b/examples/qualcomm/oss_scripts/llama/masking_utils.py index bed81c894f0..8d9d9ead154 100644 --- a/examples/qualcomm/oss_scripts/llama/masking_utils.py +++ b/examples/qualcomm/oss_scripts/llama/masking_utils.py @@ -127,6 +127,7 @@ def mask(self): def smart_mask_update(self, pos, n_updates): """ Smart Mask mechanism for attention mask updating + Initial mask(5x15) layout (before any updates): Each row represents a query token in the autoregressive context. ● = activate (can attend), ○ = inactivate (masked) @@ -166,15 +167,15 @@ def shift_pointer_update(self, pos, n_updates): Each row represents a query token in the autoregressive context. ● = activate (can attend), ○ = inactivate (masked) - Init mask: - 0 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ○ ○ ○ ○ - 1 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ○ ○ ○ - 2 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ - 3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ● ○ - 4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ● ● + 0 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ○ ○ ○ ○ + 1 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ○ ○ ○ + 2 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ + 3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ● ○ + 4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ● ● After 1st update (e.g., pos=0, n_updates=5, sliding_window=3): Newly added tokens are unmasked (set to 0). + 0 ○ ○ ○ ○ ○ ● ● ● ● ● ● ○ ○ ○ ○ 1 ○ ○ ○ ○ ○ ● ● ● ● ● ● ● ○ ○ ○ 2 ○ ○ ○ ○ ○ ● ● ● ● ● ● ● ● ○ ○ @@ -182,6 +183,7 @@ def shift_pointer_update(self, pos, n_updates): 4 ○ ○ ○ ○ ○ ● ● ● ● ● ● ● ● ● ● After 2nd update (e.g., pos=5, n_updates=5): + 0 ● ● ● ● ● ● ● ● ● ● ● ○ ○ ○ ○ 1 ● ● ● ● ● ● ● ● ● ● ● ● ○ ○ ○ 2 ● ● ● ● ● ● ● ● ● ● ● ● ● ○ ○ @@ -225,7 +227,6 @@ def smart_mask_update(self, pos, n_updates): 3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ 4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● - After 1st update (e.g., pos=0, n_updates=5, sliding_window=3): Newly added tokens are unmasked (set to 0). Earlier tokens lose access to older cache due to sliding window limits. @@ -236,7 +237,6 @@ def smart_mask_update(self, pos, n_updates): 3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ 4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● - After 2nd update (e.g., pos=5, n_updates=5): Sliding window shifts again, masking older positions and activate new postion. @@ -269,7 +269,6 @@ def shift_pointer_update(self, pos, n_updates): Each row represents a query token in the autoregressive context. ● = activate (can attend), ○ = inactivate (masked) - Init mask: 0 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ○ ○ ○ ○ 1 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ○ ○ ○ 2 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ @@ -277,6 +276,7 @@ def shift_pointer_update(self, pos, n_updates): 4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● After 1st update (e.g., pos=0, n_updates=5, sliding_window=3): + 0 ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ ○ ○ 1 ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ ○ 2 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ @@ -284,6 +284,7 @@ def shift_pointer_update(self, pos, n_updates): 4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● After 2nd update (e.g., pos=5, n_updates=5): + 0 ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ ○ ○ 1 ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ ○ 2 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ diff --git a/examples/qualcomm/oss_scripts/llama/model/static_llama.py b/examples/qualcomm/oss_scripts/llama/model/static_llama.py index 32764eba985..caa5153c696 100755 --- a/examples/qualcomm/oss_scripts/llama/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama/model/static_llama.py @@ -62,7 +62,7 @@ def apply_partial_rotary_emb_single( class LlamaAttention(nn.Module): - def __init__(self, config: ModelArgs, output_new_cache_only=False): + def __init__(self, layer_idx: int, config: ModelArgs, output_new_cache_only=False): super().__init__() self.config = config self.dim = config.dim @@ -75,6 +75,10 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False): self.enable_masked_softmax = getattr(config, "enable_masked_softmax", False) self.use_qk_norm = config.use_qk_norm self.qk_norm_before_rope = config.qk_norm_before_rope + apply_rope_layers = getattr(config, "apply_rope_layers", None) + self.use_rope = ( + apply_rope_layers[layer_idx] if apply_rope_layers is not None else True + ) if self.use_qk_norm: q_norm_dim = self.head_dim @@ -226,7 +230,8 @@ def forward_sha( # noqa: C901 for i in range(len(q)): if self.use_qk_norm and self.qk_norm_before_rope: q[i] = self.q_norm_fn(q[i]) - q[i] = self.apply_rope_emb(q[i], freqs_cos, freqs_sin) + if self.use_rope: + q[i] = self.apply_rope_emb(q[i], freqs_cos, freqs_sin) if self.use_qk_norm and not self.qk_norm_before_rope: q[i] = self.q_norm_fn(q[i]) if getattr(self.config, "enable_r3", False): @@ -235,7 +240,8 @@ def forward_sha( # noqa: C901 for i in range(len(k)): if self.use_qk_norm and self.qk_norm_before_rope: k[i] = self.k_norm_fn(k[i]) - k[i] = self.apply_rope_emb(k[i], freqs_cos, freqs_sin) + if self.use_rope: + k[i] = self.apply_rope_emb(k[i], freqs_cos, freqs_sin) if self.use_qk_norm and not self.qk_norm_before_rope: k[i] = self.k_norm_fn(k[i]) if getattr(self.config, "enable_r3", False): @@ -301,8 +307,10 @@ def forward( q = self.q_norm_fn(q) k = self.k_norm_fn(k) - q = self.apply_rope_emb(q, freqs_cos, freqs_sin) - k = self.apply_rope_emb(k, freqs_cos, freqs_sin).permute(0, 2, 3, 1) + if self.use_rope: + q = self.apply_rope_emb(q, freqs_cos, freqs_sin) + k = self.apply_rope_emb(k, freqs_cos, freqs_sin) + k = k.permute(0, 2, 3, 1) if self.use_qk_norm and not self.qk_norm_before_rope: q = self.q_norm_fn(q) @@ -394,10 +402,11 @@ def forward(self, x): class LlamaDecoderLayer(nn.Module): - def __init__(self, config: ModelArgs, output_new_cache_only=False): + def __init__(self, layer_idx: int, config: ModelArgs, output_new_cache_only=False): super().__init__() self.dim = config.dim self.attention = LlamaAttention( + layer_idx=layer_idx, config=config, output_new_cache_only=output_new_cache_only, ) @@ -472,8 +481,8 @@ def __init__( self.layers = nn.ModuleList( [ - LlamaDecoderLayer(config, self.output_new_cache_only) - for _ in range(config.n_layers) + LlamaDecoderLayer(i, config, self.output_new_cache_only) + for i in range(config.n_layers) ] ) self.norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps) diff --git a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp index e143d314d06..71eaea2b8d6 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp @@ -10,8 +10,8 @@ * @file * * This tool can run Llama2 110M, Llama3.2 1B / 3B, Gemma3 1B, - * phi4-mini-instruct, Qwen2.5 0.5B / 1.5B, Qwen3 0.6B / 1.7B, Smollm2 135M with - * Qualcomm AI Engine Direct. + * phi4-mini-instruct, Qwen2.5 0.5B / 1.5B, Qwen3 0.6B / 1.7B, SmolLM2 135M, + * SmolLM3 3B with Qualcomm AI Engine Direct. * */ @@ -161,6 +161,17 @@ std::string get_formatted_prompt( formatted_prompt.append(prompt); formatted_prompt.append("<|im_end|>\n\n"); break; + case example::DecoderModelVersion::kSmollm3: + if (!system_prompt.empty()) { + formatted_prompt.append("<|im_start|>system\n"); + formatted_prompt.append(system_prompt); + formatted_prompt.append("\n\n"); + } + formatted_prompt.append("<|im_start|>user\n"); + formatted_prompt.append(prompt); + formatted_prompt.append("<|im_end|>\n"); + formatted_prompt.append("<|im_start|>assistant\n"); + break; default: ET_CHECK_MSG(false, "unsupported llama version"); break; diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index dfba5fbb677..0c9be4d441d 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -129,8 +129,12 @@ Runner::Runner( decoder_model_version_ = DecoderModelVersion::kPhi4; } else if (decoder_model_version == "qwen2_5") { decoder_model_version_ = DecoderModelVersion::kQwen2_5; + } else if (decoder_model_version == "qwen3") { + decoder_model_version_ = DecoderModelVersion::kQwen3; } else if (decoder_model_version == "smollm2_135m") { decoder_model_version_ = DecoderModelVersion::kSmollm2_135m; + } else if (decoder_model_version == "smollm3") { + decoder_model_version_ = DecoderModelVersion::kSmollm3; } else { ET_CHECK_MSG(false, "Unsupported Decoder Model"); } @@ -193,7 +197,8 @@ Error Runner::load() { eos_ids->insert(tokenizer_->encode("<|end|>", 0, 0).get()[0]); } else if ( decoder_model_version_ == DecoderModelVersion::kQwen3 || - decoder_model_version_ == DecoderModelVersion::kSmollm2_135m) { + decoder_model_version_ == DecoderModelVersion::kSmollm2_135m || + decoder_model_version_ == DecoderModelVersion::kSmollm3) { eos_ids->insert(tokenizer_->encode("<|im_end|>", 0, 0).get()[0]); } else if (decoder_model_version_ == DecoderModelVersion::kGemma3) { eos_ids->insert(tokenizer_->encode("", 0, 0).get()[0]); diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index cb6c08d9c87..30fba71ecef 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -37,6 +37,7 @@ enum DecoderModelVersion { kQwen2_5, kQwen3, kSmollm2_135m, + kSmollm3 }; enum KvBitWidth { From 4f3d12ec1ac94c8eb4221344f7cc69331f49becd Mon Sep 17 00:00:00 2001 From: DannyYuyang-quic Date: Fri, 12 Sep 2025 00:55:58 +0800 Subject: [PATCH 2/2] add no_rope_layer_interval into config --- examples/models/llama/model_args.py | 3 +++ examples/models/smollm3/3b_config.json | 1 + examples/qualcomm/oss_scripts/llama/llama.py | 7 ------- examples/qualcomm/oss_scripts/llama/model/static_llama.py | 4 ++-- 4 files changed, 6 insertions(+), 9 deletions(-) diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 651047ecd96..04d29f91ac6 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -78,6 +78,9 @@ class ModelArgs: use_qk_norm: bool = False # apply normalization to q and k in the attention qk_norm_before_rope: bool = False # when to apply qk norm use_hf_rope: bool = False # Use HuggingFace's RoPE implementation + no_rope_layer_interval: Optional[int] = ( + None # Interval at which to skip RoPE. From Rope to Nope and Back Again: A New Hybrid Attention Strategy (https://huggingface.co/papers/2501.18795). + ) partial_rotary_factor: float = 1.0 rope_theta: Optional[float] = ( None # The official name to override self.rope_freq_base. diff --git a/examples/models/smollm3/3b_config.json b/examples/models/smollm3/3b_config.json index 76844dd85b5..c44f0727919 100644 --- a/examples/models/smollm3/3b_config.json +++ b/examples/models/smollm3/3b_config.json @@ -10,5 +10,6 @@ "use_scaled_rope": false, "vocab_size": 128256, "use_hf_rope": false, + "no_rope_layer_interval": 4, "attention_qkv_bias": false } diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index ed9415f7f8b..273829d214e 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -445,13 +445,6 @@ def compile( else: kv_config.enable_masked_softmax = True - if args.decoder_model == "smollm3-3b": - from transformers import AutoConfig - - kv_config.apply_rope_layers = AutoConfig.from_pretrained( - decoder_model_config.repo_id - ).no_rope_layers - prefill_config = copy.copy(kv_config) prefill_config.use_kv_cache = ( False if args.max_seq_len == args.prefill_ar_len else True diff --git a/examples/qualcomm/oss_scripts/llama/model/static_llama.py b/examples/qualcomm/oss_scripts/llama/model/static_llama.py index caa5153c696..8dcfced95fb 100755 --- a/examples/qualcomm/oss_scripts/llama/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama/model/static_llama.py @@ -75,9 +75,9 @@ def __init__(self, layer_idx: int, config: ModelArgs, output_new_cache_only=Fals self.enable_masked_softmax = getattr(config, "enable_masked_softmax", False) self.use_qk_norm = config.use_qk_norm self.qk_norm_before_rope = config.qk_norm_before_rope - apply_rope_layers = getattr(config, "apply_rope_layers", None) self.use_rope = ( - apply_rope_layers[layer_idx] if apply_rope_layers is not None else True + config.no_rope_layer_interval + and (layer_idx + 1) % config.no_rope_layer_interval ) if self.use_qk_norm: