diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index e1f12bb51b8..4ab0777a6cb 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -130,21 +130,33 @@ def gen_ids_and_flags( # This will break if we change the way q/dq are partitioned # Tensor can still be input if its quantizing node is an input - is_input = (self).is_graph_input(tensor) + if self.is_graph_input(tensor) or ( + quant_params.is_input if quant_params else False + ): + tensor_input = tensor + if quant_params: + if quant_params.is_input and not self.is_graph_input(tensor): + tensor_input = quant_params.q_input + assert ( + tensor_input in self.external_ids.keys() + ), f"Tensor {tensor_input}, is_input. ext_ids: {self.external_ids.keys()}" + ext_id = self.external_ids[tensor_input].external_id + xnn_graph.input_ids.append(id_out) + flag = self.external_ids[tensor_input].io_type # Tensor can still be output if its quantizing node is an output - is_output = self.is_graph_output(tensor) - # handle logic for input/output tensors - if is_input or is_output: + elif self.is_graph_output(tensor) or ( + quant_params.is_output if quant_params else False + ): + tensor_output = tensor + if quant_params: + if quant_params.is_output and not self.is_graph_output(tensor): + tensor_output = list(tensor.users)[0] assert ( - tensor in self.external_ids.keys() - ), f"Tensor {tensor}, is_input: {is_input}, is_output: {is_output}, ext_ids: {self.external_ids.keys()}" - ext_id = self.external_ids[tensor].external_id - if is_input: - xnn_graph.input_ids.append(id_out) - flag = XNN_VALUE_FLAG_EXTERNAL_INPUT - if is_output: - xnn_graph.output_ids.append(id_out) - flag = XNN_VALUE_FLAG_EXTERNAL_OUTPUT + tensor_output in self.external_ids.keys() + ), f"Tensor {tensor_output} is_output: ext_ids: {self.external_ids.keys()}" + ext_id = self.external_ids[tensor_output].external_id + xnn_graph.output_ids.append(id_out) + flag = self.external_ids[tensor_output].io_type return ext_id, id_out, flag @@ -230,6 +242,7 @@ def define_tensor( # Get new xnn id for tensor value ext_id, id_out, flag = self.gen_ids_and_flags(tensor, xnn_graph, quant_params) dims = get_shape(tensor) + dims = [1] if len(dims) == 0 else dims # constant values serialize data buffer_idx = self.get_serialized_buffer( @@ -336,6 +349,10 @@ def get_serialized_buffer( # Quantize buffer if static data is indeed quantized if quant_params is not None and not quant_params.is_dynamic: const_val = quant_params.quantize_tensor(const_val).contiguous() + else: + # ensure that the const is fp32 + const_val = const_val.to(dtype=torch.float32).contiguous() + if swap_nc_for_depthwise_weights: const_val = const_val.permute( dims=((1, 0) + tuple(range(2, const_val.dim()))) diff --git a/backends/xnnpack/partition/configs.py b/backends/xnnpack/partition/configs.py index a3b66d3fcb4..1a5f567bf43 100644 --- a/backends/xnnpack/partition/configs.py +++ b/backends/xnnpack/partition/configs.py @@ -37,6 +37,9 @@ SUPPORTED_MODULES = [ torch.nn.Conv1d, + # TODO(T161981984) recomposed hardswish into a single node + torch.nn.Hardswish, + torch.nn.Hardsigmoid, torch.nn.Conv2d, torch.nn.ReLU, torch.nn.Sigmoid, diff --git a/examples/models/models.py b/examples/models/models.py index e76b98a5948..f4a11318b07 100644 --- a/examples/models/models.py +++ b/examples/models/models.py @@ -155,4 +155,5 @@ class OptimizationOptions(object): "add": OptimizationOptions(True, True), "add_mul": OptimizationOptions(True, True), "mv2": OptimizationOptions(True, True), + "mv3": OptimizationOptions(True, False), } diff --git a/examples/quantization/example.py b/examples/quantization/example.py index 4ee1dc9495b..f74bbda9365 100644 --- a/examples/quantization/example.py +++ b/examples/quantization/example.py @@ -46,6 +46,7 @@ def verify_xnnpack_quantizer_matching_fx_quant_model(model_name, model, example_ m = prepare_pt2e(m, quantizer) # calibration after_prepare_result = m(*example_inputs) + print("pt2e prepare:", m) m = convert_pt2e(m) after_quant_result = m(*example_inputs) @@ -57,6 +58,7 @@ def verify_xnnpack_quantizer_matching_fx_quant_model(model_name, model, example_ m_copy, qconfig_mapping, example_inputs, backend_config=backend_config ) after_prepare_result_fx = m_fx(*example_inputs) + print("fx prepare:", m_fx) m_fx = _convert_to_reference_decomposed_fx(m_fx, backend_config=backend_config) after_quant_result_fx = m_fx(*example_inputs) @@ -69,10 +71,10 @@ def verify_xnnpack_quantizer_matching_fx_quant_model(model_name, model, example_ print("m_fx:", m_fx) print("prepare sqnr:", compute_sqnr(after_prepare_result, after_prepare_result_fx)) assert compute_sqnr(after_prepare_result, after_prepare_result_fx) > 100 - print("quant diff max:", torch.max(after_quant_result - after_quant_result_fx)) + print("diff max:", torch.max(after_quant_result - after_quant_result_fx)) + print("sqnr:", compute_sqnr(after_quant_result, after_quant_result_fx)) assert torch.max(after_quant_result - after_quant_result_fx) < 1e-1 - print("quant sqnr:", compute_sqnr(after_quant_result, after_quant_result_fx)) - assert compute_sqnr(after_quant_result, after_quant_result_fx) > 30 + assert compute_sqnr(after_quant_result, after_quant_result_fx) > 35 if __name__ == "__main__": @@ -121,7 +123,7 @@ def verify_xnnpack_quantizer_matching_fx_quant_model(model_name, model, example_ raise RuntimeError( f"Model {args.model_name} is not a valid name. or not quantizable right now, " "please contact executorch team if you want to learn why or how to support " - "quantization for the requested model" + "quantization for the requested model " f"Available models are {list(MODEL_NAME_TO_OPTIONS.keys())}." )