Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions backends/qualcomm/_passes/insert_requantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,15 @@ def _single_output_annotation(
requantize_dict = n.meta.pop(QCOM_REQUANTIZE)
# {quant_attr: user_node_name_list}
group_quant_attr_dict = self._invert_dict(requantize_dict)
# TODO: If users of the node contain output node,
# we replace the node with to_copy op. However, it would
# be problem when the node has multiple to_copy ops
add_output = len(group_quant_attr_dict) == 1

for hashable_quant_attr, user_nodes in group_quant_attr_dict.items():
user_nodes_copy = user_nodes.copy()
if add_output:
user_nodes_copy.append("output")
self._insert_to_copy(gm, n, dict(hashable_quant_attr), user_nodes_copy)

def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
Expand Down
74 changes: 6 additions & 68 deletions backends/qualcomm/quantizer/custom_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,80 +14,17 @@
QuantizationConfig,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch.ao.quantization.observer import FixedQParamsObserver, MinMaxObserver
from torch.ao.quantization.observer import MinMaxObserver
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
QuantizationSpec,
SharedQuantizationSpec,
)
from torch.fx import Node


def annotate_linear_16a8w_in_affine_layer(gm: torch.fx.GraphModule) -> None:
def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
input_qspec_map = {}
input_act = node.args[0]
input_spec = quantization_config.input_activation
input_qspec_map[input_act] = input_spec

weight = node.args[1]
input_qspec_map[weight] = quantization_config.weight

node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=quantization_config.output_activation,
_annotated=True,
)

quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config(
torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver
)
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default:
if "nn_module_stack" in node.meta:
module_values_list = list(node.meta["nn_module_stack"].values())
full_qualified_name = module_values_list[-1][0]
if full_qualified_name == "output.conv":
annotate_conv2d(
node, quantization_config=quantization_config_16a8w_per_channel
)


def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):
for node in gm.graph.nodes:
if node.op == "output":
for index, prefill_output in enumerate(node.args[0]):
kv_quant_attr = kv_quant_attrs[index]
fixed_observer = FixedQParamsObserver.with_args(
scale=kv_quant_attr[0],
zero_point=kv_quant_attr[1],
quant_min=kv_quant_attr[2],
quant_max=kv_quant_attr[3],
dtype=kv_quant_attr[4],
qscheme=torch.torch.per_tensor_affine,
)

fixed_output_spec = QuantizationSpec(
quant_min=kv_quant_attr[2],
quant_max=kv_quant_attr[3],
dtype=kv_quant_attr[4],
ch_axis=0,
observer_or_fake_quant_ctr=fixed_observer,
)

input_qspec_map = {}
for input in prefill_output.args:
if isinstance(input, Node):
input_qspec_map[input] = fixed_output_spec

prefill_output.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=fixed_output_spec,
_annotated=True,
)


