Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions backends/qualcomm/quantizer/custom_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from enum import Enum, unique
from typing import Sequence

import torch
Expand Down Expand Up @@ -50,6 +51,17 @@ def annotate_down_proj(
)


@unique
class StaticLLMQuantConfig(Enum):
"""
Layer namespace configuration for Qualcomm's static LLaMA quantization.
"""

wq_sha = "wq_sha" # Query weight (single head)
wk_sha = "wk_sha" # Key weight (single head)
wv_sha = "wv_sha" # Value weight (single head)


def annotate_eurobert(gm: torch.fx.GraphModule):
"""
QNN does not support int32 -> signed 16bit quant
Expand Down Expand Up @@ -185,11 +197,35 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):
)


def annotate_wv_sha(gm: torch.fx.GraphModule, quantization_config: QuantizationConfig):
def annotate_qkv_proj_sha(
gm: torch.fx.GraphModule,
quantization_config: QuantizationConfig,
qkv_tags: set[StaticLLMQuantConfig],
):
"""
Annotates QKV projection layers in a GraphModule for quantization,
specifically layers defined in StaticLLMQuantConfig.

Args:
qkv_tags (set[StaticLLMQuantConfig]): A set of enum tags indicating which QKV layers
(e.g., wq, wk, wv) should be annotated for quantization. Only tags defined in
StaticLLMQuantConfig are allowed.

Raises:
ValueError: If any tag in `qkv_tags` is not among the allowed enum members.
"""

# Get all valid tags from the StaticLLMQuantConfig enum
allowed_tags = set(StaticLLMQuantConfig)
invalid_tags = qkv_tags - allowed_tags
if invalid_tags:
raise ValueError(
f"Invalid qkv tags: {invalid_tags}. Allowed tags are: {allowed_tags}"
)

for node in gm.graph.nodes:
if (
node.target == torch.ops.aten.conv2d.default
and "wv_sha" in node.meta["stack_trace"]
if node.target == torch.ops.aten.conv2d.default and any(
tag.value in node.meta["stack_trace"] for tag in qkv_tags
):
input_qspec_map = {}
input_qspec_map[node.args[0]] = quantization_config.input_activation
Expand Down
99 changes: 81 additions & 18 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5138,6 +5138,60 @@ def test_static_qwen3(self):
msg["inference_speed"], inference_speed_ref[self.model]
)

def test_qwen2_5(self):
if not self.required_envs([]):
self.skipTest("missing required envs")
prompt = "My favourite condiment is "
cmds = [
"python",
f"{self.executorch_root}/examples/qualcomm/oss_scripts/qwen2_5/qwen2_5.py",
"--prompt",
prompt,
"--decoder_model",
"qwen2.5_0.5B",
"--ptq",
"16a8w",
"--enable_spinquant_r3",
"--max_seq_len",
"128",
"--artifact",
self.artifact_dir,
"--build_folder",
self.build_folder,
"--model",
self.model,
"--ip",
self.ip,
"--port",
str(self.port),
]
if self.compile_only:
cmds.extend(["--compile_only"])
elif self.device:
cmds.extend(["--device", self.device])
if self.host:
cmds.extend(["--host", self.host])
elif self.enable_x86_64:
cmds.extend(["--enable_x86_64"])
if self.pre_gen_pte:
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])

golden_start_with = "My favourite condiment is iced tea."
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
with Listener((self.ip, self.port)) as listener:
conn = listener.accept()
p.communicate()
msg = json.loads(conn.recv())
if "Error" in msg:
self.fail(msg["Error"])
else:
if not self.compile_only:
model_out = msg["result"][0]
self.assertTrue(
model_out.startswith(golden_start_with),
f"Expected Output: '{golden_start_with}' Actual Output: '{model_out}'",
)

def test_static_smollm2(self):
if not self.required_envs():
self.skipTest("missing required envs")
Expand Down Expand Up @@ -5171,6 +5225,8 @@ def test_static_smollm2(self):
"--eval_perplexity",
"--task",
"wikitext",
"--limit",
"1",
]
if self.compile_only:
cmds.extend(["--compile_only"])
Expand All @@ -5194,22 +5250,14 @@ def test_static_smollm2(self):
self.assertLessEqual(msg["wiki_ppl"], 25)
self.assertGreaterEqual(msg["inference_speed"], 200)

def test_qwen2_5(self):
if not self.required_envs([]):
def test_static_smollm3(self):
if not self.required_envs():
self.skipTest("missing required envs")

prompt = "My favourite condiment is "
cmds = [
"python",
f"{self.executorch_root}/examples/qualcomm/oss_scripts/qwen2_5/qwen2_5.py",
"--prompt",
prompt,
"--decoder_model",
"qwen2.5_0.5B",
"--ptq",
"16a8w",
"--enable_spinquant_r3",
"--max_seq_len",
"128",
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
"--artifact",
self.artifact_dir,
"--build_folder",
Expand All @@ -5220,6 +5268,21 @@ def test_qwen2_5(self):
self.ip,
"--port",
str(self.port),
"--prompt",
f"{prompt}",
"--decoder_model",
"smollm3-3b",
"--model_mode",
"kv",
"--temperature",
"0",
"--max_seq_len",
"1024",
"--eval_perplexity",
"--task",
"wikitext",
"--limit",
"1",
]
if self.compile_only:
cmds.extend(["--compile_only"])
Expand All @@ -5232,7 +5295,6 @@ def test_qwen2_5(self):
if self.pre_gen_pte:
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])

golden_start_with = "My favourite condiment is iced tea."
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
with Listener((self.ip, self.port)) as listener:
conn = listener.accept()
Expand All @@ -5241,11 +5303,12 @@ def test_qwen2_5(self):
if "Error" in msg:
self.fail(msg["Error"])
else:
if not self.compile_only:
model_out = msg["result"][0]
self.assertTrue(
model_out.startswith(golden_start_with),
f"Expected Output: '{golden_start_with}' Actual Output: '{model_out}'",
inference_speed_ref = {"SM8650": 23, "SM8750": 28}
self.assertLessEqual(msg["wiki_ppl"], 10)
self.assertLessEqual(msg["pte_size"], 2_600_000_000) # 2.6GB
if self.model in inference_speed_ref:
self.assertGreaterEqual(
msg["inference_speed"], inference_speed_ref[self.model]
)


Expand Down
3 changes: 3 additions & 0 deletions examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ class ModelArgs:
use_qk_norm: bool = False # apply normalization to q and k in the attention
qk_norm_before_rope: bool = False # when to apply qk norm
use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
no_rope_layer_interval: Optional[int] = (
None # Interval at which to skip RoPE. From Rope to Nope and Back Again: A New Hybrid Attention Strategy (https://huggingface.co/papers/2501.18795).
)
partial_rotary_factor: float = 1.0
rope_theta: Optional[float] = (
None # The official name to override self.rope_freq_base.
Expand Down
15 changes: 15 additions & 0 deletions examples/models/smollm3/3b_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"dim": 2048,
"ffn_dim_multiplier": 1,
"hidden_dim": 11008,
"n_heads": 16,
"n_kv_heads": 4,
"n_layers": 36,
"norm_eps": 1e-06,
"rope_theta": 5000000.0,
"use_scaled_rope": false,
"vocab_size": 128256,
"use_hf_rope": false,
"no_rope_layer_interval": 4,
"attention_qkv_bias": false
}
16 changes: 16 additions & 0 deletions examples/models/smollm3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.examples.models.llama.model import Llama2Model
from executorch.examples.models.smollm3.convert_weights import convert_weights


class SmolLM3Model(Llama2Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)


__all__ = [
"SmolLM3Model",
"convert_weights",
]
112 changes: 112 additions & 0 deletions examples/models/smollm3/convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import argparse
import json
import os
from typing import Dict

import torch

from safetensors.torch import load_file

from torchtune.models.convert_weights import get_mapped_key


_SMOLLM_TO_META = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.norm.weight": "norm.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
}


def smollm_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Convert a state dict from torchtune's format to Meta's format. This function
doesn't handle any sharding or splitting of state dicts. It follows the
state_dict IN -> state_dict OUT pattern.

Args:
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.

Returns:
Dict[str, torch.Tensor]: State dict in Meta's format.
"""
converted_state_dict = {}
for key, value in state_dict.items():
new_key = get_mapped_key(key, _SMOLLM_TO_META)
converted_state_dict[new_key] = value
converted_state_dict["output.weight"] = converted_state_dict[
"tok_embeddings.weight"
]

