Skip to content
1 change: 1 addition & 0 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
QNN_uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16,
}
QNN_TENSOR_TYPE_MAP = {
torch.bool: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
torch.float32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_8,
torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_16,
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/partition/common_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.index.Tensor,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.slice_scatter.default,
exir_ops.edge.aten.index_put.default,
]

allow_list_operator = [
Expand Down
77 changes: 67 additions & 10 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,13 @@ def build_args_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--pt2e_quantize",
default=None,
choices=[
"xnnpack_dynamic",
"xnnpack_dynamic_qc4",
"qnn_8a8w",
"qnn_16a16w",
"qnn_16a4w",
],
help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.",
)
parser.add_argument(
Expand Down Expand Up @@ -627,6 +634,9 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
if args.use_sdpa_with_kv_cache:
transforms.append(replace_sdpa_with_custom_op)

if args.qnn and args.use_kv_cache:
transforms.append(replace_sdpa_with_simple_sdpa)
transforms.append(replace_causal_mask)
return (
load_llama_model(
modelname=modelname,
Expand All @@ -650,13 +660,16 @@ def _export_llama(modelname, args) -> str: # noqa: C901
# export_to_edge
pt2e_quant_params = _get_pt2e_quantization_params(args)
quantizers = get_pt2e_quantizers(pt2e_quant_params, args)
if args.qnn:
assert (
args.quantization_mode is None
), "Currently qnn backend only supports QnnQuantizer via pt2e flow"
quant_dtype = None
if args.qnn and args.pt2e_quantize:
try:
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.quantizer.quantizer`
from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer
from executorch.backends.qualcomm.quantizer.quantizer import (
get_16a4w_qnn_ptq_config,
get_default_16bit_qnn_ptq_config,
QnnQuantizer,
QuantDtype,
)

# reset quantizers and pt2e_quant_params from xnnpack backend
pt2e_quant_params = None
Expand All @@ -666,10 +679,41 @@ def _export_llama(modelname, args) -> str: # noqa: C901
"Please install the Qualcomm backend follwing https://pytorch.org/executorch/main/build-run-qualcomm.html"
)

backend, quant_config = args.pt2e_quantize.split("_")
assert (
backend == "qnn"
), f"The quantization config is for backend {backend} instead of qnn."
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
qnn_quantizer = QnnQuantizer()
# more custom quantization are supported including 16a4w etc. default to 8bit quantized
custom_annotations = ()
if quant_config == "8a8w":
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
quant_dtype = QuantDtype.use_8a8w
pass
elif quant_config == "16a16w":
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
quant_dtype = QuantDtype.use_16a16w
qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS)
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
qnn_quantizer.set_bit16_op_quant_config(get_default_16bit_qnn_ptq_config())
elif quant_config == "16a4w":
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
quant_dtype = QuantDtype.use_16a4w
qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS)
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
qnn_quantizer.set_bit16_op_quant_config(get_16a4w_qnn_ptq_config())
qnn_quantizer.set_per_channel_weight_dtype(
weight_dtype_for_16bit_act="int4"
)
else:
raise AssertionError(
f"No support for quant type {quant_config}. Support 8a8w, 16a16w and 16a4w."
)

assert (
args.quantization_mode is None
), "Currently qnn backend only supports QnnQuantizer via pt2e flow"
qnn_quantizer.add_custom_quant_annotations(custom_annotations)
quantizers.append(qnn_quantizer)

Expand Down Expand Up @@ -786,25 +830,38 @@ def _export_llama(modelname, args) -> str: # noqa: C901
"Please install the Qualcomm backend follwing https://pytorch.org/executorch/main/build-run-qualcomm.html"
)

# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
backend_options = generate_htp_compiler_spec(use_fp16=False)
use_fp16 = True
skip_node_op_set = {}
if args.pt2e_quantize:
use_fp16 = False
# TODO: fix the lowering error without skipping nodes
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
if quant_dtype == QuantDtype.use_8a8w:
raise NotImplementedError("8a8w for llama is still under development")
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
elif quant_dtype == QuantDtype.use_16a16w:
raise NotImplementedError("16a16w for llama is still under development")
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
elif quant_dtype == QuantDtype.use_16a4w:
raise NotImplementedError("16a4w for llama is still under development")
partitioners.append(
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
QnnPartitioner(
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
generate_qnn_executorch_compiler_spec(
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
soc_model=QcomChipset.SM8650, # default to SM8650
backend_options=backend_options,
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
backend_options=generate_htp_compiler_spec(use_fp16=use_fp16),
debug=False,
saver=False,
),
skip_node_id_set={},
skip_node_op_set={},
skip_node_op_set=skip_node_op_set,
)
)
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`
_transform(builder_exported_to_edge.export_program())
_transform(builder_exported_to_edge.edge_manager.exported_program())

if args.generate_etrecord:
if not builder_exported_to_edge.edge_manager:
Expand Down