From 3baa6f8981eb4366005910ef70d1ba4825b89059 Mon Sep 17 00:00:00 2001 From: maxren Date: Fri, 25 Aug 2023 16:56:19 -0700 Subject: [PATCH 1/3] handle static ints and floats Differential Revision: D48667679 fbshipit-source-id: ca01aff1fdc7888e0abcb5a5437b0b6f2231dbd8 --- backends/xnnpack/operators/node_visitor.py | 5 +++++ backends/xnnpack/partition/configs.py | 3 +++ 2 files changed, 8 insertions(+) diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index e1f12bb51b8..5f4e871f90b 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -230,6 +230,7 @@ def define_tensor( # Get new xnn id for tensor value ext_id, id_out, flag = self.gen_ids_and_flags(tensor, xnn_graph, quant_params) dims = get_shape(tensor) + dims = [1] if len(dims) == 0 else dims # constant values serialize data buffer_idx = self.get_serialized_buffer( @@ -336,6 +337,10 @@ def get_serialized_buffer( # Quantize buffer if static data is indeed quantized if quant_params is not None and not quant_params.is_dynamic: const_val = quant_params.quantize_tensor(const_val).contiguous() + else: + # ensure that the const is fp32 + const_val = const_val.to(dtype=torch.float32).contiguous() + if swap_nc_for_depthwise_weights: const_val = const_val.permute( dims=((1, 0) + tuple(range(2, const_val.dim()))) diff --git a/backends/xnnpack/partition/configs.py b/backends/xnnpack/partition/configs.py index a3b66d3fcb4..1a5f567bf43 100644 --- a/backends/xnnpack/partition/configs.py +++ b/backends/xnnpack/partition/configs.py @@ -37,6 +37,9 @@ SUPPORTED_MODULES = [ torch.nn.Conv1d, + # TODO(T161981984) recomposed hardswish into a single node + torch.nn.Hardswish, + torch.nn.Hardsigmoid, torch.nn.Conv2d, torch.nn.ReLU, torch.nn.Sigmoid, From 93b28d4f6a25f0b869309f8317b7f309feb10682 Mon Sep 17 00:00:00 2001 From: maxren Date: Fri, 25 Aug 2023 16:56:19 -0700 Subject: [PATCH 2/3] Use IDs for placeholder/outputs of implicit q/dq nodes Differential Revision: D48667676 fbshipit-source-id: 19dc3ff786f6a57dd68102bf69acf2cb0e68363e --- backends/xnnpack/operators/node_visitor.py | 38 ++++++++++++++-------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index 5f4e871f90b..4ab0777a6cb 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -130,21 +130,33 @@ def gen_ids_and_flags( # This will break if we change the way q/dq are partitioned # Tensor can still be input if its quantizing node is an input - is_input = (self).is_graph_input(tensor) + if self.is_graph_input(tensor) or ( + quant_params.is_input if quant_params else False + ): + tensor_input = tensor + if quant_params: + if quant_params.is_input and not self.is_graph_input(tensor): + tensor_input = quant_params.q_input + assert ( + tensor_input in self.external_ids.keys() + ), f"Tensor {tensor_input}, is_input. ext_ids: {self.external_ids.keys()}" + ext_id = self.external_ids[tensor_input].external_id + xnn_graph.input_ids.append(id_out) + flag = self.external_ids[tensor_input].io_type # Tensor can still be output if its quantizing node is an output - is_output = self.is_graph_output(tensor) - # handle logic for input/output tensors - if is_input or is_output: + elif self.is_graph_output(tensor) or ( + quant_params.is_output if quant_params else False + ): + tensor_output = tensor + if quant_params: + if quant_params.is_output and not self.is_graph_output(tensor): + tensor_output = list(tensor.users)[0] assert ( - tensor in self.external_ids.keys() - ), f"Tensor {tensor}, is_input: {is_input}, is_output: {is_output}, ext_ids: {self.external_ids.keys()}" - ext_id = self.external_ids[tensor].external_id - if is_input: - xnn_graph.input_ids.append(id_out) - flag = XNN_VALUE_FLAG_EXTERNAL_INPUT - if is_output: - xnn_graph.output_ids.append(id_out) - flag = XNN_VALUE_FLAG_EXTERNAL_OUTPUT + tensor_output in self.external_ids.keys() + ), f"Tensor {tensor_output} is_output: ext_ids: {self.external_ids.keys()}" + ext_id = self.external_ids[tensor_output].external_id + xnn_graph.output_ids.append(id_out) + flag = self.external_ids[tensor_output].io_type return ext_id, id_out, flag From c5a4e0fd4427547d04a927f6d5ff7008757ffed2 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Fri, 25 Aug 2023 16:56:42 -0700 Subject: [PATCH 3/3] add unsupported module list for partitioner (#139) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/139 For Quantized Mobilenetv3. Right now decomposed hardswish and hardsigmoid gets partitioned with XNNPACKQuantizedPartitioner2, because they contain simple ops like add, mul, div, etc. These ops are all in FP32 since xnnpack doesn't support the quantized variants. Since we don't currently have support for a unified partition, we want to block the from our Quantized Partitioner. Reviewed By: digantdesai Differential Revision: D48667680 fbshipit-source-id: d055d4b7d5d0e0f0d3633635c16cfca7e6c5f48a --- backends/xnnpack/partition/configs.py | 5 ++++ .../xnnpack/partition/xnnpack_partitioner.py | 24 +++++++++++++++---- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/backends/xnnpack/partition/configs.py b/backends/xnnpack/partition/configs.py index 1a5f567bf43..99277fa029b 100644 --- a/backends/xnnpack/partition/configs.py +++ b/backends/xnnpack/partition/configs.py @@ -81,6 +81,11 @@ ) } +UNSUPPORTED_QUANT_MODULES = [ + torch.nn.Hardswish, + torch.nn.Hardsigmoid, +] + # TODO delete this and should use SUPPORTED_MODULES instead once we align fp32 and quant support SUPPORTED_QUANT_MODULES = [ torch.clamp, diff --git a/backends/xnnpack/partition/xnnpack_partitioner.py b/backends/xnnpack/partition/xnnpack_partitioner.py index 59319358993..45667bf0fd4 100644 --- a/backends/xnnpack/partition/xnnpack_partitioner.py +++ b/backends/xnnpack/partition/xnnpack_partitioner.py @@ -17,6 +17,7 @@ SUPPORTED_OPS, SUPPORTED_QUANT_MODULES, SUPPORTED_QUANT_OPS, + UNSUPPORTED_QUANT_MODULES, ) from executorch.backends.xnnpack.partition.support_patterns import ( get_add_graphs, @@ -103,6 +104,7 @@ def __init__( Any, Callable[[torch.fx.Node], bool] ] = _OP_SUPPORT_CONSTRAINTS, supported_ops: Optional[List] = None, + unsupported_modules: Optional[List] = None, ): """ @Arg constraints_dict: Dict mapping each node to a lambda function that @@ -110,10 +112,17 @@ def __init__( node. @Arg supported_ops: List of supported operators for partitioning """ + self.unsupported_modules = unsupported_modules self.supported_ops = supported_ops self.constraints = constraints_dict assert len(self.constraints) + def check_common_constraints(self, node) -> bool: + if self.unsupported_modules and "source_fn" in node.meta: + return not node.meta["source_fn"][1] in self.unsupported_modules + + return True + @staticmethod def check_constraint(node) -> bool: """ @@ -132,7 +141,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: if self.supported_ops and node.target not in self.supported_ops: return False - return self.check_constraint(node) + return self.check_constraint(node) and self.check_common_constraints(node) def _constraint(target): # noqa """ @@ -540,10 +549,11 @@ def __init__( self, supported_modules: List[Callable] = SUPPORTED_MODULES, supported_ops: Optional[List[Callable]] = SUPPORTED_OPS, + unsupported_modules: Optional[List[Callable]] = None, ): super().__init__() self.supported_modules = set(supported_modules) - + self.unsupported_modules = unsupported_modules self.supported_ops = set(supported_ops or []) self.delegation_spec = DelegationSpec(XnnpackBackend.__name__, []) @@ -614,7 +624,10 @@ def generate_partitions(self, graph_module: torch.fx.GraphModule) -> List[Any]: return generate_partitions_from_list_of_nodes( graph_module, matched_module_nodes, - XnnpackOperatorSupport(supported_ops=self.supported_ops), + XnnpackOperatorSupport( + supported_ops=self.supported_ops, + unsupported_modules=self.unsupported_modules, + ), ) def tag_nodes(self, partitions: List[Partition]) -> None: @@ -668,9 +681,12 @@ def __init__( self, supported_modules=SUPPORTED_QUANT_MODULES, supported_ops=SUPPORTED_QUANT_OPS, + unsupported_modules=UNSUPPORTED_QUANT_MODULES, ): supported_ops = supported_ops or [] - super().__init__(supported_modules, supported_ops + self._QUANT_OPS) + super().__init__( + supported_modules, supported_ops + self._QUANT_OPS, unsupported_modules + ) # TODO Refactor this # TODO Don't be greedy when pulling q->dq pairs for a given op, add convert tracker pass