Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4762,6 +4762,65 @@ def test_qnn_backend_seq_mse(self):


class TestExampleLLMScript(TestQNN):
def test_static_gemma_2b(self):
if not self.required_envs():
self.skipTest("missing required envs")

prompt = "My favourite condiment is "
cmds = [
"python",
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
"--artifact",
self.artifact_dir,
"--build_folder",
self.build_folder,
"--model",
self.model,
"--ip",
self.ip,
"--port",
str(self.port),
"--prompt",
f"{prompt}",
"--decoder_model",
"gemma-2b",
"--model_mode",
"kv",
"--max_seq_len",
"1024",
"--eval_perplexity",
"--tasks",
"wikitext",
"--limit",
"1",
]
if self.compile_only:
cmds.extend(["--compile_only"])
elif self.device:
cmds.extend(["--device", self.device])
if self.host:
cmds.extend(["--host", self.host])
elif self.enable_x86_64:
cmds.extend(["--enable_x86_64"])
if self.pre_gen_pte:
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])

p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
with Listener((self.ip, self.port)) as listener:
conn = listener.accept()
p.communicate()
msg = json.loads(conn.recv())
if "Error" in msg:
self.fail(msg["Error"])
else:
inference_speed_ref = {"SM8650": 32, "SM8750": 36}
self.assertLessEqual(msg["wiki_ppl"], 35)
self.assertLessEqual(msg["pte_size"], 2_700_000_000) # 2.7GB
if self.model in inference_speed_ref:
self.assertGreaterEqual(
msg["inference_speed"], inference_speed_ref[self.model]
)

def test_static_gemma3_1b(self):
if not self.required_envs():
self.skipTest("missing required envs")
Expand Down
16 changes: 16 additions & 0 deletions examples/models/gemma/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.examples.models.gemma.convert_weights import convert_weights
from executorch.examples.models.llama.model import Llama2Model


class GemmaModel(Llama2Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)


__all__ = [
"GemmaModel",
"convert_weights",
]
19 changes: 19 additions & 0 deletions examples/models/gemma/config/2b_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"dim": 2048,
"ffn_dim_multiplier": 1,
"hidden_dim": 16384,
"n_heads": 8,
"head_dim": 256,
"n_kv_heads": 1,
"n_layers": 18,
"act_fn": "gelu",
"norm_type": "gemma3",
"norm_eps": 1e-06,
"rope_theta": 10000.0,
"use_scaled_rope": false,
"apply_embedding": true,
"embedding_scale_factor": 45.254833995939045,
"vocab_size": 256000,
"use_hf_rope": true,
"attention_qkv_bias": false
}
104 changes: 104 additions & 0 deletions examples/models/gemma/convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import argparse

import json
import os
from typing import Dict

import torch
from safetensors.torch import load_file

from torchtune.models.convert_weights import get_mapped_key


# Weight mappings from Gemma's checkpoint to ExecuTorch's transformer parameters.
_GEMMA_TO_EXECUTORCH = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.norm.weight": "norm.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
}


