diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 1839b0e1a47..8f5431d3ffa 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -46,6 +46,7 @@ from torchao.quantization.pt2e.quantizer import ( annotate_input_qspec_map, annotate_output_qspec, + get_module_name_filter, QuantizationSpec, Quantizer, ) @@ -248,33 +249,6 @@ def get_symmetric_a16w8_quantization_config( """ -def _get_module_name_filter(module_name: str) -> NodeFilterType: - """Get the module_name_filter function for a given module name, the filter accepts - a node and checks if the node comes from a module that has certain module name - - For example: - node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1 - - >> module_name_filter = _get_module_name_filter("blocks.sub") - >> print(module_name_filter(node)) - True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1" - """ - - name_start = len("L['self'].") - - def module_name_filter(n: Node) -> bool: - # node_stack example: { - # 'L__self___sub': ("L['self'].sub", ), - # 'L__self___sub_linear': ("L['self'].sub.linear", ) - # } - # get_attr nodes doesn't have nn_module_stack? - nn_module_stack = n.meta.get("nn_module_stack", {}) - names = [name[name_start:] for name, _ in nn_module_stack.values()] - return module_name in names - - return module_name_filter - - def _get_module_type_filter(tp: Callable) -> NodeFilterType: """Get the module_type_filter function for a given module type, the filter accepts a node and checks if the node comes from a module that has certain module type @@ -306,7 +280,7 @@ def _get_not_module_type_or_name_filter( tp_list: List[Callable], module_name_list: List[str] ) -> NodeFilterType: module_type_filters = [_get_module_type_filter(tp) for tp in tp_list] - module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] + module_name_list_filters = [get_module_name_filter(m) for m in module_name_list] def not_module_type_or_name_filter(n: Node) -> bool: return not any(f(n) for f in module_type_filters + module_name_list_filters) @@ -455,7 +429,7 @@ def _annotate_for_static_quantization_config( module_name_list = list(self.module_name_config.keys()) for module_name, config in self.module_name_config.items(): self._annotate_all_static_patterns( - model, config, _get_module_name_filter(module_name) + model, config, get_module_name_filter(module_name) ) tp_list = list(self.module_type_config.keys()) diff --git a/backends/arm/test/quantizer/test_set_module_name.py b/backends/arm/test/quantizer/test_set_module_name.py new file mode 100644 index 00000000000..56131a83e86 --- /dev/null +++ b/backends/arm/test/quantizer/test_set_module_name.py @@ -0,0 +1,158 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.quantizer import ( + get_symmetric_a16w8_quantization_config, + get_symmetric_quantization_config, + is_annotated, + QuantizationConfig, + TOSAQuantizer, +) +from executorch.backends.arm.quantizer.quantization_config import QuantizationSpec +from executorch.backends.arm.tosa import TosaSpecification +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + +DQ_PER_CHANNEL = torch.ops.quantized_decomposed.dequantize_per_channel.default +DQ_PER_TENSOR = torch.ops.quantized_decomposed.dequantize_per_tensor.default +Q_PER_TENSOR = torch.ops.quantized_decomposed.quantize_per_tensor.default + + +class ConvModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv0 = torch.nn.Conv2d( + 3, + 16, + kernel_size=4, + ) + self.conv1 = torch.nn.Conv2d(16, 32, kernel_size=3, bias=False) + self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3) + + def forward(self, x): + x = self.conv0(x) + x = torch.sigmoid(x) + x = self.conv1(x) + x = torch.tanh(x) + x = self.conv2(x) + return x + + +test_inputs = (torch.randn(1, 3, 64, 64),) + + +def validate_per_tensor_quant(node: torch.fx.Node, qspec: QuantizationSpec): + _, _, zero_point, qmin, qmax, dtype = node.args + if qspec.qscheme == torch.per_tensor_symmetric: + assert ( + zero_point == 0 + ), f"Zero point {zero_point} is not zero for symmetric quantization" + assert ( + qmin == qspec.quant_min + ), f"Quant min {qmin} does not match expected {qspec.quant_min}" + assert ( + qmax == qspec.quant_max + ), f"Quant max {qmax} does not match expected {qspec.quant_max}" + assert dtype == qspec.dtype, f"Dtype {dtype} does not match expected {qspec.dtype}" + + +def validate_per_channel_quant(node: torch.fx.Node, qspec: QuantizationSpec): + _, _, _, channel_axis, qmin, qmax, dtype = node.args + assert ( + channel_axis == qspec.ch_axis + ), f"Channel axis {channel_axis} does not match expected {qspec.ch_axis}" + assert ( + qmin == qspec.quant_min + ), f"Quant min {qmin} does not match expected {qspec.quant_min}" + assert ( + qmax == qspec.quant_max + ), f"Quant max {qmax} does not match expected {qspec.quant_max}" + assert dtype == qspec.dtype, f"Dtype {dtype} does not match expected {qspec.dtype}" + + +def validate_input(input_node: torch.fx.Node, qspec: QuantizationSpec | None): + if qspec is None: + return + + per_channel = qspec.qscheme == torch.per_channel_symmetric + expected_dequant_op = DQ_PER_CHANNEL if per_channel else DQ_PER_TENSOR + assert ( + input_node.target == expected_dequant_op + ), f"Input node {input_node} is not quantized as expected" + if per_channel: + validate_per_channel_quant(input_node, qspec) + else: + validate_per_tensor_quant(input_node, qspec) + + +def validate_output(node: torch.fx.Node, qspec: QuantizationSpec | None): + if qspec is None: + return + users = list(node.users) + assert len(users) == 1, f"Node {node} should have exactly one user" + assert ( + users[0].target == Q_PER_TENSOR + ), f"Output node {users[0]} is not quantized as expected" + validate_per_tensor_quant(users[0], qspec) + + +def validate_node( + node: torch.fx.Node, quantization_config: QuantizationConfig | None +) -> None: + if quantization_config is None: + assert not is_annotated(node), f"Node {node} is unexpectedly annotated" + return + + assert is_annotated(node), f"Node {node} is not annotated" + input_qspec = quantization_config.get_input_act_qspec() + output_qspec = quantization_config.get_output_act_qspec() + weight_qspec = quantization_config.get_weight_qspec() + + if len(node.all_input_nodes) == 3: + input_node, weight_node, bias_node = node.all_input_nodes + bias_qspec = quantization_config.get_bias_qspec(node) + validate_input(bias_node, bias_qspec) + else: + input_node, weight_node = node.all_input_nodes + + validate_input(input_node, input_qspec) + validate_input(weight_node, weight_qspec) + validate_output(node, output_qspec) + + +def test_set_module_name() -> None: + model = ConvModel() + model.eval() + + # Set up quantizer with different configs for different modules + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") + quantizer = TOSAQuantizer(tosa_spec) + int8_config = get_symmetric_quantization_config(is_per_channel=False) + a16w8_config = get_symmetric_a16w8_quantization_config() + # Set module-specific configurations but don't set global config to test that + # only specified modules are quantized + quantizer.set_module_name("conv0", int8_config) + quantizer.set_module_name("conv1", a16w8_config) + + # Export model + exported_model = torch.export.export(model, test_inputs) + + # Prepare, calibrate and convert model + prepared_model = prepare_pt2e(exported_model.module(), quantizer) + prepared_model(*test_inputs) + converted_model = convert_pt2e(prepared_model) + + validate_node( + [node for node in converted_model.graph.nodes if node.name == "conv2d"][0], + int8_config, + ) + validate_node( + [node for node in converted_model.graph.nodes if node.name == "conv2d_1"][0], + a16w8_config, + ) + validate_node( + [node for node in converted_model.graph.nodes if node.name == "conv2d_2"][0], + None, + )