Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/vulkan/docs/android_demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ partially lower the Llama model to Vulkan.
# The files will usually be downloaded to ~/.llama
python -m examples.models.llama.export_llama \
--disable_dynamic_shape --vulkan -kv --use_sdpa_with_kv_cache -d fp32 \
--model "llama3_2" \
-c ~/.llama/checkpoints/Llama3.2-1B/consolidated.00.pth \
-p ~/.llama/checkpoints/Llama3.2-1B/params.json \
--metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ In this demo app, we support text-only inference with up-to-date Llama models an
Meta has released prequantized INT4 SpinQuant Llama 3.2 models that ExecuTorch supports on the XNNPACK backend.
* Export Llama model and generate .pte file as below:
```
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte"
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte"
```

### For Llama 3.2 1B and 3B QAT+LoRA models
Meta has released prequantized INT4 QAT+LoRA Llama 3.2 models that ExecuTorch supports on the XNNPACK backend.
* Export Llama model and generate .pte file as below:
```
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte"
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte"
```

### For Llama 3.2 1B and 3B BF16 models
Expand All @@ -72,7 +72,7 @@ We have supported BF16 as a data type on the XNNPACK backend for Llama 3.2 1B/3B
* Export Llama model and generate .pte file as below:

```
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte"
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte"
```

For more detail using Llama 3.2 lightweight models including prompt template, please go to our official [website](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2#-llama-3.2-lightweight-models-(1b/3b)-).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ sh examples/models/llama/install_requirements.sh
Meta has released prequantized INT4 SpinQuant Llama 3.2 models that ExecuTorch supports on the XNNPACK backend.
* Export Llama model and generate .pte file as below:
```
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte"
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte"
```

### For Llama 3.2 1B and 3B QAT+LoRA models
Meta has released prequantized INT4 QAT+LoRA Llama 3.2 models that ExecuTorch supports on the XNNPACK backend.
* Export Llama model and generate .pte file as below:
```
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte"
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte"
```

### For Llama 3.2 1B and 3B BF16 models
Expand All @@ -64,7 +64,7 @@ We have supported BF16 as a data type on the XNNPACK backend for Llama 3.2 1B/3B
* Export Llama model and generate .pte file as below:

```
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte"
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte"
```

For more detail using Llama 3.2 lightweight models including prompt template, please go to our official [website](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2#-llama-3.2-lightweight-models-(1b/3b)-).
Expand Down
3 changes: 3 additions & 0 deletions examples/models/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ LLAMA_CHECKPOINT=path/to/checkpoint.pth
LLAMA_PARAMS=path/to/params.json

python -m examples.models.llama.export_llama \
--model "llama3_2" \
--checkpoint "${LLAMA_CHECKPOINT:?}" \
--params "${LLAMA_PARAMS:?}" \
-kv \
Expand All @@ -189,6 +190,7 @@ LLAMA_QUANTIZED_CHECKPOINT=path/to/spinquant/checkpoint.pth
LLAMA_PARAMS=path/to/spinquant/params.json

python -m examples.models.llama.export_llama \
--model "llama3_2" \
--checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \
--params "${LLAMA_PARAMS:?}" \
--use_sdpa_with_kv_cache \
Expand All @@ -214,6 +216,7 @@ LLAMA_QUANTIZED_CHECKPOINT=path/to/qlora/checkpoint.pth
LLAMA_PARAMS=path/to/qlora/params.json

python -m examples.models.llama.export_llama \
--model "llama3_2" \
--checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \
--params "${LLAMA_PARAMS:?}" \
-qat \
Expand Down
5 changes: 4 additions & 1 deletion examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class ModelArgs:
)
rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC.
use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1.
rope_scale_factor: int = 8
# Additional Model Metadata needed at runtime
bos_idx: int = 1
eos_idx: int = 3
Expand Down Expand Up @@ -155,7 +156,9 @@ def __init__(self, params: ModelArgs):
self.precompute_freqs_cis = hf_precompute_freqs_cis
else:
self.precompute_freqs_cis = partial(
precompute_freqs_cis, use_scaled=self.params.use_scaled_rope
precompute_freqs_cis,
use_scaled=self.params.use_scaled_rope,
scale_factor=self.params.rope_scale_factor,
)
freqs_cos, freqs_sin = self.precompute_freqs_cis(
self.params.head_dim,
Expand Down
9 changes: 9 additions & 0 deletions examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,15 @@ def __init__(self, **kwargs):
enable_dynamic_shape=self.enable_dynamic_shape,
**params,
)

if model_args.use_scaled_rope:
# Older models don't have use_scaled_rope configuration
assert self.args.model not in ["llama2", "stories110m"]

# Llama3_2 and newer models in ExecuTorch repo should set larger scale factor
if self.args.model not in ["llama3", "llama3_1"]:
model_args.rope_scale_factor = 32

if kwargs.get("verbose", False):
print("============= weights ================")
print("{key} : {weights.numel()} : {weights.size()}")
Expand Down
14 changes: 9 additions & 5 deletions examples/models/llama/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,15 @@
# Different RoPE implementations

import math
from typing import Tuple
from typing import Optional, Tuple

import torch

# ======================== Stock Implementation ========================


def apply_scaling(freqs: torch.Tensor):
def apply_scaling(freqs: torch.Tensor, scale_factor: int):
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length
Expand All @@ -41,14 +40,19 @@ def apply_scaling(freqs: torch.Tensor):


def precompute_freqs_cis(
dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False
dim: int,
end: int,
theta: float = 10000.0,
use_scaled: bool = False,
scale_factor: Optional[int] = None,
):
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim)
)
t = torch.arange(end, device=freqs.device) # pyre-ignore
if use_scaled:
freqs = apply_scaling(freqs) # pyre-ignore
assert scale_factor is not None
freqs = apply_scaling(freqs, scale_factor) # pyre-ignore
freqs = torch.outer(t, freqs).float()
freqs_cos = torch.cos(freqs)
freqs_sin = torch.sin(freqs)
Expand Down
Loading