def gemma_to_executorch(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Convert the state dict so that it matches what ExecuTorch's transformer definition expects.
"""
converted_state_dict = {}
for key, value in state_dict.items():
new_key = get_mapped_key(key, _GEMMA_TO_EXECUTORCH)
converted_state_dict[new_key] = value
converted_state_dict["output.weight"] = converted_state_dict[
"tok_embeddings.weight"
]
return converted_state_dict


def load_checkpoint_from_safetensors(input_dir: str) -> Dict:
index_path = os.path.join(input_dir, "model.safetensors.index.json")
if os.path.exists(index_path):
# Sharded checkpoint.
with open(index_path, "r") as f:
index = json.load(f)
weight_map = index["weight_map"]
checkpoint_shards = sorted(set(weight_map.values()))

# Load all the shards into memory
shard_to_weights = {}
for shard in checkpoint_shards:
shard_to_weights[shard] = load_file(os.path.join(input_dir, shard))

# Merge tensors into consolidated state dict.
merged_state_dict = {}
for weight_name, shard in weight_map.items():
tensor = shard_to_weights[shard][weight_name]
merged_state_dict[weight_name] = tensor
return merged_state_dict
else:
# Single checkpoint.
state_dict = load_file(os.path.join(input_dir, "model.safetensors"))
return state_dict


def load_checkpoint(input_dir: str) -> Dict:
pytorch_path = os.path.join(input_dir, "pytorch_model.bin")
if os.path.exists(pytorch_path):
print("Loading checkpoint from PyTorch .bin file")
return torch.load(pytorch_path, map_location="cpu", weights_only=True)
print("Loading checkpoint from safetensors directory")
return load_checkpoint_from_safetensors(input_dir)


def convert_weights(input_dir: str, output_file: str) -> None:
print("Loading checkpoint...")
sd = load_checkpoint(input_dir)
print("Converting checkpoint...")
sd = gemma_to_executorch(sd)
print("Saving checkpoint...")
torch.save(sd, output_file)
print("Done.")


def main():
parser = argparse.ArgumentParser(
description="Convert Gemma weights to ExecuTorch transformer format."
)
parser.add_argument(
"input_dir",
type=str,
help="Path to directory containing safetensor checkpoint files, or PyTorch checkpoint file.",
)
parser.add_argument("output", type=str, help="Path to the output checkpoint")

args = parser.parse_args()
convert_weights(args.input_dir, args.output)


if __name__ == "__main__":
main()
20 changes: 14 additions & 6 deletions examples/qualcomm/oss_scripts/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ This file provides you the instructions to run LLM Decoder model with different
1. LLAMA2 Stories 110M
2. LLAMA3.2 1B
3. LLAMA3.2 3B
4. Gemma3 1B
5. Phi4-mini-instruct
6. QWEN2.5 0.5B / 1.5B
7. QWEN3 0.6B / 1.7B
8. SmolLM2 135M
9. SmolLM3 3B
4. Gemma 2B
5. Gemma3 1B
6. Phi4-mini-instruct
7. QWEN2.5 0.5B / 1.5B
8. QWEN3 0.6B / 1.7B
9. SmolLM2 135M
10. SmolLM3 3B


We offer the following modes to execute the model:
Expand Down Expand Up @@ -78,6 +79,13 @@ Default example using kv mode.
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-3b_instruct --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
```

#### Gemma 2B
Default example using hybrid mode
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model gemma-2b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
```


#### Gemma3 1B
Default example using hybrid mode
```bash
Expand Down
31 changes: 31 additions & 0 deletions examples/qualcomm/oss_scripts/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype

from executorch.examples.models.gemma import convert_weights as convert_gemma_weights
from executorch.examples.models.gemma3 import convert_weights as convert_gemma3_weights
from executorch.examples.models.phi_4_mini import (
convert_weights as convert_phi_4_mini_weights,
Expand Down Expand Up @@ -300,6 +301,36 @@ class Llama3_2_3B_Instruct(LLMModelConfig):
)


@register_llm_model("gemma-2b")
@dataclass(init=False, frozen=True)
class Gemma_2B(LLMModelConfig):
repo_id: str = "google/gemma-2b-it"
params_path: str = os.path.join(
BASE_DIR, "../../../models/gemma/config/2b_config.json"
)
convert_weights = convert_gemma_weights
transform_weight = False
instruct_model = True

num_sharding = 4
# quant config
ptq = QuantDtype.use_16a4w_block
group_size = 64
masked_softmax = True
seq_mse_candidates = 0
r1 = False
r2 = False
r3 = False
quantization_config_wv_sha_16a8w = get_ptq_per_channel_quant_config(
torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver
)
custom_annotation = (
annotate_kv_8bit,
annotate_output_16a8w,
partial(annotate_wv_sha, quantization_config=quantization_config_wv_sha_16a8w),
)


@register_llm_model("gemma3-1b")
@dataclass(init=False, frozen=True)
class Gemma3(LLMModelConfig):
Expand Down
1 change: 1 addition & 0 deletions examples/qualcomm/oss_scripts/llama/decoder_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
DECODER_MODEL_VERSION = {
"stories260k": "llama2",
"stories110m": "llama2",
"gemma-2b": "gemma",
"gemma3-1b": "gemma3",
"phi_4_mini": "phi_4_mini",
"llama3_2-1b_instruct": "llama3",
Expand Down
16 changes: 13 additions & 3 deletions examples/qualcomm/oss_scripts/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,13 @@ def quantize(
chat_template, args.prompt[0], args.system_prompt
)
)

# Gemma may produce unexpected output if the prompt contains an extra <bos> token.
# This can happen after applying a prompt template, which might inject <bos> unintentionally.
# To prevent decoding issues, we explicitly remove <bos> token
if chat_template and args.decoder_model in {"gemma-2b", "gemma3-1b"}:
prompt = prompt.replace("<bos>", "")

graph_module_inference(
use_kv_cache=self.llama_meta["get_use_kv_cache"],
get_example_inputs=self.get_example_inputs,
Expand Down Expand Up @@ -534,14 +541,13 @@ def compile(
state_dict = torch.load(
checkpoint, weights_only=True, map_location="cpu", mmap=True
)
if args.decoder_model == "gemma3-1b":
if args.decoder_model in {"gemma-2b", "gemma3-1b"}:
for k, v in state_dict.items():
if "norm" not in k:
continue
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
state_dict[k] = v.float() + torch.ones(v.shape, dtype=torch.float32)

else:
state_dict = torch.load(
args.checkpoint, weights_only=True, map_location="cpu", mmap=True
Expand Down Expand Up @@ -1286,7 +1292,11 @@ def export_llama(args) -> None:
)
tokenizer_artifacts = tokenizer.save_pretrained(args.artifact)
tokenizer_config = tokenizer_artifacts[0]
runtime_tokenizer_path = tokenizer_artifacts[-1]
if args.decoder_model == "gemma-2b":
# For Gemma, use tokenizer.model as it doesn't provide pre_tokenizer in tokenizer.json.
runtime_tokenizer_path = tokenizer_artifacts[-3]
else:
runtime_tokenizer_path = tokenizer_artifacts[-1]
tokenizer = get_tokenizer(runtime_tokenizer_path, tokenizer_config)

# TODO: Remove this once error is resolved.
Expand Down
3 changes: 2 additions & 1 deletion examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
/**
* @file
*
* This tool can run Llama2 110M, Llama3.2 1B / 3B, Gemma3 1B,
* This tool can run Llama2 110M, Llama3.2 1B / 3B, Gemma 2B, Gemma3 1B,
* phi4-mini-instruct, Qwen2.5 0.5B / 1.5B, Qwen3 0.6B / 1.7B, SmolLM2 135M,
* SmolLM3 3B with Qualcomm AI Engine Direct.
*
Expand Down Expand Up @@ -117,6 +117,7 @@ std::string get_formatted_prompt(
formatted_prompt.append(
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
break;
case example::DecoderModelVersion::kGemma:
case example::DecoderModelVersion::kGemma3:
formatted_prompt.append("<start_of_turn>user\n");
formatted_prompt.append(prompt);
Expand Down
6 changes: 5 additions & 1 deletion examples/qualcomm/oss_scripts/llama/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ Runner<T>::Runner(
decoder_model_version_ = DecoderModelVersion::kLlama2;
} else if (decoder_model_version == "llama3") {
decoder_model_version_ = DecoderModelVersion::kLlama3;
} else if (decoder_model_version == "gemma") {
decoder_model_version_ = DecoderModelVersion::kGemma;
} else if (decoder_model_version == "gemma3") {
decoder_model_version_ = DecoderModelVersion::kGemma3;
cache_mode_ = CacheMode::HybridCache;
Expand Down Expand Up @@ -199,7 +201,9 @@ Error Runner<T>::load() {
decoder_model_version_ == DecoderModelVersion::kSmollm2_135m ||
decoder_model_version_ == DecoderModelVersion::kSmollm3) {
eos_ids->insert(tokenizer_->encode("<|im_end|>", 0, 0).get()[0]);
} else if (decoder_model_version_ == DecoderModelVersion::kGemma3) {
} else if (
decoder_model_version_ == DecoderModelVersion::kGemma ||
decoder_model_version_ == DecoderModelVersion::kGemma3) {
eos_ids->insert(tokenizer_->encode("<end_of_turn>", 0, 0).get()[0]);
}

Expand Down
1 change: 1 addition & 0 deletions examples/qualcomm/oss_scripts/llama/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace example {
enum DecoderModelVersion {
kLlama2 = 0,
kLlama3,
kGemma,
kGemma3,
kPhi4,
kQwen2_5,
Expand Down
Loading