From 1d6b3806b58210c94bcab02906a1419c3ab31b39 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 24 Oct 2024 08:07:09 -0700 Subject: [PATCH] Support quantized llama models --- README.md | 2 +- examples/models/llama/README.md | 105 +++++++++++++++++++++++--------- 2 files changed, 76 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index b27845e9f55..be7ff32229d 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ Check out the [Getting Started](https://pytorch.org/executorch/stable/getting-st Check out the examples of [Llama](./examples/models/llama/README.md), [Llava](./examples/models/llava/README.md) and [other models](./examples/README.md) running on edge devices using ExecuTorch. -**[UPDATE - 09/25]** We have added support for running [Llama 3.2 1B/3B](./examples/models/llama/README.md) models via ExecuTorch. +**[UPDATE - 10/24]** We have added support for running [Llama 3.2 Quantized 1B/3B](./examples/models/llama/README.md) models via ExecuTorch. ## Feedback diff --git a/examples/models/llama/README.md b/examples/models/llama/README.md index fe98561e091..4c1be82cfb6 100644 --- a/examples/models/llama/README.md +++ b/examples/models/llama/README.md @@ -4,6 +4,7 @@ This example demonstrates how to run [Llama models](https://www.llama.com/) on m Here are supported models: - Llama 3.2 1B and 3B +- Llama 3.2 Quantized 1B and 3B - Llama 3.1 8B - Llama 3 8B - [Llama 2 7B](../llama2/README.md) @@ -24,40 +25,54 @@ Please note that the models are subject to the [Llama 2 Acceptable Use Policy](h # Results -## Llama 3.2 1B/3B +## Llama 3.2 1B/3B and quantized 1B/3B models -For Llama 3.2 1B/3B models, we have enabled the original bf16 format and quantization to 4-bit, using SpinQuant, for enhanced performance. +For Llama 3.2 1B/3B models, we have enabled the original BF16 format and quantization to 4-bit, using SpinQuant and QAT+LoRA, for enhanced performance. -### 1. Enablement +The quantized models were optimized primarily for Arm CPU architecture by leveraging XNNPACK and Kleidi AI library. Work is underway to specifically enable quantization on mobile accelerators for Llama 1B/3B. + +### Enablement We have successfully verified performance on the following devices: iPhone 15 Pro, iPhone 15 Pro Max, Samsung Galaxy S24+, S22 and OnePlus 12 (featuring 16GB RAM). -Note, the Llama 3.2 3B unquantized bf16 model was only tested on the OnePlus 12, which has sufficient memory (16GB RAM) to support its size requirements. +Note, the Llama 3.2 3B unquantized BF16 model was only tested on the OnePlus 12, which has sufficient memory (16GB RAM) to support its size requirements. + +### Quantization -### 2. Quantization +The 1B/3B models are sensitive to accuracy loss when regular post-training quantization (PTQ) is applied. To achieve a balance between accuracy, performance and memory, we utilized 4-bit quantization, using [SpinQuant](https://github.com/facebookresearch/SpinQuant/tree/main) and QAT+LoRA methods. -#### 2.1 SpinQuant +Our quantization scheme involves three parts, applicable to both methods: -The 1B/3B models are sensitive to accuracy loss when regular post-training quantization (PTQ) is applied. To achieve a balance between accuracy, performance and memory, we utilized 4-bit quantization with [SpinQuant](https://github.com/facebookresearch/SpinQuant/tree/main). With SpinQuant, we currently quantize 4-bit groupwise (with groupsize 32) weight, 8bit dynamic activation of all the linear layers of the model, except embedding and output layers. The embedding and output layers are quantized as 8-bit per-channel weight and 8-bit dynamic activation. +- We quantize all linear layers in all transformer blocks to a 4-bit groupwise scheme (with a group size of 32) for weights and 8-bit per-token dynamic quantization for activations. +- The classification layer is quantized to 8-bit per-channel for weight and 8-bit per token dynamic quantization for activation. +- We employ an 8-bit per channel quantization for embedding. + +#### SpinQuant The SpinQuant method takes the original weights and produces optimized quantized weights with minimal outliers, resulting in higher accuracy. This can be achieved without any finetuning of the weights and only requires 100 iterations on a single A100 node. SpinQuant can generate quantized weights that are [compatible with ExecuTorch](https://github.com/facebookresearch/SpinQuant/tree/main?tab=readme-ov-file#3-export-to-executorch), specifically, it can be integrated with the existing optimized XNNPACK kernels (e.g., group-wise 4bit weight and 8bit dynamic activation). This allows developers to benefit from the higher accuracy of SpinQuant while also taking advantage of the strong performance of ExecuTorch acceleration. -### 3. Accuracy +#### Quantization-Aware Training and LoRA (QAT+LoRA) + +Quantization-Aware Training (QAT) is employed to simulate the effects of quantization during the training of Llama-3.2 models, enabling optimization of their performance in low precision environments. To initialize QAT, BF16 Llama-3.2 model checkpoints obtained after supervised fine-tuning (SFT) are utilized and an additional full round of SFT training with QAT is performed. The backbone of the QAT model is then frozen and another round of SFT is performed with low-rank adaptation (LoRA) adaptors applied to all layers within the transformer block. Meanwhile, the LoRA adaptors' weights and activations are maintained in BF16. + +### Accuracy Please see the [Llama 3.2 model card](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/MODEL_CARD.md) for accuracy evalations. -### 4. Performance: +### Performance -Llama 3.2 1B and 3B performance was measured on Android OnePlus 12 device. The performance measurement is expressed in terms of tokens per second using an [adb binary-based approach](#step-5-run-benchmark-on) with prompt length of 64. +Llama 3.2 1B and 3B performance was measured on Android OnePlus 12 device. The performance measurement is expressed in terms of tokens per second using an [adb binary-based approach](#step-4-run-benchmark-on-android-phone) with prompt length of 64. It is measured with KleidiAI library. KleidiAI is not enabled by default yet. Use `-DEXECUTORCH_XNNPACK_ENABLE_KLEIDI=ON` to enable it in the build. -|Model | decode (tokens/s) | prefill (tokens/s) | Memory size (RSS in MiB) | -|-------|------------------------ |------------------ | ------------------ | -|1B bf16 | 19.2 | 60.3 | 3,185 | -|1B SpinQuant | 50.2 | 260.5 | 1,921 | -|3B bf16 | 7.6 | 21.2 | 7,419 | -|3B SpinQuant | 19.7 | 89.7 | 3,726 | +|Model | Decode (tokens/s) | Time-to-first-token (sec) | Prefill (tokens/s) | Model size (PTE file size in MiB) | Memory size (RSS in MiB) | +|-------|------------------:|--------------------------:| ------------------:|----------------------------------:| ------------------------:| +|1B BF16 (baseline) | 19.2 | 1.0 | 60.3 | 2,358 | 3,185 | +|1B SpinQuant | 50.2 (2.6x) | 0.3 (-76.9%) | 260.5 (4.3x) | 1,083 (-54.1%) | 1,921 (-39.7%) | +|1B QAT+LoRA | 45.8 (2.4x) | 0.3 (-76.0%) | 252.0 (4.2x) | 1,127 (-52.2%) | 2,255 (-29.2%) | +|3B BF16 (baseline) | 7.6 | 3.0 | 21.2 | 6,129 | 7,419 | +|3B SpinQuant | 19.7 (2.6x) | 0.7 (-76.4%) | 89.7 (4.2x) | 2,435 (-60.3%) | 3,726 (-49.8%) | +|3B QAT+LoRA | 18.5 (2.4x) | 0.7 (-76.1%) | 88.8 (4.2x) | 2,529 (-58.7%) | 4,060 (-45.3%) | @@ -65,7 +80,7 @@ Llama 3.2 1B and 3B performance was measured on Android OnePlus 12 device. The p

- Llama3.2 1B, unquantized, bf16 on Android phone. + Llama3.2 1B, unquantized, BF16 on Android phone.
@@ -80,15 +95,15 @@ Llama 3.2 1B and 3B performance was measured on Android OnePlus 12 device. The p ## Llama 3/3.1 8B Since Llama 3 8B model needs at least 4-bit quantization to fit even within some of the highend phones, results presented here correspond to 4-bit groupwise post-training quantized (PTQ) model. -### 1. Enablement +### Enablement For Llama 3 8B and Llama3.1 8B, we have verified so far on iPhone 15 Pro, iPhone 15 Pro Max, Samsung Galaxy S24+ and OnePlus 12 (with 16GB RAM) by quantizing to 4bit. -### 2. Quantization +### Quantization We employed PTQ 4-bit groupwise per token dynamic quantization of all the linear layers of the model. Dynamic quantization refers to quantizating activations dynamically, such that quantization parameters for activations are calculated, from min/max range, at runtime. Here we quantized activations with 8bits (signed integer). Furthermore, weights are statically quantized. In our case weights were per-channel groupwise quantized with 4bit signed integer. Due to Llama3's vocabulary size, we had to quantize embedding lookup table as well. For these results embedding lookup table was groupwise quantized with 4-bits and group size of 32. -### 3. Accuracy +### Accuracy We evaluated WikiText perplexity using [LM Eval](https://github.com/EleutherAI/lm-evaluation-harness). Below are the results for two different groupsizes, with max_seq_length 2048, and limit 1000. @@ -98,9 +113,9 @@ We evaluated WikiText perplexity using [LM Eval](https://github.com/EleutherAI/l Please note that LM Eval reports perplexity normalized by word count instead of token count. You may see different perplexity for WikiText from other sources if they implement it differently. More details could be found [here](https://github.com/EleutherAI/lm-evaluation-harness/issues/2301). -### 4. Performance +### Performance -Llama 3 8B performance was measured on the Samsung Galaxy S22, S24, and OnePlus 12 devices. The performance measurement is expressed in terms of tokens per second using an [adb binary-based approach](#step-5-run-benchmark-on). +Llama 3 8B performance was measured on the Samsung Galaxy S22, S24, and OnePlus 12 devices. The performance measurement is expressed in terms of tokens per second using an [adb binary-based approach](#step-4-run-benchmark-on-android-phone). |Device | Groupwise 4-bit (128) | Groupwise 4-bit (256) |--------| ---------------------- | --------------- @@ -137,9 +152,11 @@ Llama 3 8B performance was measured on the Samsung Galaxy S22, S24, and OnePlus 1. Download `consolidated.00.pth`, `params.json` and `tokenizer.model` from [Llama website](https://www.llama.com/llama-downloads/) or [Hugging Face](https://huggingface.co/meta-llama/Llama-3.2-1B). For chat use-cases, download the instruct models. -2. Export model and generate `.pte` file. Use original bfloat16 version, without any quantization. +2. Export model and generate `.pte` file. +- Use **original BF16** version, without any quantization. ``` +# No quantization # Set these paths to point to the downloaded files LLAMA_CHECKPOINT=path/to/checkpoint.pth LLAMA_PARAMS=path/to/params.json @@ -155,20 +172,22 @@ python -m examples.models.llama.export_llama \ --output_name="llama3_2.pte" ``` -Optionally, we can apply SpinQuant to quantize the model without sacrifacing too much accuracy loss. - -To use SpinQuant, follow its [instruction](https://github.com/facebookresearch/SpinQuant/tree/main?tab=readme-ov-file#3-export-to-executorch) for exporting checkpoint to ExecuTorch and then export the SpinQuant checkpoint. +- To use **SpinQuant**, here are two ways: + - Download directly from [Llama website](https://www.llama.com/llama-downloads). The model weights are prequantized and can be exported to `pte` file directly. + - Follow its [instruction](https://github.com/facebookresearch/SpinQuant/tree/main?tab=readme-ov-file#3-export-to-executorch) for exporting checkpoint to ExecuTorch and then export the SpinQuant checkpoint. ``` +# SpinQuant # Set these paths to point to the exported files LLAMA_QUANTIZED_CHECKPOINT=path/to/spinquant/checkpoint.pth -LLAMA_PARAMS=path/to/params.json +LLAMA_PARAMS=path/to/spinquant/params.json python -m examples.models.llama.export_llama \ --checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \ --params "${LLAMA_PARAMS:?}" \ --use_sdpa_with_kv_cache \ -X \ + --xnnpack-extended-ops \ --preq_mode 8da4w_output_8da8w \ --preq_group_size 32 \ --max_seq_length 2048 \ @@ -180,6 +199,32 @@ python -m examples.models.llama.export_llama \ --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' ``` +- To use **QAT+LoRA**, download directly from [Llama website](https://www.llama.com/llama-downloads). The model weights are prequantized and can be exported to `pte` file directly by: + +``` +# QAT+LoRA +# Set these paths to point to the exported files +LLAMA_QUANTIZED_CHECKPOINT=path/to/qlora/checkpoint.pth +LLAMA_PARAMS=path/to/qlora/params.json + +python -m examples.models.llama.export_llama \ + --checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \ + --params "${LLAMA_PARAMS:?}" \ + -qat \ + -lora 16 \ + --preq_mode 8da4w_output_8da8w \ + --preq_group_size 32 \ + --preq_embedding_quantize 8,0 \ + --use_sdpa_with_kv_cache \ + -kv \ + -X \ + --xnnpack-extended-ops \ + -d fp32 \ + --max_seq_length 2048 \ + --output_name "llama3_2.pte" \ + --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' +``` + ### Option B: Download and export Llama 3 8B instruct model You can export and run the original Llama 3 8B instruct model. @@ -193,7 +238,7 @@ You can export and run the original Llama 3 8B instruct model. Due to the larger vocabulary size of Llama 3, we recommend quantizing the embeddings with `--embedding-quantize 4,32` as shown above to further reduce the model size. -## Step 4: Run on your computer to validate +## Step 3: Run on your computer to validate 1. Build executorch with optimized CPU performance as follows. Build options available [here](https://github.com/pytorch/executorch/blob/main/CMakeLists.txt#L59). ``` @@ -236,7 +281,7 @@ Note for Mac users: There's a known linking issue with Xcode 15.1. Refer to the To build for CoreML backend and validate on Mac, replace `-DEXECUTORCH_BUILD_XNNPACK=ON` with `-DEXECUTORCH_BUILD_COREML=ON` -## Step 5: Run benchmark on Android phone +## Step 4: Run benchmark on Android phone **1. Build llama runner binary for Android** @@ -301,7 +346,7 @@ adb push cmake-out-android/examples/models/llama/llama_main /data/local/tmp/llam **2.3 Run model** ``` -adb shell "cd /data/local/tmp/llama && ./llama_main --model_path --tokenizer_path --prompt \"Once upon a time\" --seq_len 120" +adb shell "cd /data/local/tmp/llama && ./llama_main --model_path --tokenizer_path --prompt \"What is the capital of France?\" --seq_len 120" --warmup=1 ``` ## Step 6: Build Mobile apps