return converted_state_dict


def load_checkpoint_from_safetensors(input_dir: str) -> Dict:
index_path = os.path.join(input_dir, "model.safetensors.index.json")
if os.path.exists(index_path):
# Sharded checkpoint.
with open(index_path, "r") as f:
index = json.load(f)
weight_map = index["weight_map"]
checkpoint_shards = sorted(set(weight_map.values()))

# Load all the shards into memory
shard_to_weights = {}
for shard in checkpoint_shards:
shard_to_weights[shard] = load_file(os.path.join(input_dir, shard))

# Merge tensors into consolidated state dict.
merged_state_dict = {}
for weight_name, shard in weight_map.items():
tensor = shard_to_weights[shard][weight_name]
merged_state_dict[weight_name] = tensor
return merged_state_dict
else:
# Single checkpoint.
state_dict = load_file(os.path.join(input_dir, "model.safetensors"))
return state_dict


def load_checkpoint(input_dir: str) -> Dict:
pytorch_path = os.path.join(input_dir, "pytorch_model.bin")
if os.path.exists(pytorch_path):
print("Loading checkpoint from PyTorch .bin file")
return torch.load(pytorch_path, map_location="cpu", weights_only=True)
print("Loading checkpoint from safetensors directory")
return load_checkpoint_from_safetensors(input_dir)


def convert_weights(input_dir: str, output_file: str) -> None:
print("Loading checkpoint...")
sd = load_checkpoint(input_dir)
print("Converting checkpoint...")
sd = smollm_to_meta(sd)
print("Saving checkpoint...")
torch.save(sd, output_file)
print("Done.")


def main():
parser = argparse.ArgumentParser(
description="Convert SmolLM weights to Meta format."
)
parser.add_argument(
"input_dir",
type=str,
help="Path to directory containing checkpoint files",
)
parser.add_argument("output", type=str, help="Path to the output checkpoint")

args = parser.parse_args()
convert_weights(args.input_dir, args.output)


if __name__ == "__main__":
main()
13 changes: 10 additions & 3 deletions examples/qualcomm/oss_scripts/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ This file provides you the instructions to run LLM Decoder model with different
5. Phi4-mini-instruct
6. QWEN2.5 0.5B / 1.5B
7. QWEN3 0.6B / 1.7B
8. SMOLLM2 135M
8. SmolLM2 135M
9. SmolLM3 3B


We offer the following modes to execute the model:
Expand Down Expand Up @@ -113,10 +114,16 @@ Default example using hybrid mode
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen3-1_7b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
```

#### SMOLLM2
#### SmolLM2
Default example using hybrid mode.
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -H mlgtw-linux -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm2_135m --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm2_135m --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
```

#### SmolLM3
Default example using kv mode.
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm3-3b --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
```


Expand Down
Loading
Loading