def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
def annotate_matmul_16a8w( # noqa: C901
gm: torch.fx.GraphModule, traverse_input1=True
) -> None:
"""
This function is specific for matmul op 16a8w.
For k, we will tag such as the below, and
Expand Down Expand Up @@ -205,7 +142,8 @@ def annotate_matmul_input1(node: Node):
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
annotate_matmul(node, quantization_config_16a8w)
annotate_matmul_input1(node.args[1])
if traverse_input1:
annotate_matmul_input1(node.args[1])


def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
Expand Down
4 changes: 1 addition & 3 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3529,7 +3529,7 @@ def test_stories_single_llama(self):

cmds = [
"python",
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama2/llama.py",
"--artifact",
self.artifact_dir,
"--build_folder",
Expand All @@ -3556,8 +3556,6 @@ def test_stories_single_llama(self):
"16a4w",
"--temperature",
"0",
"--llama_model",
"stories110m",
]
if self.host:
cmds.extend(["--host", self.host])
Expand Down
7 changes: 5 additions & 2 deletions examples/qualcomm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,11 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ${_pic_flag})
# build qnn_executor_runner
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/executor_runner)

# build qnn_llama_runner for llama
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama)
# build qnn_llama_runner for llama2
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama2)

# build qnn_llama_runner for llama3.2
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama3_2)

# build qaihub_llama2_7b_runner and qaihub_llama3_8b_runner
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/qaihub_scripts/llama)
Expand Down
4 changes: 2 additions & 2 deletions examples/qualcomm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ This directory contains examples for some AI models.

We have seperated the example scripts into the following subfolders, please refer to [README.md](../../backends/qualcomm/README.md) for the example scripts' directory structure:

1. executor_runner: This folder contains a general executor runner capable of running most of the models. As a rule of thumb, if a model does not have its own customized runner, execute the model using [executor_runner](./executor_runner/qnn_executor_runner.cpp). On the other hand, if a model has its own runner, such as [llama](./oss_scripts/llama/qnn_llama_runner.cpp), use the customized runner to execute the model. Customized runner should be located under the same folder as the model's python script.
1. executor_runner: This folder contains a general executor runner capable of running most of the models. As a rule of thumb, if a model does not have its own customized runner, execute the model using [executor_runner](./executor_runner/qnn_executor_runner.cpp). On the other hand, if a model has its own runner, such as [llama2](./oss_scripts/llama2/qnn_llama_runner.cpp), use the customized runner to execute the model. Customized runner should be located under the same folder as the model's python script.

2. oss_scripts: OSS stands for Open Source Software. This folder contains python scripts for open source models. Some models under this folder might also have their own customized runner.
For example, [llama](./oss_scripts/llama/qnn_llama_runner.cpp) contains not only the python scripts to prepare the model but also a customized runner for executing the model.
For example, [llama2](./oss_scripts/llama2/qnn_llama_runner.cpp) contains not only the python scripts to prepare the model but also a customized runner for executing the model.

3. qaihub_scripts: QAIHub stands for [Qualcomm AI Hub](https://aihub.qualcomm.com/). On QAIHub, users can find pre-compiled context binaries, a format used by QNN to save its models. This provides users with a new option for model deployment. Different from oss_scripts & scripts, which the example scripts are converting a model from nn.Module to ExecuTorch .pte files, qaihub_scripts provides example scripts for converting pre-compiled context binaries to ExecuTorch .pte files. Additionaly, users can find customized example runners specific to the QAIHub models for execution. For example [qaihub_llama2_7b](./qaihub_scripts/llama2/qaihub_llama2_7b.py) is a script converting context binaries to ExecuTorch .pte files, and [qaihub_llama2_7b_runner](./qaihub_scripts/llama2/qaihub_llama2_7b_runner.cpp) is a customized example runner to execute llama2 .pte files. Please be aware that context-binaries downloaded from QAIHub are tied to a specific QNN SDK version.
Before executing the scripts and runner, please ensure that you are using the QNN SDK version that is matching the context binary. Tutorial below will also cover how to check the QNN Version for a context binary.
Expand Down
70 changes: 0 additions & 70 deletions examples/qualcomm/oss_scripts/llama/README.md

This file was deleted.

38 changes: 38 additions & 0 deletions examples/qualcomm/oss_scripts/llama2/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

set(_qnn_llama_runner__srcs ${_llama_runner__srcs})

# preprocess qnn llama runner src files
list(TRANSFORM _qnn_llama_runner__srcs PREPEND "${EXECUTORCH_SOURCE_DIR}/")
list(FILTER _qnn_llama_runner__srcs EXCLUDE REGEX ".*(/runner/).*")
list(
PREPEND
_qnn_llama_runner__srcs
${CMAKE_CURRENT_LIST_DIR}/qnn_llama_runner.cpp
${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp
${CMAKE_CURRENT_LIST_DIR}/runner/runner.h
)

# build qnn llama runner
add_executable(qnn_llama_runner ${_qnn_llama_runner__srcs})
target_include_directories(
qnn_llama_runner PUBLIC ${_common_include_directories}
)
target_link_libraries(
qnn_llama_runner
qnn_executorch_backend
full_portable_ops_lib
extension_data_loader
extension_module
extension_tensor
gflags
re2::re2
)
target_compile_options(qnn_llama_runner PUBLIC ${_common_compile_options})
set_target_properties(
qnn_llama_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'"
)
39 changes: 39 additions & 0 deletions examples/qualcomm/oss_scripts/llama2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Summary

## Overview
This file provides you the instructions to run LLAMA2 with different parameters via Qualcomm HTP backend. Following settings support for Stories 110M

Please check corresponding section for more information.

## Stories 110M
This example demonstrates how to run a smaller LLAMA2, stories110M on mobile via Qualcomm HTP backend. Model architecture is fine-tuned specifically for HTP to accelerate the performance. Weight is quantized via PTQ quantization to fit the model on a phone.

### Instructions
#### Step 1: Setup
1. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch.
2. Follow the [tutorial](https://pytorch.org/executorch/stable/build-run-qualcomm-ai-engine-direct-backend.html) to build Qualcomm AI Engine Direct Backend.

#### Step2: Prepare Model
Download and preapre stories110M model

```bash
# tokenizer.model & stories110M.pt:
wget "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt"
wget "https://raw.githubusercontent.com/karpathy/llama2.c/master/tokenizer.model"

# tokenizer.bin:
python -m extension.llm.tokenizer.tokenizer -t tokenizer.model -o tokenizer.bin

# params.json:
echo '{"dim": 768, "multiple_of": 32, "n_heads": 12, "n_layers": 12, "norm_eps": 1e-05, "vocab_size": 32000}' > params.json
```

#### Step3: Run default examples
Default example generates the story based on the given prompt, "Once".
```bash
# 16a4w quant:
python examples/qualcomm/oss_scripts/llama2/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint stories110M --params params.json --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin --prompt "Once"
```

#### (Note) Customized PTQ data set
User prompts are used for PTQ calibration data. Take the examples above, the word "Once" is the only word for PTQ. If you want to observe more data during the calibration time. Please add more prompts to the args `--prompt`.
43 changes: 43 additions & 0 deletions examples/qualcomm/oss_scripts/llama2/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
load("@fbsource//xplat/executorch/backends/qualcomm/qnn_version.bzl", "get_qnn_library_verision")
load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

oncall("executorch")


python_library(
name = "static_llama",
srcs = [
"model/static_llama.py",
],
deps = [
"//caffe2:torch",
],
)

python_binary(
name = "llama",
srcs = ["llama.py"],
main_function = "executorch.examples.qualcomm.oss_scripts.llama2.llama.main",
deps = [
":static_llama",
"//caffe2:torch",
"//executorch/extension/pybindings:aten_lib",
"//executorch/backends/qualcomm/partition:partition",
"//executorch/backends/qualcomm/quantizer:quantizer",
"//executorch/devtools:lib",
"//executorch/examples/models:models",
"//executorch/examples/qualcomm:utils",
"//executorch/extension/export_util:export_util",
"//executorch/extension/llm/export:export_lib",
],
)

runtime.command_alias(
name = "llama_qnn",
env = {
"LD_LIBRARY_PATH": "$(location fbsource//third-party/qualcomm/qnn/qnn-{0}:qnn_offline_compile_libs)".format(get_qnn_library_verision()),
},
exe = ":llama",
)
Loading
Loading