From eb96f2bb36ee1be00310a9ffc8cdee9403a1560d Mon Sep 17 00:00:00 2001 From: maxren Date: Fri, 25 Aug 2023 16:55:31 -0700 Subject: [PATCH 1/5] handle static ints and floats Differential Revision: D48667679 fbshipit-source-id: 961dda60a88209f85e95276ce26cbc6d8ccc3a4b --- backends/xnnpack/operators/node_visitor.py | 5 +++++ backends/xnnpack/partition/configs.py | 3 +++ 2 files changed, 8 insertions(+) diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index e1f12bb51b8..5f4e871f90b 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -230,6 +230,7 @@ def define_tensor( # Get new xnn id for tensor value ext_id, id_out, flag = self.gen_ids_and_flags(tensor, xnn_graph, quant_params) dims = get_shape(tensor) + dims = [1] if len(dims) == 0 else dims # constant values serialize data buffer_idx = self.get_serialized_buffer( @@ -336,6 +337,10 @@ def get_serialized_buffer( # Quantize buffer if static data is indeed quantized if quant_params is not None and not quant_params.is_dynamic: const_val = quant_params.quantize_tensor(const_val).contiguous() + else: + # ensure that the const is fp32 + const_val = const_val.to(dtype=torch.float32).contiguous() + if swap_nc_for_depthwise_weights: const_val = const_val.permute( dims=((1, 0) + tuple(range(2, const_val.dim()))) diff --git a/backends/xnnpack/partition/configs.py b/backends/xnnpack/partition/configs.py index a3b66d3fcb4..1a5f567bf43 100644 --- a/backends/xnnpack/partition/configs.py +++ b/backends/xnnpack/partition/configs.py @@ -37,6 +37,9 @@ SUPPORTED_MODULES = [ torch.nn.Conv1d, + # TODO(T161981984) recomposed hardswish into a single node + torch.nn.Hardswish, + torch.nn.Hardsigmoid, torch.nn.Conv2d, torch.nn.ReLU, torch.nn.Sigmoid, From d0c28663e1dc69f8ca1e45471ed3100a8a043ac6 Mon Sep 17 00:00:00 2001 From: maxren Date: Fri, 25 Aug 2023 16:55:31 -0700 Subject: [PATCH 2/5] Use IDs for placeholder/outputs of implicit q/dq nodes Differential Revision: D48667676 fbshipit-source-id: 9998b49f97971029545a4d3698de501e2edade56 --- backends/xnnpack/operators/node_visitor.py | 38 ++++++++++++++-------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index 5f4e871f90b..4ab0777a6cb 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -130,21 +130,33 @@ def gen_ids_and_flags( # This will break if we change the way q/dq are partitioned # Tensor can still be input if its quantizing node is an input - is_input = (self).is_graph_input(tensor) + if self.is_graph_input(tensor) or ( + quant_params.is_input if quant_params else False + ): + tensor_input = tensor + if quant_params: + if quant_params.is_input and not self.is_graph_input(tensor): + tensor_input = quant_params.q_input + assert ( + tensor_input in self.external_ids.keys() + ), f"Tensor {tensor_input}, is_input. ext_ids: {self.external_ids.keys()}" + ext_id = self.external_ids[tensor_input].external_id + xnn_graph.input_ids.append(id_out) + flag = self.external_ids[tensor_input].io_type # Tensor can still be output if its quantizing node is an output - is_output = self.is_graph_output(tensor) - # handle logic for input/output tensors - if is_input or is_output: + elif self.is_graph_output(tensor) or ( + quant_params.is_output if quant_params else False + ): + tensor_output = tensor + if quant_params: + if quant_params.is_output and not self.is_graph_output(tensor): + tensor_output = list(tensor.users)[0] assert ( - tensor in self.external_ids.keys() - ), f"Tensor {tensor}, is_input: {is_input}, is_output: {is_output}, ext_ids: {self.external_ids.keys()}" - ext_id = self.external_ids[tensor].external_id - if is_input: - xnn_graph.input_ids.append(id_out) - flag = XNN_VALUE_FLAG_EXTERNAL_INPUT - if is_output: - xnn_graph.output_ids.append(id_out) - flag = XNN_VALUE_FLAG_EXTERNAL_OUTPUT + tensor_output in self.external_ids.keys() + ), f"Tensor {tensor_output} is_output: ext_ids: {self.external_ids.keys()}" + ext_id = self.external_ids[tensor_output].external_id + xnn_graph.output_ids.append(id_out) + flag = self.external_ids[tensor_output].io_type return ext_id, id_out, flag From ef009c27e323c307465337be196e7a0280425dea Mon Sep 17 00:00:00 2001 From: maxren Date: Fri, 25 Aug 2023 16:55:31 -0700 Subject: [PATCH 3/5] add unsupported module list for partitioner Differential Revision: D48667680 fbshipit-source-id: 2b69e340dbadd1cd0215e4b128e3952c0fe94f45 --- backends/xnnpack/partition/configs.py | 5 ++++ .../xnnpack/partition/xnnpack_partitioner.py | 24 +++++++++++++++---- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/backends/xnnpack/partition/configs.py b/backends/xnnpack/partition/configs.py index 1a5f567bf43..99277fa029b 100644 --- a/backends/xnnpack/partition/configs.py +++ b/backends/xnnpack/partition/configs.py @@ -81,6 +81,11 @@ ) } +UNSUPPORTED_QUANT_MODULES = [ + torch.nn.Hardswish, + torch.nn.Hardsigmoid, +] + # TODO delete this and should use SUPPORTED_MODULES instead once we align fp32 and quant support SUPPORTED_QUANT_MODULES = [ torch.clamp, diff --git a/backends/xnnpack/partition/xnnpack_partitioner.py b/backends/xnnpack/partition/xnnpack_partitioner.py index 59319358993..45667bf0fd4 100644 --- a/backends/xnnpack/partition/xnnpack_partitioner.py +++ b/backends/xnnpack/partition/xnnpack_partitioner.py @@ -17,6 +17,7 @@ SUPPORTED_OPS, SUPPORTED_QUANT_MODULES, SUPPORTED_QUANT_OPS, + UNSUPPORTED_QUANT_MODULES, ) from executorch.backends.xnnpack.partition.support_patterns import ( get_add_graphs, @@ -103,6 +104,7 @@ def __init__( Any, Callable[[torch.fx.Node], bool] ] = _OP_SUPPORT_CONSTRAINTS, supported_ops: Optional[List] = None, + unsupported_modules: Optional[List] = None, ): """ @Arg constraints_dict: Dict mapping each node to a lambda function that @@ -110,10 +112,17 @@ def __init__( node. @Arg supported_ops: List of supported operators for partitioning """ + self.unsupported_modules = unsupported_modules self.supported_ops = supported_ops self.constraints = constraints_dict assert len(self.constraints) + def check_common_constraints(self, node) -> bool: + if self.unsupported_modules and "source_fn" in node.meta: + return not node.meta["source_fn"][1] in self.unsupported_modules + + return True + @staticmethod def check_constraint(node) -> bool: """ @@ -132,7 +141,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: if self.supported_ops and node.target not in self.supported_ops: return False - return self.check_constraint(node) + return self.check_constraint(node) and self.check_common_constraints(node) def _constraint(target): # noqa """ @@ -540,10 +549,11 @@ def __init__( self, supported_modules: List[Callable] = SUPPORTED_MODULES, supported_ops: Optional[List[Callable]] = SUPPORTED_OPS, + unsupported_modules: Optional[List[Callable]] = None, ): super().__init__() self.supported_modules = set(supported_modules) - + self.unsupported_modules = unsupported_modules self.supported_ops = set(supported_ops or []) self.delegation_spec = DelegationSpec(XnnpackBackend.__name__, []) @@ -614,7 +624,10 @@ def generate_partitions(self, graph_module: torch.fx.GraphModule) -> List[Any]: return generate_partitions_from_list_of_nodes( graph_module, matched_module_nodes, - XnnpackOperatorSupport(supported_ops=self.supported_ops), + XnnpackOperatorSupport( + supported_ops=self.supported_ops, + unsupported_modules=self.unsupported_modules, + ), ) def tag_nodes(self, partitions: List[Partition]) -> None: @@ -668,9 +681,12 @@ def __init__( self, supported_modules=SUPPORTED_QUANT_MODULES, supported_ops=SUPPORTED_QUANT_OPS, + unsupported_modules=UNSUPPORTED_QUANT_MODULES, ): supported_ops = supported_ops or [] - super().__init__(supported_modules, supported_ops + self._QUANT_OPS) + super().__init__( + supported_modules, supported_ops + self._QUANT_OPS, unsupported_modules + ) # TODO Refactor this # TODO Don't be greedy when pulling q->dq pairs for a given op, add convert tracker pass From 881fd8ea600a28aadeac973a4716e0a22071eca3 Mon Sep 17 00:00:00 2001 From: maxren Date: Fri, 25 Aug 2023 16:55:31 -0700 Subject: [PATCH 4/5] MobileNetv3 FP32 Differential Revision: D48667678 fbshipit-source-id: 97e0dcfb04c3185fb3da3356e15ea1bb5ae9c041 --- backends/xnnpack/test/models/mobilenet_v3.py | 82 ++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 backends/xnnpack/test/models/mobilenet_v3.py diff --git a/backends/xnnpack/test/models/mobilenet_v3.py b/backends/xnnpack/test/models/mobilenet_v3.py new file mode 100644 index 00000000000..4a1df39b09e --- /dev/null +++ b/backends/xnnpack/test/models/mobilenet_v3.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +import torchvision.models as models +from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackQuantizedPartitioner2, +) +from executorch.backends.xnnpack.test.tester import Partition, Tester +from executorch.backends.xnnpack.test.tester.tester import Export +from executorch.backends.xnnpack.utils.configs import get_xnnpack_capture_config + + +class TestMobileNetV3(unittest.TestCase): + export_stage = Export(get_xnnpack_capture_config(enable_aot=True)) + + mv3 = models.mobilenetv3.mobilenet_v3_small(pretrained=True) + mv3 = mv3.eval() + model_inputs = (torch.ones(1, 3, 224, 244),) + + all_operators = { + "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default", + "executorch_exir_dialects_edge__ops_aten_clamp_default", + "executorch_exir_dialects_edge__ops_aten_permute_copy_default", + "executorch_exir_dialects_edge__ops_aten_addmm_default", + "executorch_exir_dialects_edge__ops_aten__to_copy_default", + "executorch_exir_dialects_edge__ops_aten_convolution_default", + "executorch_exir_dialects_edge__ops_aten_relu_default", + "executorch_exir_dialects_edge__ops_aten_add_Tensor", + "executorch_exir_dialects_edge__ops_aten_mul_Tensor", + "executorch_exir_dialects_edge__ops_aten_div_Tensor", + "executorch_exir_dialects_edge__ops_aten_mean_dim", + } + + def test_fp32(self): + ( + Tester(self.mv3, self.model_inputs) + .export(self.export_stage) + .to_edge() + .check(list(self.all_operators)) + .partition() + .check(["torch.ops.executorch_call_delegate"]) + .check_not(list(self.all_operators)) + .to_executorch() + .serialize() + .run_method() + .compare_outputs() + ) + + def test_qs8_pt2e(self): + ops_after_quantization = self.all_operators - { + "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default", + } + ops_after_lowering = self.all_operators - { + # TODO: unified partitioner since hardswish/hardsigmoid decomposed operators are not quantized + # They will not be partitioned by quantized partitioner + "executorch_exir_dialects_edge__ops_aten_add_Tensor", + "executorch_exir_dialects_edge__ops_aten_mul_Tensor", + "executorch_exir_dialects_edge__ops_aten_div_Tensor", + "executorch_exir_dialects_edge__ops_aten_clamp_default", + "executorch_exir_dialects_edge__ops_aten__to_copy_default", + } + + ( + Tester(self.mv3, self.model_inputs) + .quantize2() + .export(self.export_stage) + .to_edge() + .check(list(ops_after_quantization)) + .partition(Partition(partitioner=XnnpackQuantizedPartitioner2)) + .check(["torch.ops.executorch_call_delegate"]) + .check_not(list(ops_after_lowering)) + .to_executorch() + .serialize() + .run_method() + .compare_outputs() + ) From 48df6517556a8a7056d743d70edcf08562fa3f26 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Fri, 25 Aug 2023 16:55:50 -0700 Subject: [PATCH 5/5] Fix XNNPACK Example for MV3 Summary: xnnpack requires dequant nodes to be duplicated such that len(dq.users) == 1. We run this pass on our internal XNNPACK tests From Jerry: This will be a config that is used within the XNNPACKQuantizer. This will be added later, for the sake of enabling this in the example, we add the duplicatedequantnodepass to our edge compile config Additionally, graph_validation is failing for MV3, so we have to disable_validation for now. cccclai is looking into it Reviewed By: digantdesai Differential Revision: D48703639 fbshipit-source-id: 6bdca0480ea90782fc6a857095ec773d5ded1071 --- examples/backend/xnnpack_examples.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/backend/xnnpack_examples.py b/examples/backend/xnnpack_examples.py index 1f0f3edcee3..7e608e8ae45 100644 --- a/examples/backend/xnnpack_examples.py +++ b/examples/backend/xnnpack_examples.py @@ -15,6 +15,9 @@ XnnpackQuantizedPartitioner2, ) from executorch.exir.backend.backend_api import to_backend +from executorch.exir.backend.canonical_partitioners.duplicate_dequant_node_pass import ( + DuplicateDequantNodePass, +) from ..models import MODEL_NAME_TO_MODEL, MODEL_NAME_TO_OPTIONS from ..quantization.utils import quantize @@ -77,7 +80,13 @@ # It will eventually be changed to a lifted graph, in which _unlift=False, edge = exir.capture( model, example_inputs, exir.CaptureConfig(enable_aot=True, _unlift=True) - ).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False)) + ).to_edge( + exir.EdgeCompileConfig( + # TODO(T162080278): Duplicated Dequant nodes will be in quantizer spec + _check_ir_validity=False, + passes=[DuplicateDequantNodePass()], + ) + ) logging.info(f"Exported graph:\n{edge.exported_program.graph}") edge.exported_program = to_backend(edge.exported_program, partitioner)