diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index e1f12bb51b8..4ab0777a6cb 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -130,21 +130,33 @@ def gen_ids_and_flags( # This will break if we change the way q/dq are partitioned # Tensor can still be input if its quantizing node is an input - is_input = (self).is_graph_input(tensor) + if self.is_graph_input(tensor) or ( + quant_params.is_input if quant_params else False + ): + tensor_input = tensor + if quant_params: + if quant_params.is_input and not self.is_graph_input(tensor): + tensor_input = quant_params.q_input + assert ( + tensor_input in self.external_ids.keys() + ), f"Tensor {tensor_input}, is_input. ext_ids: {self.external_ids.keys()}" + ext_id = self.external_ids[tensor_input].external_id + xnn_graph.input_ids.append(id_out) + flag = self.external_ids[tensor_input].io_type # Tensor can still be output if its quantizing node is an output - is_output = self.is_graph_output(tensor) - # handle logic for input/output tensors - if is_input or is_output: + elif self.is_graph_output(tensor) or ( + quant_params.is_output if quant_params else False + ): + tensor_output = tensor + if quant_params: + if quant_params.is_output and not self.is_graph_output(tensor): + tensor_output = list(tensor.users)[0] assert ( - tensor in self.external_ids.keys() - ), f"Tensor {tensor}, is_input: {is_input}, is_output: {is_output}, ext_ids: {self.external_ids.keys()}" - ext_id = self.external_ids[tensor].external_id - if is_input: - xnn_graph.input_ids.append(id_out) - flag = XNN_VALUE_FLAG_EXTERNAL_INPUT - if is_output: - xnn_graph.output_ids.append(id_out) - flag = XNN_VALUE_FLAG_EXTERNAL_OUTPUT + tensor_output in self.external_ids.keys() + ), f"Tensor {tensor_output} is_output: ext_ids: {self.external_ids.keys()}" + ext_id = self.external_ids[tensor_output].external_id + xnn_graph.output_ids.append(id_out) + flag = self.external_ids[tensor_output].io_type return ext_id, id_out, flag @@ -230,6 +242,7 @@ def define_tensor( # Get new xnn id for tensor value ext_id, id_out, flag = self.gen_ids_and_flags(tensor, xnn_graph, quant_params) dims = get_shape(tensor) + dims = [1] if len(dims) == 0 else dims # constant values serialize data buffer_idx = self.get_serialized_buffer( @@ -336,6 +349,10 @@ def get_serialized_buffer( # Quantize buffer if static data is indeed quantized if quant_params is not None and not quant_params.is_dynamic: const_val = quant_params.quantize_tensor(const_val).contiguous() + else: + # ensure that the const is fp32 + const_val = const_val.to(dtype=torch.float32).contiguous() + if swap_nc_for_depthwise_weights: const_val = const_val.permute( dims=((1, 0) + tuple(range(2, const_val.dim()))) diff --git a/backends/xnnpack/partition/configs.py b/backends/xnnpack/partition/configs.py index a3b66d3fcb4..99277fa029b 100644 --- a/backends/xnnpack/partition/configs.py +++ b/backends/xnnpack/partition/configs.py @@ -37,6 +37,9 @@ SUPPORTED_MODULES = [ torch.nn.Conv1d, + # TODO(T161981984) recomposed hardswish into a single node + torch.nn.Hardswish, + torch.nn.Hardsigmoid, torch.nn.Conv2d, torch.nn.ReLU, torch.nn.Sigmoid, @@ -78,6 +81,11 @@ ) } +UNSUPPORTED_QUANT_MODULES = [ + torch.nn.Hardswish, + torch.nn.Hardsigmoid, +] + # TODO delete this and should use SUPPORTED_MODULES instead once we align fp32 and quant support SUPPORTED_QUANT_MODULES = [ torch.clamp, diff --git a/backends/xnnpack/partition/xnnpack_partitioner.py b/backends/xnnpack/partition/xnnpack_partitioner.py index 59319358993..45667bf0fd4 100644 --- a/backends/xnnpack/partition/xnnpack_partitioner.py +++ b/backends/xnnpack/partition/xnnpack_partitioner.py @@ -17,6 +17,7 @@ SUPPORTED_OPS, SUPPORTED_QUANT_MODULES, SUPPORTED_QUANT_OPS, + UNSUPPORTED_QUANT_MODULES, ) from executorch.backends.xnnpack.partition.support_patterns import ( get_add_graphs, @@ -103,6 +104,7 @@ def __init__( Any, Callable[[torch.fx.Node], bool] ] = _OP_SUPPORT_CONSTRAINTS, supported_ops: Optional[List] = None, + unsupported_modules: Optional[List] = None, ): """ @Arg constraints_dict: Dict mapping each node to a lambda function that @@ -110,10 +112,17 @@ def __init__( node. @Arg supported_ops: List of supported operators for partitioning """ + self.unsupported_modules = unsupported_modules self.supported_ops = supported_ops self.constraints = constraints_dict assert len(self.constraints) + def check_common_constraints(self, node) -> bool: + if self.unsupported_modules and "source_fn" in node.meta: + return not node.meta["source_fn"][1] in self.unsupported_modules + + return True + @staticmethod def check_constraint(node) -> bool: """ @@ -132,7 +141,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: if self.supported_ops and node.target not in self.supported_ops: return False - return self.check_constraint(node) + return self.check_constraint(node) and self.check_common_constraints(node) def _constraint(target): # noqa """ @@ -540,10 +549,11 @@ def __init__( self, supported_modules: List[Callable] = SUPPORTED_MODULES, supported_ops: Optional[List[Callable]] = SUPPORTED_OPS, + unsupported_modules: Optional[List[Callable]] = None, ): super().__init__() self.supported_modules = set(supported_modules) - + self.unsupported_modules = unsupported_modules self.supported_ops = set(supported_ops or []) self.delegation_spec = DelegationSpec(XnnpackBackend.__name__, []) @@ -614,7 +624,10 @@ def generate_partitions(self, graph_module: torch.fx.GraphModule) -> List[Any]: return generate_partitions_from_list_of_nodes( graph_module, matched_module_nodes, - XnnpackOperatorSupport(supported_ops=self.supported_ops), + XnnpackOperatorSupport( + supported_ops=self.supported_ops, + unsupported_modules=self.unsupported_modules, + ), ) def tag_nodes(self, partitions: List[Partition]) -> None: @@ -668,9 +681,12 @@ def __init__( self, supported_modules=SUPPORTED_QUANT_MODULES, supported_ops=SUPPORTED_QUANT_OPS, + unsupported_modules=UNSUPPORTED_QUANT_MODULES, ): supported_ops = supported_ops or [] - super().__init__(supported_modules, supported_ops + self._QUANT_OPS) + super().__init__( + supported_modules, supported_ops + self._QUANT_OPS, unsupported_modules + ) # TODO Refactor this # TODO Don't be greedy when pulling q->dq pairs for a given op, add convert tracker pass diff --git a/backends/xnnpack/test/models/mobilenet_v3.py b/backends/xnnpack/test/models/mobilenet_v3.py new file mode 100644 index 00000000000..4a1df39b09e --- /dev/null +++ b/backends/xnnpack/test/models/mobilenet_v3.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +import torchvision.models as models +from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackQuantizedPartitioner2, +) +from executorch.backends.xnnpack.test.tester import Partition, Tester +from executorch.backends.xnnpack.test.tester.tester import Export +from executorch.backends.xnnpack.utils.configs import get_xnnpack_capture_config + + +class TestMobileNetV3(unittest.TestCase): + export_stage = Export(get_xnnpack_capture_config(enable_aot=True)) + + mv3 = models.mobilenetv3.mobilenet_v3_small(pretrained=True) + mv3 = mv3.eval() + model_inputs = (torch.ones(1, 3, 224, 244),) + + all_operators = { + "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default", + "executorch_exir_dialects_edge__ops_aten_clamp_default", + "executorch_exir_dialects_edge__ops_aten_permute_copy_default", + "executorch_exir_dialects_edge__ops_aten_addmm_default", + "executorch_exir_dialects_edge__ops_aten__to_copy_default", + "executorch_exir_dialects_edge__ops_aten_convolution_default", + "executorch_exir_dialects_edge__ops_aten_relu_default", + "executorch_exir_dialects_edge__ops_aten_add_Tensor", + "executorch_exir_dialects_edge__ops_aten_mul_Tensor", + "executorch_exir_dialects_edge__ops_aten_div_Tensor", + "executorch_exir_dialects_edge__ops_aten_mean_dim", + } + + def test_fp32(self): + ( + Tester(self.mv3, self.model_inputs) + .export(self.export_stage) + .to_edge() + .check(list(self.all_operators)) + .partition() + .check(["torch.ops.executorch_call_delegate"]) + .check_not(list(self.all_operators)) + .to_executorch() + .serialize() + .run_method() + .compare_outputs() + ) + + def test_qs8_pt2e(self): + ops_after_quantization = self.all_operators - { + "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default", + } + ops_after_lowering = self.all_operators - { + # TODO: unified partitioner since hardswish/hardsigmoid decomposed operators are not quantized + # They will not be partitioned by quantized partitioner + "executorch_exir_dialects_edge__ops_aten_add_Tensor", + "executorch_exir_dialects_edge__ops_aten_mul_Tensor", + "executorch_exir_dialects_edge__ops_aten_div_Tensor", + "executorch_exir_dialects_edge__ops_aten_clamp_default", + "executorch_exir_dialects_edge__ops_aten__to_copy_default", + } + + ( + Tester(self.mv3, self.model_inputs) + .quantize2() + .export(self.export_stage) + .to_edge() + .check(list(ops_after_quantization)) + .partition(Partition(partitioner=XnnpackQuantizedPartitioner2)) + .check(["torch.ops.executorch_call_delegate"]) + .check_not(list(ops_after_lowering)) + .to_executorch() + .serialize() + .run_method() + .compare_outputs() + ) diff --git a/examples/backend/xnnpack_examples.py b/examples/backend/xnnpack_examples.py index 1f0f3edcee3..7e608e8ae45 100644 --- a/examples/backend/xnnpack_examples.py +++ b/examples/backend/xnnpack_examples.py @@ -15,6 +15,9 @@ XnnpackQuantizedPartitioner2, ) from executorch.exir.backend.backend_api import to_backend +from executorch.exir.backend.canonical_partitioners.duplicate_dequant_node_pass import ( + DuplicateDequantNodePass, +) from ..models import MODEL_NAME_TO_MODEL, MODEL_NAME_TO_OPTIONS from ..quantization.utils import quantize @@ -77,7 +80,13 @@ # It will eventually be changed to a lifted graph, in which _unlift=False, edge = exir.capture( model, example_inputs, exir.CaptureConfig(enable_aot=True, _unlift=True) - ).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False)) + ).to_edge( + exir.EdgeCompileConfig( + # TODO(T162080278): Duplicated Dequant nodes will be in quantizer spec + _check_ir_validity=False, + passes=[DuplicateDequantNodePass()], + ) + ) logging.info(f"Exported graph:\n{edge.exported_program.graph}") edge.exported_program = to_backend(edge.exported_program, partitioner)