From 58c69ee21a4b7843094f0d00172d19849ceaa74d Mon Sep 17 00:00:00 2001 From: haowhsu-quic <111341466+haowhsu-quic@users.noreply.github.com> Date: Thu, 18 Sep 2025 05:52:54 +0800 Subject: [PATCH] Qualcomm AI Engine Direct - issue fix #2 (#14378) ### Summary - #14048 > add quantized test case with GLU decomposition - #14049 > add e2e example where constant expansion is applied - #14050 > add e2e example and source transform for 6D operation - #14051 > add e2e example and complement missed annotation - #14052 > add e2e example and dedicated passe for 6D partition Fixes #14048 Fixes #14049 Fixes #14050 Fixes #14051 Fixes #14052 ### Test plan MATRIX = {convnext_small, maxvit_t, swin_v2_t, vit_b_16} ```bash python backends/qualcomm/tests/test_qnn_delegate.py TestExampleOssScript.test_${MATRIX} -b build-android/ -m SM8750 -s $SN -a /path/to/test_artifacts/ -i /path/to/imagenet_1k/imagenet-mini/val -r . ``` ```bash python backends/qualcomm/tests/test_qnn_delegate.py TestQuantizedModel.test_qnn_backend_conformer -b build-android/ -m SM8750 -s $SN -a /path/to/test_artifacts/ ``` (cherry picked from commit 2b54a19ec5d35e1981848fa86ad423c2b37d49f4) --- backends/qualcomm/_passes/__init__.py | 2 + .../qualcomm/_passes/annotate_quant_attrs.py | 8 + backends/qualcomm/_passes/decompose_any.py | 28 +- backends/qualcomm/_passes/decompose_cdist.py | 28 +- backends/qualcomm/_passes/decompose_einsum.py | 33 +-- backends/qualcomm/_passes/decompose_glu.py | 55 ++++ .../_passes/decompose_linalg_vector_norm.py | 29 +-- backends/qualcomm/_passes/decompose_roll.py | 29 +-- .../_passes/decompose_wrap_with_autocast.py | 27 +- .../qualcomm/_passes/fixed_linear_keep_dim.py | 23 +- backends/qualcomm/_passes/qnn_pass_manager.py | 2 + backends/qualcomm/_passes/utils.py | 39 +++ backends/qualcomm/quantizer/annotators.py | 4 +- backends/qualcomm/tests/models.py | 20 ++ backends/qualcomm/tests/test_qnn_delegate.py | 230 +++++++++++++++++ examples/qualcomm/oss_scripts/README.md | 6 +- .../qualcomm/oss_scripts/convnext_small.py | 145 +++++++++++ examples/qualcomm/oss_scripts/maxvit_t.py | 244 ++++++++++++++++++ examples/qualcomm/oss_scripts/swin_v2_t.py | 185 +++++++++++++ examples/qualcomm/oss_scripts/vit_b_16.py | 135 ++++++++++ examples/qualcomm/utils.py | 3 + 21 files changed, 1140 insertions(+), 135 deletions(-) create mode 100644 backends/qualcomm/_passes/decompose_glu.py create mode 100755 examples/qualcomm/oss_scripts/convnext_small.py create mode 100755 examples/qualcomm/oss_scripts/maxvit_t.py create mode 100755 examples/qualcomm/oss_scripts/swin_v2_t.py create mode 100755 examples/qualcomm/oss_scripts/vit_b_16.py diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index 5bd305335d5..4548c348bd0 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -18,6 +18,7 @@ from .decompose_col_im import DecomposeColIm from .decompose_einsum import DecomposeEinsum from .decompose_expm1 import DecomposeExpM1 +from .decompose_glu import DecomposeGlu from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm from .decompose_minmaxdim import DecomposeMinMaxDim from .decompose_roll import DecomposeRoll @@ -59,6 +60,7 @@ DecomposeColIm, DecomposeEinsum, DecomposeExpM1, + DecomposeGlu, DecomposeLinalgVectorNorm, DecomposeMinMaxDim, DecomposeRoll, diff --git a/backends/qualcomm/_passes/annotate_quant_attrs.py b/backends/qualcomm/_passes/annotate_quant_attrs.py index 610e88e6d3b..6077d51b099 100644 --- a/backends/qualcomm/_passes/annotate_quant_attrs.py +++ b/backends/qualcomm/_passes/annotate_quant_attrs.py @@ -19,6 +19,7 @@ QCOM_SCALE, QCOM_ZERO_POINT, ) +from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from .utils import get_quant_attrs @@ -38,6 +39,9 @@ def __init__( super(AnnotateQuantAttrs, self).__init__() self.edge_program = edge_program self.skip_advanced_requant = skip_advanced_requant + self.skip_requant_allowlist = { + exir_ops.edge.aten.sigmoid.default, + } def _annotate_source_nodes( self, quant_node: torch.fx.Node, quant_attrs: Dict[str, Any] @@ -80,6 +84,10 @@ def _annotate_requant(self, n): # node1 -> q_ui8 (n) -> dq_ui8 -> q_int32 -> dq_int32 -> node2 -> .... # We store {node2: quant_attr in dq_int32} in node1.meta if n.target in q_ops and n.args[0].target not in dq_ops: + # for some fixed scale op, there is no need to requantize it + if n.args[0].target in self.skip_requant_allowlist: + return + dq_nodes = self._find_last_dq_nodes(n) q_attrs = get_quant_attrs(self.edge_program, n) for dq_node in dq_nodes: diff --git a/backends/qualcomm/_passes/decompose_any.py b/backends/qualcomm/_passes/decompose_any.py index e92bf11dd18..0cb959ff77f 100644 --- a/backends/qualcomm/_passes/decompose_any.py +++ b/backends/qualcomm/_passes/decompose_any.py @@ -8,6 +8,8 @@ from executorch.exir import to_edge from executorch.exir.pass_base import ExportPass, PassResult +from .utils import merge_decomposed_graph + class Any(torch.nn.Module): def __init__(self, dim, keepdim): @@ -49,26 +51,12 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # remap is used to map original node values to new node values, # which ensures that reference to nodes are correctly updated in the new graph remap = {"x": node.args[0]} - - for decomposed_node in decomposed_module.graph.nodes: - # no need to copy existent 'output' - if decomposed_node.op == "output": - for user in node.users.copy(): - # remap - user.replace_input_with( - node, - remap[decomposed_node.args[0][0]], - ) - # no need to copy existent placeholders - elif decomposed_node.op == "placeholder": - # replace node map from string to graph node - remap[decomposed_node] = remap.pop(decomposed_node.name) - else: - remap[decomposed_node] = graph.node_copy( - decomposed_node, - arg_transform=lambda x, remap=remap: remap[x], - ) - + merge_decomposed_graph( + remap=remap, + target_node=node, + target_graph=graph, + decomposed_graph_module=decomposed_module, + ) graph.erase_node(node) graph.eliminate_dead_code() diff --git a/backends/qualcomm/_passes/decompose_cdist.py b/backends/qualcomm/_passes/decompose_cdist.py index d18a0295ffb..a3c812bdc37 100644 --- a/backends/qualcomm/_passes/decompose_cdist.py +++ b/backends/qualcomm/_passes/decompose_cdist.py @@ -7,6 +7,8 @@ import torch from executorch.exir.pass_base import ExportPass, PassResult +from .utils import merge_decomposed_graph + class CDist(torch.nn.Module): def __init__(self): @@ -54,26 +56,12 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # remap is used to map original node values to new node values, # which ensures that reference to nodes are correctly updated in the new graph remap = {"x": node.args[0], "y": node.args[1]} - - for decomposed_node in decomposed_module.graph.nodes: - # no need to copy existent 'output' - if decomposed_node.op == "output": - for user in node.users.copy(): - # remap - user.replace_input_with( - node, - remap[decomposed_node.args[0][0]], - ) - # no need to copy existent placeholders - elif decomposed_node.op == "placeholder": - # replace node map from string to graph node - remap[decomposed_node] = remap.pop(decomposed_node.name) - else: - remap[decomposed_node] = graph.node_copy( - decomposed_node, - arg_transform=lambda x, remap=remap: remap[x], - ) - + merge_decomposed_graph( + remap=remap, + target_node=node, + target_graph=graph, + decomposed_graph_module=decomposed_module, + ) graph.erase_node(node) graph.eliminate_dead_code() diff --git a/backends/qualcomm/_passes/decompose_einsum.py b/backends/qualcomm/_passes/decompose_einsum.py index 046c1598311..464d989333f 100644 --- a/backends/qualcomm/_passes/decompose_einsum.py +++ b/backends/qualcomm/_passes/decompose_einsum.py @@ -8,7 +8,7 @@ from executorch.exir.pass_base import ExportPass, PassResult from torch.fx.experimental.proxy_tensor import make_fx -from .utils import copy_nn_module_stack +from .utils import merge_decomposed_graph class DecomposeEinsum(ExportPass): @@ -37,30 +37,13 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: for i, arg in enumerate(node.args[1]): remap[f"arg1_{i+1}"] = arg - for decomposed_node in decomposed_module.graph.nodes: - copy_nn_module_stack(node, decomposed_node) - # This is the arg[0] equation string, which is not required anymore after decomposition - if "arg0" in decomposed_node.name: - continue - - # no need to copy existent 'output' - if decomposed_node.op == "output": - for user in node.users.copy(): - # remap - user.replace_input_with( - node, - remap[decomposed_node.args[0][0]], - ) - # no need to copy existent placeholders - elif decomposed_node.op == "placeholder": - # replace node map from string to graph node - remap[decomposed_node] = remap.pop(decomposed_node.name) - else: - remap[decomposed_node] = graph.node_copy( - decomposed_node, - arg_transform=lambda x, remap=remap: remap[x], - ) - + merge_decomposed_graph( + remap=remap, + target_node=node, + target_graph=graph, + decomposed_graph_module=decomposed_module, + predicate=lambda decomp_node: "arg0" not in decomp_node.name, + ) graph.erase_node(node) graph.eliminate_dead_code() diff --git a/backends/qualcomm/_passes/decompose_glu.py b/backends/qualcomm/_passes/decompose_glu.py new file mode 100644 index 00000000000..de363468799 --- /dev/null +++ b/backends/qualcomm/_passes/decompose_glu.py @@ -0,0 +1,55 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.pass_base import ExportPass, PassResult + +from .utils import merge_decomposed_graph + + +# this wrapper is required for IO name mapping with decomposed graph +class Glu(torch.nn.Module): + def __init__(self, dim=-1): + super().__init__() + self.glu = torch.nn.GLU(dim=dim) + + def forward(self, x): + return self.glu(x) + + +class DecomposeGlu(ExportPass): + """ + Decompose glu for quantization annotation to work properly. + """ + + def __init__(self) -> None: + super().__init__() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + for node in graph.nodes: + if node.target == torch.ops.aten.glu.default: + ep = torch.export.export( + Glu(dim=-1 if len(node.args) < 2 else node.args[1]), + (node.args[0].meta["val"],), + ) + decomposed_module = ep.run_decompositions().graph_module + + with graph.inserting_before(node): + # remap is used to map original node values to new node values, + # which ensures that reference to nodes are correctly updated in the new graph + remap = {"x": node.args[0]} + merge_decomposed_graph( + remap=remap, + target_node=node, + target_graph=graph, + decomposed_graph_module=decomposed_module, + ) + graph.erase_node(node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/decompose_linalg_vector_norm.py b/backends/qualcomm/_passes/decompose_linalg_vector_norm.py index 993f088da12..94a5b10ba3f 100644 --- a/backends/qualcomm/_passes/decompose_linalg_vector_norm.py +++ b/backends/qualcomm/_passes/decompose_linalg_vector_norm.py @@ -8,7 +8,7 @@ from executorch.exir import to_edge from executorch.exir.pass_base import ExportPass, PassResult -from .utils import copy_nn_module_stack +from .utils import merge_decomposed_graph class LinalgVectorNorm(torch.nn.Module): @@ -62,27 +62,12 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # remap is used to map original node values to new node values, # which ensures that reference to nodes are correctly updated in the new graph remap = {"x": node.args[0]} - - for decomposed_node in decomposed_module.graph.nodes: - copy_nn_module_stack(node, decomposed_node) - # no need to copy existent 'output' - if decomposed_node.op == "output": - for user in node.users.copy(): - # remap - user.replace_input_with( - node, - remap[decomposed_node.args[0][0]], - ) - # no need to copy existent placeholders - elif decomposed_node.op == "placeholder": - # replace node map from string to graph node - remap[decomposed_node] = remap.pop(decomposed_node.name) - else: - remap[decomposed_node] = graph.node_copy( - decomposed_node, - arg_transform=lambda x, remap=remap: remap[x], - ) - + merge_decomposed_graph( + remap=remap, + target_node=node, + target_graph=graph, + decomposed_graph_module=decomposed_module, + ) graph.erase_node(node) graph.eliminate_dead_code() diff --git a/backends/qualcomm/_passes/decompose_roll.py b/backends/qualcomm/_passes/decompose_roll.py index e13433508f5..e6f60d55464 100644 --- a/backends/qualcomm/_passes/decompose_roll.py +++ b/backends/qualcomm/_passes/decompose_roll.py @@ -7,7 +7,7 @@ from executorch.exir.pass_base import ExportPass, PassResult -from .utils import copy_nn_module_stack +from .utils import merge_decomposed_graph class SliceCopy(torch.nn.Module): @@ -65,27 +65,12 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # remap is used to map original node values to new node values, # which ensures that reference to nodes are correctly updated in the new graph remap = {"x": input_node} - - for decomposed_node in decomposed_module.graph.nodes: - copy_nn_module_stack(node, decomposed_node) - # no need to copy existent 'output' - if decomposed_node.op == "output": - for user in node.users.copy(): - # remap - user.replace_input_with( - node, - remap[decomposed_node.args[0][0]], - ) - # no need to copy existent placeholders - elif decomposed_node.op == "placeholder": - # replace node map from string to graph node - remap[decomposed_node] = remap.pop(decomposed_node.name) - else: - remap[decomposed_node] = graph.node_copy( - decomposed_node, - arg_transform=lambda x, remap=remap: remap[x], - ) - + merge_decomposed_graph( + remap=remap, + target_node=node, + target_graph=graph, + decomposed_graph_module=decomposed_module, + ) graph.erase_node(node) graph.eliminate_dead_code() diff --git a/backends/qualcomm/_passes/decompose_wrap_with_autocast.py b/backends/qualcomm/_passes/decompose_wrap_with_autocast.py index 6c073bd309c..1b60b740ed3 100644 --- a/backends/qualcomm/_passes/decompose_wrap_with_autocast.py +++ b/backends/qualcomm/_passes/decompose_wrap_with_autocast.py @@ -10,7 +10,7 @@ import torch from executorch.exir.pass_base import ExportPass, PassResult -from .utils import copy_nn_module_stack +from .utils import merge_decomposed_graph class DecomposeWrapWithAutocast(ExportPass): @@ -52,7 +52,7 @@ def _replace(self, gm: torch.fx.GraphModule) -> None: graph = gm.graph for node in graph.nodes: if isinstance(node.target, torch._higher_order_ops.wrap.WrapWithAutocast): - submod, submod_name = self._get_submod(gm, node) + submod, _ = self._get_submod(gm, node) n_args = node.args input_submod = n_args[4] decomposed_module = submod @@ -61,22 +61,13 @@ def _replace(self, gm: torch.fx.GraphModule) -> None: # which ensures that reference to nodes are correctly updated in the new graph # remap = {"expand_1": node.args[5], "to_4": node.args[6]} remap = {n_args[i].name: n_args[i] for i in range(5, len(n_args))} - - for decomposed_node in decomposed_module.graph.nodes: - copy_nn_module_stack(node, decomposed_node) - # no need to copy existent 'output' - if decomposed_node.op == "output": - self._replace_output(node, decomposed_node, remap) - # no need to copy existent placeholders - elif decomposed_node.op == "placeholder": - # replace node map from string to graph node - remap[decomposed_node] = remap.pop(decomposed_node.name) - else: - remap[decomposed_node] = graph.node_copy( - decomposed_node, - arg_transform=lambda x, remap=remap: remap[x], - ) - + merge_decomposed_graph( + remap=remap, + target_node=node, + target_graph=graph, + decomposed_graph_module=decomposed_module, + output_processor=self._replace_output, + ) graph.erase_node(node) graph.erase_node(input_submod) diff --git a/backends/qualcomm/_passes/fixed_linear_keep_dim.py b/backends/qualcomm/_passes/fixed_linear_keep_dim.py index 19f5c631921..04c0f92cebf 100644 --- a/backends/qualcomm/_passes/fixed_linear_keep_dim.py +++ b/backends/qualcomm/_passes/fixed_linear_keep_dim.py @@ -5,10 +5,14 @@ # LICENSE file in the root directory of this source tree. import torch +from executorch.backends.qualcomm.builders.node_visitor import dq_ops +from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.passes import dead_code_elimination_pass +from .utils import copy_meta, get_quant_attrs + class FixedLinearKeepDim(ExportPass): """ @@ -18,8 +22,12 @@ class FixedLinearKeepDim(ExportPass): view_copy = exir_ops.edge.aten.view_copy.default linear = exir_ops.edge.aten.linear.default - def __init__(self): + def __init__( + self, + edge_program: torch.export.ExportedProgram, + ): super(FixedLinearKeepDim, self).__init__() + self.edge_program = edge_program def _fixed_keep_dim(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: @@ -46,9 +54,15 @@ def _fixed_keep_dim(self, graph_module: torch.fx.GraphModule): ) # meta needs to be copied elementwisely for fake-tensor # to be updated correctly and not affect meta of input_node - for k, v in input_node.meta.items(): - squeeze_node.meta[k] = v + squeeze_node.meta = copy_meta(input_node.meta) squeeze_node.meta["val"] = input_tensor.reshape(squeeze_dim) + # if input_node is dequantize, we need to fetch encodings manually + # TODO: remove this when constant fold mechanism is introduced + if input_node.target in dq_ops: + squeeze_node.meta[QCOM_QUANT_ATTRS] = get_quant_attrs( + self.edge_program, input_node + ) + for user in input_users: if user == linear_node: user.replace_input_with(input_node, squeeze_node) @@ -66,8 +80,7 @@ def _fixed_keep_dim(self, graph_module: torch.fx.GraphModule): ) # meta needs to be copied elementwisely for fake-tensor # to be updated correctly and not affect meta of unsqueeze_node - for k, v in linear_node.meta.items(): - unsqueeze_node.meta[k] = v + unsqueeze_node.meta = copy_meta(linear_node.meta) # update linear node's shape linear_node.meta["val"] = linear_output.reshape( (squeeze_node.meta["val"].shape[0], linear_output.shape[-1]) diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index d1f49baf83b..22c7c650ac2 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -23,6 +23,7 @@ DecomposeColIm, DecomposeEinsum, DecomposeExpM1, + DecomposeGlu, DecomposeLinalgVectorNorm, DecomposeMinMaxDim, DecomposeRoll, @@ -203,6 +204,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeWrapWithAutocast()) self.add_pass(DecomposeEinsum()) self.add_pass(DecomposeExpM1()) + self.add_pass(DecomposeGlu()) self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True)) self.add_pass(ReplaceInfValues()) self.add_pass(LiftConstantScalarOperands()) diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index 6d908707892..eebfa4d9eb4 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -117,6 +117,45 @@ def copy_nn_module_stack(src, target): target.meta["nn_module_stack"] = value +def merge_decomposed_graph( + remap: Dict[str, torch.fx.Node], + target_node: torch.fx.Node, + target_graph: torch.fx.GraphModule, + decomposed_graph_module: torch.fx.GraphModule, + predicate: Callable[[torch.fx.Node], None] = None, + # target_node, decomposed_output_node, remap + output_processor: Callable[ + [torch.fx.Node, torch.fx.Node, Dict[str, torch.fx.Node]], None + ] = None, +) -> None: + def default_output_process(node): + for user in node.users.copy(): + # remap + user.replace_input_with( + node, + remap[decomposed_node.args[0][0]], + ) + + for decomposed_node in decomposed_graph_module.graph.nodes: + copy_nn_module_stack(target_node, decomposed_node) + if predicate is None or predicate(decomposed_node): + # no need to copy existent 'output' + if decomposed_node.op == "output": + if output_processor is None: + default_output_process(target_node) + else: + output_processor(target_node, decomposed_node, remap) + # no need to copy existent placeholders + elif decomposed_node.op == "placeholder": + # replace node map from string to graph node + remap[decomposed_node] = remap.pop(decomposed_node.name) + else: + remap[decomposed_node] = target_graph.node_copy( + decomposed_node, + arg_transform=lambda x, remap=remap: remap[x], + ) + + def is_float_tensor(node: torch.fx.Node) -> bool: if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor): return False diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index c3213af6338..d5b66ec308a 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -674,7 +674,7 @@ def annotate_pad(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) -@register_annotator([torch.ops.aten.reshape.default]) +@register_annotator([torch.ops.aten.reshape.default, torch.ops.aten.unflatten.int]) def annotate_reshape(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) @@ -879,7 +879,7 @@ def annotate_unsqueeze_copy( annotate_single_in_share_out(node, quantization_config) -@register_annotator([torch.ops.aten.transpose.int]) +@register_annotator([torch.ops.aten.transpose.int, torch.ops.aten.swapaxes.default]) def annotate_transpose(node: Node, quantization_config: QuantizationConfig) -> None: annotate_in_out_obs_sharing_op(node, quantization_config) if not _is_annotated([node]): diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 035adc52d53..c457667cb2f 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -1976,6 +1976,16 @@ def forward(self, x): return torch.sum(x, dim=(2, 3), keepdim=True) +class SwapAxes(torch.nn.Module): + def __init__(self, axis0, axis1): + super().__init__() + self.axis0 = axis0 + self.axis1 = axis1 + + def forward(self, x): + return torch.swapaxes(x, axis0=self.axis0, axis1=self.axis1) + + class Tanh(torch.nn.Module): def __init__(self): super().__init__() @@ -2002,6 +2012,16 @@ def forward(self, x): return torch.unbind(x) +class Unflatten(torch.nn.Module): + def __init__(self, dim, sizes): + super().__init__() + self.dim = dim + self.sizes = sizes + + def forward(self, x): + return torch.unflatten(x, dim=self.dim, sizes=self.sizes) + + class Unfold(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 2bfff8bbecd..f0f5e8182eb 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -736,6 +736,13 @@ def test_qnn_backend_gelu(self): sample_input = (torch.randn(2, 5, 1, 3),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_glu(self): + modules = [torch.nn.GLU(), torch.nn.GLU(dim=0)] + sample_input = (torch.randn(2, 5, 1, 4),) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_greater_equal(self): test_comb = [ { @@ -1354,11 +1361,21 @@ def test_qnn_backend_sum_int_list(self): sample_input = (torch.randn([1, 4, 8, 8]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_swapaxes(self): + module = SwapAxes(0, 1) # noqa: F405 + sample_input = (torch.randn([1, 2, 3, 4]),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_tanh(self): module = Tanh() # noqa: F405 sample_input = (torch.randn(2, 5, 1, 3),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_unflatten(self): + module = Unflatten(dim=1, sizes=(2, 3, 4)) # noqa: F405 + sample_input = (torch.randn([1, 24]),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_unbind(self): module = Unbind() # noqa: F405 sample_input = (torch.randn([3, 3]),) @@ -2405,6 +2422,14 @@ def test_qnn_backend_gelu(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_glu(self): + modules = [torch.nn.GLU(), torch.nn.GLU(dim=0)] + sample_input = (torch.randn(2, 5, 1, 4),) + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_greater_equal(self): test_comb = [ { @@ -3120,12 +3145,24 @@ def test_qnn_backend_sum_int_list(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_swapaxes(self): + module = SwapAxes(0, 1) # noqa: F405 + sample_input = (torch.randn([1, 2, 3, 4]),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_tanh(self): module = Tanh() # noqa: F405 sample_input = (torch.randn(2, 5, 1, 3),) module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_unflatten(self): + module = Unflatten(dim=1, sizes=(2, 3, 4)) # noqa: F405 + sample_input = (torch.randn([1, 24]),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_unbind(self): module = Unbind() # noqa: F405 sample_input = (torch.randn([3, 3]),) @@ -3249,6 +3286,51 @@ def test_qnn_backend_chunk_add(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conformer(self): + from typing import Tuple + + import torchaudio + + class PatchedConformer(torch.nn.Module): + """ + A lightly modified version of the top-level Conformer module, such that it can be exported. + Instead of taking lengths and computing the padding mask, it takes the padding mask directly. + See https://github.com/pytorch/audio/blob/main/src/torchaudio/models/conformer.py#L215 + """ + + def __init__(self, conformer): + super().__init__() + self.conformer = conformer + + def forward( + self, input: torch.Tensor, encoder_padding_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = input.transpose(0, 1) + for layer in self.conformer.conformer_layers: + x = layer(x, encoder_padding_mask) + return x.transpose(0, 1) + + inner_model = torchaudio.models.Conformer( + input_dim=80, + num_heads=4, + ffn_dim=128, + num_layers=4, + depthwise_conv_kernel_size=31, + ) + lengths = torch.randint(1, 400, (10,)) + encoder_padding_mask = torchaudio.models.conformer._lengths_to_padding_mask( + lengths + ) + sample_input = ( + torch.rand(10, int(lengths.max()), 80), + encoder_padding_mask.to(torch.float32), + ) + module = PatchedConformer(inner_model).eval() + module = self.get_qdq_module( + module, sample_input, quant_dtype=QuantDtype.use_16a8w + ) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv1d_relu_log_softmax(self): modules = [ Conv1dReluLogSoftmax(dim=1), # noqa: F405 @@ -5744,6 +5826,43 @@ def test_conv_former(self): self.assertGreaterEqual(msg["top_1"], 70) self.assertGreaterEqual(msg["top_5"], 92) + def test_convnext_small(self): + if not self.required_envs([self.image_dataset]): + self.skipTest("missing required envs") + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/convnext_small.py", + "--dataset", + self.image_dataset, + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--ip", + self.ip, + "--port", + str(self.port), + "--seed", + str(1126), + ] + if self.host: + cmds.extend(["--host", self.host]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["top_1"], 76) + self.assertGreaterEqual(msg["top_5"], 97) + def test_cvt(self): if not self.required_envs([self.image_dataset]): self.skipTest("missing required envs") @@ -6242,6 +6361,43 @@ def test_gMLP(self): self.assertGreaterEqual(msg["top_1"], 70) self.assertGreaterEqual(msg["top_5"], 88) + def test_maxvit_t(self): + if not self.required_envs([self.image_dataset]): + self.skipTest("missing required envs") + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/maxvit_t.py", + "--dataset", + self.image_dataset, + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--ip", + self.ip, + "--port", + str(self.port), + "--seed", + str(1126), + ] + if self.host: + cmds.extend(["--host", self.host]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["top_1"], 72) + self.assertGreaterEqual(msg["top_5"], 91) + @unittest.skip("Only outputs good accuracy in QNN 2.29") def test_mobilevit_v2(self): if not self.required_envs([self.image_dataset]): @@ -6588,6 +6744,43 @@ def test_swin_transformer(self): self.assertGreaterEqual(msg["top_1"], 71) self.assertGreaterEqual(msg["top_5"], 90) + def test_swin_v2_t(self): + if not self.required_envs([self.image_dataset]): + self.skipTest("missing required envs") + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/swin_v2_t.py", + "--dataset", + self.image_dataset, + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--ip", + self.ip, + "--port", + str(self.port), + "--seed", + str(1126), + ] + if self.host: + cmds.extend(["--host", self.host]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["top_1"], 63) + self.assertGreaterEqual(msg["top_5"], 92) + def test_t5(self): if not self.required_envs([self.qa_dataset]): self.skipTest("missing required envs") @@ -6624,6 +6817,43 @@ def test_t5(self): else: self.assertGreaterEqual(msg["f1"], 0.72) + def test_vit_b_16(self): + if not self.required_envs([self.image_dataset]): + self.skipTest("missing required envs") + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/vit_b_16.py", + "--dataset", + self.image_dataset, + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--ip", + self.ip, + "--port", + str(self.port), + "--seed", + str(1126), + ] + if self.host: + cmds.extend(["--host", self.host]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["top_1"], 72) + self.assertGreaterEqual(msg["top_5"], 96) + def test_whisper(self): if not self.required_envs(): self.skipTest("missing required envs") diff --git a/examples/qualcomm/oss_scripts/README.md b/examples/qualcomm/oss_scripts/README.md index b68024d5fbf..7971cc4a1de 100644 --- a/examples/qualcomm/oss_scripts/README.md +++ b/examples/qualcomm/oss_scripts/README.md @@ -15,6 +15,7 @@ The following models can be categorized based on their primary use cases. 2. Vision Model: - conv_former + - convnext_small - cvt - deit - dino_v2 @@ -26,6 +27,7 @@ The following models can be categorized based on their primary use cases. - fbnet - focalnet - gMLP_image_classification + - maxvit_t - mobilevit1 - mobilevit_v2 - pvt @@ -34,6 +36,8 @@ The following models can be categorized based on their primary use cases. - squeezenet - ssd300_vgg16 - swin_transformer + - swin_v2_t + - vit_b_16 ## Prerequisite Please follow another [README](../README.md) first to set up environment. @@ -51,7 +55,7 @@ If you want to export the model without running it, please add `--compile_only` ```bash python albert.py -m ${SOC_MODEL} -b path/to/build-android/ -s ${DEVICE_SERIAL} -d path/to/wikisent2 -2. `conv_former`,`cvt`,`deit`,`dino_v2`,`efficientnet`,`fbnet`, `focalnet`, `gMLP_image_classification`, `mobilevit1`,`mobilevit_v2`, `pvt`, `squeezenet`, `swin_transformer` : +2. `conv_former`, `convnext_small`, `cvt`, `deit`, `dino_v2`, `efficientnet`, `fbnet`, `focalnet`, `gMLP_image_classification`, `maxvit_t`, `mobilevit1`, `mobilevit_v2`, `pvt`, `squeezenet`, `swin_transformer`, `swin_v2_t`, `vit_b_16` : - Required Dataset : ImageNet Download [dataset](https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000) first, and place it in a valid folder. diff --git a/examples/qualcomm/oss_scripts/convnext_small.py b/examples/qualcomm/oss_scripts/convnext_small.py new file mode 100755 index 00000000000..491ffb0b7c3 --- /dev/null +++ b/examples/qualcomm/oss_scripts/convnext_small.py @@ -0,0 +1,145 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import logging +import os + +from multiprocessing.connection import Client + +import numpy as np + +import torch +import torchvision + +from executorch.backends.qualcomm._passes.expand_broadcast_tensor_shape import ( + ExpandBroadcastTensorShape, +) +from executorch.backends.qualcomm._passes.qnn_pass_manager import ( + get_capture_program_passes, +) +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.backends.qualcomm.utils.constants import QCOM_PASS_ACTIVATE_KEY +from executorch.examples.qualcomm.utils import ( + build_executorch_binary, + get_imagenet_dataset, + make_output_dir, + make_quantizer, + setup_common_args_and_variables, + SimpleADB, + topk_accuracy, +) + + +def main(args): + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + data_num = 100 + if args.ci: + inputs = [(torch.rand(1, 3, 224, 224),)] + logging.warning( + "This option is for CI to verify the export flow. It uses random input and will result in poor accuracy." + ) + else: + inputs, targets = get_imagenet_dataset( + dataset_path=f"{args.dataset}", + data_size=data_num, + image_shape=(256, 256), + crop_size=224, + ) + + pte_filename = "convnext_small_qnn_q8" + instance = torchvision.models.convnext_small(weights="IMAGENET1K_V1").eval() + passes_job = get_capture_program_passes() + passes_job[ExpandBroadcastTensorShape][QCOM_PASS_ACTIVATE_KEY] = True + build_executorch_binary( + instance, + inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + inputs, + custom_quantizer=make_quantizer( + quant_dtype=QuantDtype.use_8a8w, + per_channel_linear=True, + ), + passes_job=passes_job, + shared_buffer=args.shared_buffer, + ) + + if args.compile_only: + return + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=f"{args.artifact}/{pte_filename}.pte", + workspace=f"/data/local/tmp/executorch/{pte_filename}", + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + ) + adb.push(inputs=inputs) + adb.execute() + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + + adb.pull(output_path=args.artifact) + + # top-k analysis + predictions = [] + for i in range(data_num): + predictions.append( + np.fromfile( + os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32 + ) + ) + + k_val = [1, 5] + topk = [topk_accuracy(predictions, targets, k).item() for k in k_val] + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)})) + else: + for i, k in enumerate(k_val): + print(f"top_{k}->{topk[i]}%") + + +if __name__ == "__main__": + parser = setup_common_args_and_variables() + parser.add_argument( + "-d", + "--dataset", + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + required=False, + ) + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./convnext_small", + default="./convnext_small", + type=str, + ) + + args = parser.parse_args() + args.validate(args) + try: + main(args) + except Exception as e: + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"Error": str(e)})) + else: + raise Exception(e) diff --git a/examples/qualcomm/oss_scripts/maxvit_t.py b/examples/qualcomm/oss_scripts/maxvit_t.py new file mode 100755 index 00000000000..7a53edd715b --- /dev/null +++ b/examples/qualcomm/oss_scripts/maxvit_t.py @@ -0,0 +1,244 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import json +import logging +import os + +from multiprocessing.connection import Client + +import numpy as np + +import torch +import torch.nn.functional as F +import torchvision + +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.examples.qualcomm.utils import ( + build_executorch_binary, + get_imagenet_dataset, + make_output_dir, + make_quantizer, + setup_common_args_and_variables, + SimpleADB, + topk_accuracy, +) +from torchvision.models.maxvit import ( + PartitionAttentionLayer, + RelativePositionalMultiHeadAttention, +) + + +class WindowPartition(torch.nn.Module): + """ + Partition the input tensor into non-overlapping windows. + """ + + def __init__(self) -> None: + super().__init__() + + def forward(self, x: torch.Tensor, p: int) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor with expected layout of [B, C, H, W]. + p (int): Number of partitions. + Returns: + Tensor: Output tensor with expected layout of [B, H/P, W/P, P*P, C]. + """ + B, C, H, W = x.shape + P = p + # chunk up H and W dimensions + x = x.reshape(B * C, H // P, P, W // P, P) + x = x.permute(0, 1, 3, 2, 4) + # colapse P * P dimension + x = x.reshape(B, C, (H // P) * (W // P), P * P) + return x.permute(0, 2, 3, 1) + + +class WindowDepartition(torch.nn.Module): + """ + Departition the input tensor of non-overlapping windows into a feature volume of layout [B, C, H, W]. + """ + + def __init__(self) -> None: + super().__init__() + + def forward( + self, x: torch.Tensor, p: int, h_partitions: int, w_partitions: int + ) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor with expected layout of [B, (H/P * W/P), P*P, C]. + p (int): Number of partitions. + h_partitions (int): Number of vertical partitions. + w_partitions (int): Number of horizontal partitions. + Returns: + Tensor: Output tensor with expected layout of [B, C, H, W]. + """ + B, G, PP, C = x.shape + P = p + HP, WP = h_partitions, w_partitions + x = x.permute(0, 3, 1, 2) + # split P * P dimension into 2 P tile dimensionsa + x = x.reshape(B * C, HP, WP, P, P) + # permute into B * C, HP, P, WP, P + x = x.permute(0, 1, 3, 2, 4) + # reshape into B, C, H, W + x = x.reshape(B, C, HP * P, WP * P) + return x + + +def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor with expected layout of [B, G, P, D]. + Returns: + Tensor: Output tensor with expected layout of [B, G, P, D]. + """ + B, G, P, D = x.shape + H, DH = self.n_heads, self.head_dim + + qkv = self.to_qkv(x) + q, k, v = torch.chunk(qkv, 3, dim=-1) + + q = q.reshape(B * G, P, H, DH).permute(0, 2, 1, 3) + k = k.reshape(B * G, P, H, DH).permute(0, 2, 1, 3) + v = v.reshape(B * G, P, H, DH).permute(0, 2, 1, 3) + + k = k * self.scale_factor + dot_prod = torch.einsum("B H I D, B H J D -> B H I J", q, k) + pos_bias = self.get_relative_positional_bias() + + dot_prod = F.softmax(dot_prod + pos_bias, dim=-1) + + out = torch.einsum("B H I J, B H J D -> B H I D", dot_prod, v) + out = out.permute(0, 2, 1, 3).reshape(B, G, P, D) + + out = self.merge(out) + return out + + +def main(args): + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + data_num = 100 + if args.ci: + inputs = [(torch.rand(1, 3, 224, 224),)] + logging.warning( + "This option is for CI to verify the export flow. It uses random input and will result in poor accuracy." + ) + else: + inputs, targets = get_imagenet_dataset( + dataset_path=f"{args.dataset}", + data_size=data_num, + image_shape=(256, 256), + crop_size=224, + ) + + pte_filename = "maxvit_t_qnn_q8" + instance = torchvision.models.maxvit_t(weights="IMAGENET1K_V1").eval() + for block in instance.blocks: + for layer in block.layers: + for sub_layer in layer.layers: + if isinstance(sub_layer, PartitionAttentionLayer): + sub_layer.partition_op = WindowPartition() + sub_layer.departition_op = WindowDepartition() + for attn_sub_layer in sub_layer.attn_layer: + if isinstance( + attn_sub_layer, RelativePositionalMultiHeadAttention + ): + attn_sub_layer.forward = functools.partial( + forward, attn_sub_layer + ) + + build_executorch_binary( + instance, + inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + inputs, + custom_quantizer=make_quantizer( + quant_dtype=QuantDtype.use_8a8w, + per_channel_linear=True, + ), + shared_buffer=args.shared_buffer, + ) + + if args.compile_only: + return + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=f"{args.artifact}/{pte_filename}.pte", + workspace=f"/data/local/tmp/executorch/{pte_filename}", + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + ) + adb.push(inputs=inputs) + adb.execute() + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + + adb.pull(output_path=args.artifact) + + # top-k analysis + predictions = [] + for i in range(data_num): + predictions.append( + np.fromfile( + os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32 + ) + ) + + k_val = [1, 5] + topk = [topk_accuracy(predictions, targets, k).item() for k in k_val] + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)})) + else: + for i, k in enumerate(k_val): + print(f"top_{k}->{topk[i]}%") + + +if __name__ == "__main__": + parser = setup_common_args_and_variables() + parser.add_argument( + "-d", + "--dataset", + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + required=False, + ) + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./maxvit_t", + default="./maxvit_t", + type=str, + ) + + args = parser.parse_args() + args.validate(args) + try: + main(args) + except Exception as e: + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"Error": str(e)})) + else: + raise Exception(e) diff --git a/examples/qualcomm/oss_scripts/swin_v2_t.py b/examples/qualcomm/oss_scripts/swin_v2_t.py new file mode 100755 index 00000000000..954c27f428f --- /dev/null +++ b/examples/qualcomm/oss_scripts/swin_v2_t.py @@ -0,0 +1,185 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import logging +import os + +from multiprocessing.connection import Client + +import numpy as np + +import torch +import torchvision +from executorch.backends.qualcomm._passes.qnn_pass_manager import ( + FoldQDQ, + get_capture_program_passes, + get_passes_dependency_for_capture_program, + QCOM_PASS_ACTIVATE_KEY, + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY, +) + +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.examples.qualcomm.utils import ( + build_executorch_binary, + get_imagenet_dataset, + make_output_dir, + make_quantizer, + setup_common_args_and_variables, + SimpleADB, + topk_accuracy, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class RewritePartition(ExportPass): + """ + Rewrite 6D window partition pattern to 5D one. + """ + + def __init__(self): + super(RewritePartition, self).__init__() + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + # math equivalent implementation + for node in graph.nodes: + if ( + node.op == "call_function" + and node.target == exir_ops.edge.aten.permute_copy.default + and node.args[1] == [0, 1, 3, 2, 4, 5] + ): + # adjust original view node to take 5D tensor + view_node = node.args[0] + b, n_window_h, window_h, n_window_w, window_w, c = view_node.args[1] + shape = [b, n_window_h, window_h, n_window_w, window_w * c] + view_node.args = (view_node.args[0], shape) + view_node.meta["val"] = view_node.meta["val"].reshape(shape) + # change current permute node accordingly + axis_order = [0, 1, 3, 2, 4] + node.args = (view_node, axis_order) + node.meta["val"] = view_node.meta["val"].permute(axis_order) + + graph_module.recompile() + return PassResult(graph_module, True) + + +def main(args): + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + data_num = 100 + if args.ci: + inputs = [(torch.rand(1, 3, 224, 224),)] + logging.warning( + "This option is for CI to verify the export flow. It uses random input and will result in poor accuracy." + ) + else: + inputs, targets = get_imagenet_dataset( + dataset_path=f"{args.dataset}", + data_size=data_num, + image_shape=(256, 256), + crop_size=224, + ) + + pte_filename = "swin_v2_t_qnn_q8" + instance = torchvision.models.swin_v2_t(weights="IMAGENET1K_V1").eval() + passes_job = get_capture_program_passes() + passes_job[RewritePartition] = { + QCOM_PASS_ACTIVATE_KEY: True, + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY: {}, + } + passes_dep = get_passes_dependency_for_capture_program() + passes_dep[RewritePartition] = [FoldQDQ] + build_executorch_binary( + instance, + inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + inputs, + custom_quantizer=make_quantizer( + quant_dtype=QuantDtype.use_8a8w, + per_channel_linear=True, + ), + shared_buffer=args.shared_buffer, + passes_job=passes_job, + passes_dependency=passes_dep, + ) + + if args.compile_only: + return + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=f"{args.artifact}/{pte_filename}.pte", + workspace=f"/data/local/tmp/executorch/{pte_filename}", + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + ) + adb.push(inputs=inputs) + adb.execute() + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + + adb.pull(output_path=args.artifact) + + # top-k analysis + predictions = [] + for i in range(data_num): + predictions.append( + np.fromfile( + os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32 + ) + ) + + k_val = [1, 5] + topk = [topk_accuracy(predictions, targets, k).item() for k in k_val] + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)})) + else: + for i, k in enumerate(k_val): + print(f"top_{k}->{topk[i]}%") + + +if __name__ == "__main__": + parser = setup_common_args_and_variables() + parser.add_argument( + "-d", + "--dataset", + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + required=False, + ) + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./swin_v2_t", + default="./swin_v2_t", + type=str, + ) + + args = parser.parse_args() + args.validate(args) + try: + main(args) + except Exception as e: + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"Error": str(e)})) + else: + raise Exception(e) diff --git a/examples/qualcomm/oss_scripts/vit_b_16.py b/examples/qualcomm/oss_scripts/vit_b_16.py new file mode 100755 index 00000000000..6b79ecc7cda --- /dev/null +++ b/examples/qualcomm/oss_scripts/vit_b_16.py @@ -0,0 +1,135 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import logging +import os + +from multiprocessing.connection import Client + +import numpy as np + +import torch +import torchvision + +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.examples.qualcomm.utils import ( + build_executorch_binary, + get_imagenet_dataset, + make_output_dir, + make_quantizer, + setup_common_args_and_variables, + SimpleADB, + topk_accuracy, +) + + +def main(args): + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + data_num = 100 + if args.ci: + inputs = [(torch.rand(1, 3, 224, 224),)] + logging.warning( + "This option is for CI to verify the export flow. It uses random input and will result in poor accuracy." + ) + else: + inputs, targets = get_imagenet_dataset( + dataset_path=f"{args.dataset}", + data_size=data_num, + image_shape=(256, 256), + crop_size=224, + ) + + pte_filename = "vit_b_16_qnn_q8" + instance = torchvision.models.vit_b_16(weights="IMAGENET1K_V1").eval() + build_executorch_binary( + instance, + inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + inputs, + custom_quantizer=make_quantizer( + quant_dtype=QuantDtype.use_8a8w, + per_channel_linear=True, + ), + shared_buffer=args.shared_buffer, + ) + + if args.compile_only: + return + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=f"{args.artifact}/{pte_filename}.pte", + workspace=f"/data/local/tmp/executorch/{pte_filename}", + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + ) + adb.push(inputs=inputs) + adb.execute() + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + + adb.pull(output_path=args.artifact) + + # top-k analysis + predictions = [] + for i in range(data_num): + predictions.append( + np.fromfile( + os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32 + ) + ) + + k_val = [1, 5] + topk = [topk_accuracy(predictions, targets, k).item() for k in k_val] + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)})) + else: + for i, k in enumerate(k_val): + print(f"top_{k}->{topk[i]}%") + + +if __name__ == "__main__": + parser = setup_common_args_and_variables() + parser.add_argument( + "-d", + "--dataset", + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + required=False, + ) + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./vit_b_16", + default="./vit_b_16", + type=str, + ) + + args = parser.parse_args() + args.validate(args) + try: + main(args) + except Exception as e: + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"Error": str(e)})) + else: + raise Exception(e) diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index e43821bda64..11b9ab88bfe 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -384,6 +384,7 @@ def build_executorch_binary( metadata=None, dump_intermediate_outputs=False, passes_job=None, + passes_dependency=None, qat_training_data=None, online_prepare=False, optrace=False, @@ -406,6 +407,7 @@ def build_executorch_binary( metadata (dict, optional): An optional dictionary that maps each method name to a constant value in eager mode. dump_intermediate_outputs (bool, optional): Enables dumping model intermediate outputs. passes_job (OrderedDict, optional): Custom passes job in capture_program, users can enable/disable specific passes or modify their attributes. + passes_dependency (Dict, optional): A dictionary mapping each pass to its corresponding list of dependencies. qat_training_data (List[torch.Tensor], optional): A dataset for quantization aware training(QAT). Typically is a pair of tensors, such as [features, ground truth]. online_prepare (bool, optional): Compose QNN graph on device if set to True. optrace (bool, optional): Enable optrace mode for performance analysis if set to True. @@ -449,6 +451,7 @@ def build_executorch_binary( compile_spec, constant_methods=metadata, passes_job=passes_job, + dep_table=passes_dependency, skip_node_id_set=skip_node_id_set, skip_node_op_set=skip_node_op_set, )