diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 0e4d6dfd538..ef3cd9b0534 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -4762,6 +4762,65 @@ def test_qnn_backend_seq_mse(self): class TestExampleLLMScript(TestQNN): + def test_static_gemma_2b(self): + if not self.required_envs(): + self.skipTest("missing required envs") + + prompt = "My favourite condiment is " + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--model", + self.model, + "--ip", + self.ip, + "--port", + str(self.port), + "--prompt", + f"{prompt}", + "--decoder_model", + "gemma-2b", + "--model_mode", + "kv", + "--max_seq_len", + "1024", + "--eval_perplexity", + "--tasks", + "wikitext", + "--limit", + "1", + ] + if self.compile_only: + cmds.extend(["--compile_only"]) + elif self.device: + cmds.extend(["--device", self.device]) + if self.host: + cmds.extend(["--host", self.host]) + elif self.enable_x86_64: + cmds.extend(["--enable_x86_64"]) + if self.pre_gen_pte: + cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + inference_speed_ref = {"SM8650": 32, "SM8750": 36} + self.assertLessEqual(msg["wiki_ppl"], 35) + self.assertLessEqual(msg["pte_size"], 2_700_000_000) # 2.7GB + if self.model in inference_speed_ref: + self.assertGreaterEqual( + msg["inference_speed"], inference_speed_ref[self.model] + ) + def test_static_gemma3_1b(self): if not self.required_envs(): self.skipTest("missing required envs") diff --git a/examples/models/gemma/__init__.py b/examples/models/gemma/__init__.py new file mode 100644 index 00000000000..13a14ff0751 --- /dev/null +++ b/examples/models/gemma/__init__.py @@ -0,0 +1,16 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.examples.models.gemma.convert_weights import convert_weights +from executorch.examples.models.llama.model import Llama2Model + + +class GemmaModel(Llama2Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +__all__ = [ + "GemmaModel", + "convert_weights", +] diff --git a/examples/models/gemma/config/2b_config.json b/examples/models/gemma/config/2b_config.json new file mode 100644 index 00000000000..20a40723c30 --- /dev/null +++ b/examples/models/gemma/config/2b_config.json @@ -0,0 +1,19 @@ +{ + "dim": 2048, + "ffn_dim_multiplier": 1, + "hidden_dim": 16384, + "n_heads": 8, + "head_dim": 256, + "n_kv_heads": 1, + "n_layers": 18, + "act_fn": "gelu", + "norm_type": "gemma3", + "norm_eps": 1e-06, + "rope_theta": 10000.0, + "use_scaled_rope": false, + "apply_embedding": true, + "embedding_scale_factor": 45.254833995939045, + "vocab_size": 256000, + "use_hf_rope": true, + "attention_qkv_bias": false +} diff --git a/examples/models/gemma/convert_weights.py b/examples/models/gemma/convert_weights.py new file mode 100644 index 00000000000..09a17bc2266 --- /dev/null +++ b/examples/models/gemma/convert_weights.py @@ -0,0 +1,104 @@ +import argparse + +import json +import os +from typing import Dict + +import torch +from safetensors.torch import load_file + +from torchtune.models.convert_weights import get_mapped_key + + +# Weight mappings from Gemma's checkpoint to ExecuTorch's transformer parameters. +_GEMMA_TO_EXECUTORCH = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.norm.weight": "norm.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", +} + + +def gemma_to_executorch(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Convert the state dict so that it matches what ExecuTorch's transformer definition expects. + """ + converted_state_dict = {} + for key, value in state_dict.items(): + new_key = get_mapped_key(key, _GEMMA_TO_EXECUTORCH) + converted_state_dict[new_key] = value + converted_state_dict["output.weight"] = converted_state_dict[ + "tok_embeddings.weight" + ] + return converted_state_dict + + +def load_checkpoint_from_safetensors(input_dir: str) -> Dict: + index_path = os.path.join(input_dir, "model.safetensors.index.json") + if os.path.exists(index_path): + # Sharded checkpoint. + with open(index_path, "r") as f: + index = json.load(f) + weight_map = index["weight_map"] + checkpoint_shards = sorted(set(weight_map.values())) + + # Load all the shards into memory + shard_to_weights = {} + for shard in checkpoint_shards: + shard_to_weights[shard] = load_file(os.path.join(input_dir, shard)) + + # Merge tensors into consolidated state dict. + merged_state_dict = {} + for weight_name, shard in weight_map.items(): + tensor = shard_to_weights[shard][weight_name] + merged_state_dict[weight_name] = tensor + return merged_state_dict + else: + # Single checkpoint. + state_dict = load_file(os.path.join(input_dir, "model.safetensors")) + return state_dict + + +def load_checkpoint(input_dir: str) -> Dict: + pytorch_path = os.path.join(input_dir, "pytorch_model.bin") + if os.path.exists(pytorch_path): + print("Loading checkpoint from PyTorch .bin file") + return torch.load(pytorch_path, map_location="cpu", weights_only=True) + print("Loading checkpoint from safetensors directory") + return load_checkpoint_from_safetensors(input_dir) + + +def convert_weights(input_dir: str, output_file: str) -> None: + print("Loading checkpoint...") + sd = load_checkpoint(input_dir) + print("Converting checkpoint...") + sd = gemma_to_executorch(sd) + print("Saving checkpoint...") + torch.save(sd, output_file) + print("Done.") + + +def main(): + parser = argparse.ArgumentParser( + description="Convert Gemma weights to ExecuTorch transformer format." + ) + parser.add_argument( + "input_dir", + type=str, + help="Path to directory containing safetensor checkpoint files, or PyTorch checkpoint file.", + ) + parser.add_argument("output", type=str, help="Path to the output checkpoint") + + args = parser.parse_args() + convert_weights(args.input_dir, args.output) + + +if __name__ == "__main__": + main() diff --git a/examples/qualcomm/oss_scripts/llama/README.md b/examples/qualcomm/oss_scripts/llama/README.md index 1be94ec04d6..9bb76142362 100644 --- a/examples/qualcomm/oss_scripts/llama/README.md +++ b/examples/qualcomm/oss_scripts/llama/README.md @@ -5,12 +5,13 @@ This file provides you the instructions to run LLM Decoder model with different 1. LLAMA2 Stories 110M 2. LLAMA3.2 1B 3. LLAMA3.2 3B - 4. Gemma3 1B - 5. Phi4-mini-instruct - 6. QWEN2.5 0.5B / 1.5B - 7. QWEN3 0.6B / 1.7B - 8. SmolLM2 135M - 9. SmolLM3 3B + 4. Gemma 2B + 5. Gemma3 1B + 6. Phi4-mini-instruct + 7. QWEN2.5 0.5B / 1.5B + 8. QWEN3 0.6B / 1.7B + 9. SmolLM2 135M + 10. SmolLM3 3B We offer the following modes to execute the model: @@ -78,6 +79,13 @@ Default example using kv mode. python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-3b_instruct --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 ``` +#### Gemma 2B +Default example using hybrid mode +```bash +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model gemma-2b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +``` + + #### Gemma3 1B Default example using hybrid mode ```bash diff --git a/examples/qualcomm/oss_scripts/llama/__init__.py b/examples/qualcomm/oss_scripts/llama/__init__.py index 5908fcf32a6..628defc1496 100644 --- a/examples/qualcomm/oss_scripts/llama/__init__.py +++ b/examples/qualcomm/oss_scripts/llama/__init__.py @@ -24,6 +24,7 @@ ) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.examples.models.gemma import convert_weights as convert_gemma_weights from executorch.examples.models.gemma3 import convert_weights as convert_gemma3_weights from executorch.examples.models.phi_4_mini import ( convert_weights as convert_phi_4_mini_weights, @@ -300,6 +301,36 @@ class Llama3_2_3B_Instruct(LLMModelConfig): ) +@register_llm_model("gemma-2b") +@dataclass(init=False, frozen=True) +class Gemma_2B(LLMModelConfig): + repo_id: str = "google/gemma-2b-it" + params_path: str = os.path.join( + BASE_DIR, "../../../models/gemma/config/2b_config.json" + ) + convert_weights = convert_gemma_weights + transform_weight = False + instruct_model = True + + num_sharding = 4 + # quant config + ptq = QuantDtype.use_16a4w_block + group_size = 64 + masked_softmax = True + seq_mse_candidates = 0 + r1 = False + r2 = False + r3 = False + quantization_config_wv_sha_16a8w = get_ptq_per_channel_quant_config( + torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver + ) + custom_annotation = ( + annotate_kv_8bit, + annotate_output_16a8w, + partial(annotate_wv_sha, quantization_config=quantization_config_wv_sha_16a8w), + ) + + @register_llm_model("gemma3-1b") @dataclass(init=False, frozen=True) class Gemma3(LLMModelConfig): diff --git a/examples/qualcomm/oss_scripts/llama/decoder_constants.py b/examples/qualcomm/oss_scripts/llama/decoder_constants.py index ac96770b889..d43ceb8351a 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_constants.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_constants.py @@ -14,6 +14,7 @@ DECODER_MODEL_VERSION = { "stories260k": "llama2", "stories110m": "llama2", + "gemma-2b": "gemma", "gemma3-1b": "gemma3", "phi_4_mini": "phi_4_mini", "llama3_2-1b_instruct": "llama3", diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index ae5ae63d509..887e680341f 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -327,6 +327,13 @@ def quantize( chat_template, args.prompt[0], args.system_prompt ) ) + + # Gemma may produce unexpected output if the prompt contains an extra token. + # This can happen after applying a prompt template, which might inject unintentionally. + # To prevent decoding issues, we explicitly remove token + if chat_template and args.decoder_model in {"gemma-2b", "gemma3-1b"}: + prompt = prompt.replace("", "") + graph_module_inference( use_kv_cache=self.llama_meta["get_use_kv_cache"], get_example_inputs=self.get_example_inputs, @@ -534,14 +541,13 @@ def compile( state_dict = torch.load( checkpoint, weights_only=True, map_location="cpu", mmap=True ) - if args.decoder_model == "gemma3-1b": + if args.decoder_model in {"gemma-2b", "gemma3-1b"}: for k, v in state_dict.items(): if "norm" not in k: continue # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16) # See https://github.com/huggingface/transformers/pull/29402 state_dict[k] = v.float() + torch.ones(v.shape, dtype=torch.float32) - else: state_dict = torch.load( args.checkpoint, weights_only=True, map_location="cpu", mmap=True @@ -1286,7 +1292,11 @@ def export_llama(args) -> None: ) tokenizer_artifacts = tokenizer.save_pretrained(args.artifact) tokenizer_config = tokenizer_artifacts[0] - runtime_tokenizer_path = tokenizer_artifacts[-1] + if args.decoder_model == "gemma-2b": + # For Gemma, use tokenizer.model as it doesn't provide pre_tokenizer in tokenizer.json. + runtime_tokenizer_path = tokenizer_artifacts[-3] + else: + runtime_tokenizer_path = tokenizer_artifacts[-1] tokenizer = get_tokenizer(runtime_tokenizer_path, tokenizer_config) # TODO: Remove this once error is resolved. diff --git a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp index 71eaea2b8d6..2bffb35852a 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp @@ -9,7 +9,7 @@ /** * @file * - * This tool can run Llama2 110M, Llama3.2 1B / 3B, Gemma3 1B, + * This tool can run Llama2 110M, Llama3.2 1B / 3B, Gemma 2B, Gemma3 1B, * phi4-mini-instruct, Qwen2.5 0.5B / 1.5B, Qwen3 0.6B / 1.7B, SmolLM2 135M, * SmolLM3 3B with Qualcomm AI Engine Direct. * @@ -117,6 +117,7 @@ std::string get_formatted_prompt( formatted_prompt.append( "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"); break; + case example::DecoderModelVersion::kGemma: case example::DecoderModelVersion::kGemma3: formatted_prompt.append("user\n"); formatted_prompt.append(prompt); diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index fe45d4b6a67..0c4884bbccf 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -122,6 +122,8 @@ Runner::Runner( decoder_model_version_ = DecoderModelVersion::kLlama2; } else if (decoder_model_version == "llama3") { decoder_model_version_ = DecoderModelVersion::kLlama3; + } else if (decoder_model_version == "gemma") { + decoder_model_version_ = DecoderModelVersion::kGemma; } else if (decoder_model_version == "gemma3") { decoder_model_version_ = DecoderModelVersion::kGemma3; cache_mode_ = CacheMode::HybridCache; @@ -199,7 +201,9 @@ Error Runner::load() { decoder_model_version_ == DecoderModelVersion::kSmollm2_135m || decoder_model_version_ == DecoderModelVersion::kSmollm3) { eos_ids->insert(tokenizer_->encode("<|im_end|>", 0, 0).get()[0]); - } else if (decoder_model_version_ == DecoderModelVersion::kGemma3) { + } else if ( + decoder_model_version_ == DecoderModelVersion::kGemma || + decoder_model_version_ == DecoderModelVersion::kGemma3) { eos_ids->insert(tokenizer_->encode("", 0, 0).get()[0]); } diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index 9f290d79c75..1472093ab66 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -32,6 +32,7 @@ namespace example { enum DecoderModelVersion { kLlama2 = 0, kLlama3, + kGemma, kGemma3, kPhi4, kQwen2_5,