diff --git a/exir/passes/TARGETS b/exir/passes/TARGETS index eeb1e5265b0..a3251589ac0 100644 --- a/exir/passes/TARGETS +++ b/exir/passes/TARGETS @@ -16,6 +16,7 @@ python_library( ":normalize_transpose_pass", ":prim_ops_py_registry", ":quant_fusion_pass", + ":quantize_io_pass", ":remove_noop_pass", ":replace_aten_with_edge_pass", ":replace_broken_ops_with_function_ops_pass", @@ -143,6 +144,19 @@ python_library( ], ) +python_library( + name = "quantize_io_pass", + srcs = [ + "quantize_io_pass.py", + ], + deps = [ + "fbsource//third-party/pypi/numpy:numpy", + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ], +) + python_library( name = "memory_planning_pass", srcs = [ diff --git a/exir/passes/quantize_io_pass.py b/exir/passes/quantize_io_pass.py new file mode 100644 index 00000000000..21ac4c868a3 --- /dev/null +++ b/exir/passes/quantize_io_pass.py @@ -0,0 +1,259 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +import logging +from typing import Any, Dict, List, Optional, Union + +import numpy as np + +import torch + +from executorch.exir import EdgeProgramManager +from executorch.exir.dialects._ops import ops as exir_ops + +from executorch.exir.pass_base import ExportPass +from executorch.exir.tensor import scalar_type_enum +from torch.fx.passes.infra.pass_base import PassResult + +logger = logging.getLogger(__name__) + + +def quantize_input( + exported_program, input_index, qparams: Optional[Dict[str, Any]] = None +): + """ + Modify the program to expect quantized input at given index. The input is expected + to be quantizing this input as the first step. Must be called before + permute_input_layout. Returns the scale, zero point, qmin, qmax, and dtype of the + expected quantization. + """ + graph = exported_program.graph_module.graph + name = exported_program.graph_signature.user_inputs[input_index] + placeholders = [n for n in graph.nodes if n.op == "placeholder" and n.name == name] + assert placeholders + target_placeholder = placeholders[0] + + if len(target_placeholder.users) != 1: + raise ValueError(f"Input {input_index} has more than one users") + quantize = next(iter(target_placeholder.users)) + if ( + quantize.target + != exir_ops.edge.quantized_decomposed.quantize_per_tensor.default + ): + raise ValueError(f"Input {input_index} is not used by a quantize op") + + # If user specified qparams are different from args of quantize op, we do requantization instead of eliminating quantize op + need_requant = False + if qparams is not None: + assert all( + qparam in qparams for qparam in ["scale", "zp", "dtype"] + ), "dtype/scale/zp must be specified in qparam for input requantization" + if qparams["dtype"] != quantize.args[5]: + if any( + dtype + not in [torch.int8, torch.uint8, torch.bool, torch.int16, torch.uint16] + for dtype in [qparams["dtype"], quantize.args[5]] + ): + raise ValueError( + f"Only limited data types are supported for requantization, but got {qparams['dtype']} -> {quantize.args[5]}" + ) + + need_requant = True + elif ( + not np.isclose(qparams["scale"], quantize.args[1]) + or qparams["zp"] != quantize.args[2] + ): + need_requant = True + + if need_requant: + assert qparams is not None + dtype = qparams["dtype"] + qmin = torch.iinfo(dtype).min + qmax = torch.iinfo(dtype).max + scale = qparams["scale"] + zero_point = qparams["zp"] + quant_args = (scale, zero_point, qmin, qmax, dtype) + logger.info( + f"Modifying program to requantize quantized input at index {input_index}" + ) + logger.info(f"Quantization parameters: {quant_args}") + + with exported_program.graph_module.graph.inserting_before(quantize): + input_dequant = exported_program.graph_module.graph.call_function( + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + args=( + target_placeholder, + *quant_args, + ), + ) + input_dequant.meta["input_qparams"] = [ + { + "scale": scale, + "zero_point": zero_point, + "qmin": qmin, + "qmax": qmax, + "dtype": dtype, + } + ] + input_dequant.meta["val"] = quantize.meta["val"].to(torch.float32) + target_placeholder.meta["val"] = target_placeholder.meta["val"].to(dtype) + quantize.replace_input_with(target_placeholder, input_dequant) + else: + quant_args = quantize.args[1:] + logger.info(f"Modifying program to take quantized input at index {input_index}") + logger.info(f"Quantization parameters: {quant_args}") + + target_placeholder.meta["val"] = ( + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default( + target_placeholder.meta["val"], *quant_args + ) + ) + quantize.replace_all_uses_with(quantize.args[0]) + + exported_program.graph_module.graph.eliminate_dead_code() + return quant_args + + +def quantize_output(exported_program, output_index): + """ + Modify the program to produce quantized output at given index. The model is expected + to be dequantizing this output as the last step. Must be called before + permute_output_layout. Returns the scale, zero point, qmin, qmax, and dtype of the + output quantization. + """ + graph = exported_program.graph_module.graph + outputs = [n for n in graph.nodes if n.op == "output"] + if len(outputs) != 1: + raise NotImplementedError("Only 1 output node is supported") + + output_node = outputs[0] + output_list = list(output_node.args[0]) + if output_index >= len(output_list): + raise ValueError( + f"{len(output_list)} outputs available, " + + f"output index out of bounds: {output_index}" + ) + + target_output = output_list[output_index] + if ( + target_output.target + != exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default + ): + raise ValueError("Output {output_index} is not a dequantize op") + + dequant = target_output + output_list[output_index] = dequant.args[0] + output_node.args = (output_list,) + dequant_args = dequant.args[1:] + graph.eliminate_dead_code() + + logger.info( + f"Modifying program to produce quantized output at index {output_index}" + ) + logger.info(f"Dequantization parameters: {dequant_args}") + return dequant_args + + +def get_config_method_name( + prefix: Optional[str] = "forward", + arg_type: str = "input", + index: int = 0, + key: str = "scale", +): + if prefix is None: + prefix = "" + else: + prefix = prefix + "_" + assert arg_type in ["input", "output"], "arg_type must be either input or output" + assert index >= 0, "index must be non-negative" + assert key in [ + "scale", + "zp", + "quant_min", + "quant_max", + "dtype", + ], "key must be one of scale, zp, quant_min, quant_max, dtype" + return f"{prefix}{arg_type}{index}_{key}" + + +class QuantizeInputs(ExportPass): + def __init__( + self, + edge_program_manager: EdgeProgramManager, + quantized_inputs_idx: Union[Dict[int, Dict[str, Any]], List[int]], + method_name: Optional[str] = None, + ): + super().__init__() + self.edge_program_manager = edge_program_manager + + self.quantized_inputs_idx_dict = {} + if isinstance(quantized_inputs_idx, dict): + self.quantized_inputs_idx_dict = quantized_inputs_idx + else: + for idx in quantized_inputs_idx: + self.quantized_inputs_idx_dict[idx] = None + self.param_prefix_name = method_name + + def call(self, graph_module: torch.fx.GraphModule): + for i, qparams in self.quantized_inputs_idx_dict.items(): + quant_args = quantize_input( + self.edge_program_manager.exported_program(), i, qparams + ) + + if not self.edge_program_manager._config_methods: + self.edge_program_manager._config_methods = {} + + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "input", i, "scale") + ] = quant_args[0] + self.edge_program_manager._config_methods[ # pyre-ignore + get_config_method_name(self.param_prefix_name, "input", i, "zp") + ] = quant_args[1] + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "input", i, "quant_min") + ] = quant_args[2] + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "input", i, "quant_max") + ] = quant_args[3] + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "input", i, "dtype") + ] = scalar_type_enum(quant_args[4]) + return PassResult(graph_module, True) + + +class QuantizeOutputs(ExportPass): + def __init__( + self, + edge_program_manager: EdgeProgramManager, + quantized_outputs_idx_list: List[int], + method_name: Optional[str] = None, + ): + super().__init__() + self.edge_program_manager = edge_program_manager + self.quantized_outputs_idx_list = quantized_outputs_idx_list + self.param_prefix_name = method_name + + def call(self, graph_module: torch.fx.GraphModule): + for i in self.quantized_outputs_idx_list: + dequant_args = quantize_output( + self.edge_program_manager.exported_program(), i + ) # noqa F841 + + if not self.edge_program_manager._config_methods: + self.edge_program_manager._config_methods = {} + + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "output", i, "scale") + ] = dequant_args[0] + self.edge_program_manager._config_methods[ # pyre-ignore + get_config_method_name(self.param_prefix_name, "output", i, "zp") + ] = dequant_args[1] + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "output", i, "quant_min") + ] = dequant_args[2] + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "output", i, "quant_max") + ] = dequant_args[3] + self.edge_program_manager._config_methods[ + get_config_method_name(self.param_prefix_name, "output", i, "dtype") + ] = scalar_type_enum(dequant_args[4]) + + return PassResult(graph_module, True) diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index f8b4d905fb1..1995589f803 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -448,3 +448,15 @@ python_unittest( "//executorch/exir:_warnings", ], ) + +python_unittest( + name = "quantize_io_pass", + srcs = [ + "test_quantize_io_pass.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/exir/passes:quantize_io_pass", + ], +) diff --git a/exir/tests/test_quantize_io_pass.py b/exir/tests/test_quantize_io_pass.py new file mode 100644 index 00000000000..b3899b008c2 --- /dev/null +++ b/exir/tests/test_quantize_io_pass.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +import torch +from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower +from executorch.exir.passes.quantize_io_pass import ( + get_config_method_name, + QuantizeInputs, + QuantizeOutputs, +) +from executorch.exir.tensor import get_scalar_type +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e + +from torch.ao.quantization.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) +from torch.testing import FileCheck + +op_str = { + "q": "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default", + "dq": "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default", +} + + +class TestQuantIOPass(unittest.TestCase): + class Add(torch.nn.Module): + def forward(self, x, y): + return x + y + + def _quantize(self, mod, example_inputs): + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config() + quantizer.set_global(operator_config) + m = torch.export.export_for_training( + mod, copy.deepcopy(example_inputs) + ).module() + m = prepare_pt2e(m, quantizer) + _ = m(*example_inputs) + m = convert_pt2e(m) + exported_program = torch.export.export_for_training(m, example_inputs) + return exported_program + + def _check_count(self, op, count, epm): + code = epm.exported_program().graph_module.code + FileCheck().check_count(op, count, exactly=True).run(code) + + def _get_edge_prog_manager(self, mod, example_inputs): + exported_program = self._quantize(mod, example_inputs) + edge_program_manager = to_edge_transform_and_lower( + exported_program, + transform_passes=[], + partitioner=None, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + + self._check_count(op_str["dq"], 3, edge_program_manager) + self._check_count(op_str["q"], 3, edge_program_manager) + return edge_program_manager + + def test_add_drop_q_inputs(self) -> None: + example_inputs = (torch.randn(1, 5), torch.randn(1, 5)) + mod = self.Add().eval() + edge_program_manager = self._get_edge_prog_manager(mod, example_inputs) + reference_outputs = edge_program_manager.exported_program().module()( + *example_inputs + ) + + edge_program_manager_qin = edge_program_manager.transform( + [ + QuantizeInputs( + edge_program_manager=edge_program_manager, + quantized_inputs_idx=[0, 1], + method_name="forward", + ) + ] + ) + self._check_count(op_str["dq"], 3, edge_program_manager) + self._check_count(op_str["q"], 1, edge_program_manager) + + quantized_example_inputs = [] + for i in range(len(example_inputs)): + d = edge_program_manager_qin._config_methods + scale = d[get_config_method_name("forward", "input", i, "scale")] + zp = d[get_config_method_name("forward", "input", i, "zp")] + quant_min = d[get_config_method_name("forward", "input", i, "quant_min")] + quant_max = d[get_config_method_name("forward", "input", i, "quant_max")] + dtype = get_scalar_type( + d[get_config_method_name("forward", "input", i, "dtype")] + ) + + quantized_example_inputs.append( + torch.ops.quantized_decomposed.quantize_per_tensor.default( + example_inputs[i], scale, zp, quant_min, quant_max, dtype + ), + ) + quantized_example_inputs = tuple(quantized_example_inputs) + output = edge_program_manager_qin.exported_program().module()( + *quantized_example_inputs + ) + torch.testing.assert_close( + reference_outputs[0], + output[0], + ) + + def test_add_drop_dq_output(self) -> None: + example_inputs = (torch.randn(1, 5), torch.randn(1, 5)) + mod = self.Add().eval() + edge_program_manager = self._get_edge_prog_manager(mod, example_inputs) + reference_outputs = edge_program_manager.exported_program().module()( + *example_inputs + ) + + edge_program_manager_dqout = edge_program_manager.transform( + [ + QuantizeOutputs( + edge_program_manager=edge_program_manager, + quantized_outputs_idx_list=[0], + method_name="forward", + ) + ] + ) + self._check_count(op_str["dq"], 2, edge_program_manager) + self._check_count(op_str["q"], 3, edge_program_manager) + + quantized_outputs = edge_program_manager_dqout.exported_program().module()( + *example_inputs + ) + + dequantized_outputs = [] + for i in range(len(quantized_outputs)): + d = edge_program_manager_dqout._config_methods + scale = d[get_config_method_name("forward", "output", i, "scale")] + zp = d[get_config_method_name("forward", "output", i, "zp")] + q_min = d[get_config_method_name("forward", "output", i, "quant_min")] + q_max = d[get_config_method_name("forward", "output", i, "quant_max")] + dtype = get_scalar_type( + d[get_config_method_name("forward", "output", i, "dtype")] + ) + dequantized_outputs.append( + torch.ops.quantized_decomposed.dequantize_per_tensor.default( + quantized_outputs[i], scale, zp, q_min, q_max, dtype + ) + ) + dequantized_outputs = tuple(dequantized_outputs) + + torch.testing.assert_close( + reference_outputs[0], + dequantized_outputs[0], + )