diff --git a/backends/vulkan/docs/android_demo.md b/backends/vulkan/docs/android_demo.md index 2a4faacc0c8..9b45fcde9a8 100644 --- a/backends/vulkan/docs/android_demo.md +++ b/backends/vulkan/docs/android_demo.md @@ -59,6 +59,7 @@ partially lower the Llama model to Vulkan. # The files will usually be downloaded to ~/.llama python -m examples.models.llama.export_llama \ --disable_dynamic_shape --vulkan -kv --use_sdpa_with_kv_cache -d fp32 \ + --model "llama3_2" \ -c ~/.llama/checkpoints/Llama3.2-1B/consolidated.00.pth \ -p ~/.llama/checkpoints/Llama3.2-1B/params.json \ --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' diff --git a/examples/demo-apps/android/LlamaDemo/docs/delegates/xnnpack_README.md b/examples/demo-apps/android/LlamaDemo/docs/delegates/xnnpack_README.md index 2a6ddbbfe09..087bd242608 100644 --- a/examples/demo-apps/android/LlamaDemo/docs/delegates/xnnpack_README.md +++ b/examples/demo-apps/android/LlamaDemo/docs/delegates/xnnpack_README.md @@ -56,14 +56,14 @@ In this demo app, we support text-only inference with up-to-date Llama models an Meta has released prequantized INT4 SpinQuant Llama 3.2 models that ExecuTorch supports on the XNNPACK backend. * Export Llama model and generate .pte file as below: ``` -python -m examples.models.llama.export_llama --checkpoint --params -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte" +python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint --params -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte" ``` ### For Llama 3.2 1B and 3B QAT+LoRA models Meta has released prequantized INT4 QAT+LoRA Llama 3.2 models that ExecuTorch supports on the XNNPACK backend. * Export Llama model and generate .pte file as below: ``` -python -m examples.models.llama.export_llama --checkpoint --params -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte" +python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint --params -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte" ``` ### For Llama 3.2 1B and 3B BF16 models @@ -72,7 +72,7 @@ We have supported BF16 as a data type on the XNNPACK backend for Llama 3.2 1B/3B * Export Llama model and generate .pte file as below: ``` -python -m examples.models.llama.export_llama --checkpoint --params -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte" +python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint --params -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte" ``` For more detail using Llama 3.2 lightweight models including prompt template, please go to our official [website](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2#-llama-3.2-lightweight-models-(1b/3b)-). diff --git a/examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md b/examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md index 1968c550763..83b1ba76f85 100644 --- a/examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md +++ b/examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md @@ -48,14 +48,14 @@ sh examples/models/llama/install_requirements.sh Meta has released prequantized INT4 SpinQuant Llama 3.2 models that ExecuTorch supports on the XNNPACK backend. * Export Llama model and generate .pte file as below: ``` -python -m examples.models.llama.export_llama --checkpoint --params -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte" +python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint --params -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte" ``` ### For Llama 3.2 1B and 3B QAT+LoRA models Meta has released prequantized INT4 QAT+LoRA Llama 3.2 models that ExecuTorch supports on the XNNPACK backend. * Export Llama model and generate .pte file as below: ``` -python -m examples.models.llama.export_llama --checkpoint --params -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte" +python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint --params -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte" ``` ### For Llama 3.2 1B and 3B BF16 models @@ -64,7 +64,7 @@ We have supported BF16 as a data type on the XNNPACK backend for Llama 3.2 1B/3B * Export Llama model and generate .pte file as below: ``` -python -m examples.models.llama.export_llama --checkpoint --params -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte" +python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint --params -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte" ``` For more detail using Llama 3.2 lightweight models including prompt template, please go to our official [website](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2#-llama-3.2-lightweight-models-(1b/3b)-). diff --git a/examples/models/llama/README.md b/examples/models/llama/README.md index cfa0fe04b1b..e621ce5d49d 100644 --- a/examples/models/llama/README.md +++ b/examples/models/llama/README.md @@ -168,6 +168,7 @@ LLAMA_CHECKPOINT=path/to/checkpoint.pth LLAMA_PARAMS=path/to/params.json python -m examples.models.llama.export_llama \ + --model "llama3_2" \ --checkpoint "${LLAMA_CHECKPOINT:?}" \ --params "${LLAMA_PARAMS:?}" \ -kv \ @@ -189,6 +190,7 @@ LLAMA_QUANTIZED_CHECKPOINT=path/to/spinquant/checkpoint.pth LLAMA_PARAMS=path/to/spinquant/params.json python -m examples.models.llama.export_llama \ + --model "llama3_2" \ --checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \ --params "${LLAMA_PARAMS:?}" \ --use_sdpa_with_kv_cache \ @@ -214,6 +216,7 @@ LLAMA_QUANTIZED_CHECKPOINT=path/to/qlora/checkpoint.pth LLAMA_PARAMS=path/to/qlora/params.json python -m examples.models.llama.export_llama \ + --model "llama3_2" \ --checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \ --params "${LLAMA_PARAMS:?}" \ -qat \ diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 10d660d37a6..aaef3cd9804 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -113,6 +113,7 @@ class ModelArgs: ) rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC. use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1. + rope_scale_factor: int = 8 # Additional Model Metadata needed at runtime bos_idx: int = 1 eos_idx: int = 3 @@ -155,7 +156,9 @@ def __init__(self, params: ModelArgs): self.precompute_freqs_cis = hf_precompute_freqs_cis else: self.precompute_freqs_cis = partial( - precompute_freqs_cis, use_scaled=self.params.use_scaled_rope + precompute_freqs_cis, + use_scaled=self.params.use_scaled_rope, + scale_factor=self.params.rope_scale_factor, ) freqs_cos, freqs_sin = self.precompute_freqs_cis( self.params.head_dim, diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 2385aba6d5d..9f7994916ab 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -145,6 +145,15 @@ def __init__(self, **kwargs): enable_dynamic_shape=self.enable_dynamic_shape, **params, ) + + if model_args.use_scaled_rope: + # Older models don't have use_scaled_rope configuration + assert self.args.model not in ["llama2", "stories110m"] + + # Llama3_2 and newer models in ExecuTorch repo should set larger scale factor + if self.args.model not in ["llama3", "llama3_1"]: + model_args.rope_scale_factor = 32 + if kwargs.get("verbose", False): print("============= weights ================") print("{key} : {weights.numel()} : {weights.size()}") diff --git a/examples/models/llama/rope.py b/examples/models/llama/rope.py index 1445787f5eb..cd3ddb0d3b8 100644 --- a/examples/models/llama/rope.py +++ b/examples/models/llama/rope.py @@ -8,16 +8,15 @@ # Different RoPE implementations import math -from typing import Tuple +from typing import Optional, Tuple import torch # ======================== Stock Implementation ======================== -def apply_scaling(freqs: torch.Tensor): +def apply_scaling(freqs: torch.Tensor, scale_factor: int): # Values obtained from grid search - scale_factor = 8 low_freq_factor = 1 high_freq_factor = 4 old_context_len = 8192 # original llama3 length @@ -41,14 +40,19 @@ def apply_scaling(freqs: torch.Tensor): def precompute_freqs_cis( - dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False + dim: int, + end: int, + theta: float = 10000.0, + use_scaled: bool = False, + scale_factor: Optional[int] = None, ): freqs = 1.0 / ( theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim) ) t = torch.arange(end, device=freqs.device) # pyre-ignore if use_scaled: - freqs = apply_scaling(freqs) # pyre-ignore + assert scale_factor is not None + freqs = apply_scaling(freqs, scale_factor) # pyre-ignore freqs = torch.outer(t, freqs).float() freqs_cos = torch.cos(freqs) freqs_sin = torch.sin(freqs)