Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 3 additions & 29 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from torchao.quantization.pt2e.quantizer import (
annotate_input_qspec_map,
annotate_output_qspec,
get_module_name_filter,
QuantizationSpec,
Quantizer,
)
Expand Down Expand Up @@ -248,33 +249,6 @@ def get_symmetric_a16w8_quantization_config(
"""


def _get_module_name_filter(module_name: str) -> NodeFilterType:
"""Get the module_name_filter function for a given module name, the filter accepts
a node and checks if the node comes from a module that has certain module name

For example:
node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1

>> module_name_filter = _get_module_name_filter("blocks.sub")
>> print(module_name_filter(node))
True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1"
"""

name_start = len("L['self'].")

def module_name_filter(n: Node) -> bool:
# node_stack example: {
# 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
# 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
# }
# get_attr nodes doesn't have nn_module_stack?
nn_module_stack = n.meta.get("nn_module_stack", {})
names = [name[name_start:] for name, _ in nn_module_stack.values()]
return module_name in names

return module_name_filter


def _get_module_type_filter(tp: Callable) -> NodeFilterType:
"""Get the module_type_filter function for a given module type, the filter accepts
a node and checks if the node comes from a module that has certain module type
Expand Down Expand Up @@ -306,7 +280,7 @@ def _get_not_module_type_or_name_filter(
tp_list: List[Callable], module_name_list: List[str]
) -> NodeFilterType:
module_type_filters = [_get_module_type_filter(tp) for tp in tp_list]
module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list]
module_name_list_filters = [get_module_name_filter(m) for m in module_name_list]

def not_module_type_or_name_filter(n: Node) -> bool:
return not any(f(n) for f in module_type_filters + module_name_list_filters)
Expand Down Expand Up @@ -455,7 +429,7 @@ def _annotate_for_static_quantization_config(
module_name_list = list(self.module_name_config.keys())
for module_name, config in self.module_name_config.items():
self._annotate_all_static_patterns(
model, config, _get_module_name_filter(module_name)
model, config, get_module_name_filter(module_name)
)

tp_list = list(self.module_type_config.keys())
Expand Down
158 changes: 158 additions & 0 deletions backends/arm/test/quantizer/test_set_module_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm.quantizer import (
get_symmetric_a16w8_quantization_config,
get_symmetric_quantization_config,
is_annotated,
QuantizationConfig,
TOSAQuantizer,
)
from executorch.backends.arm.quantizer.quantization_config import QuantizationSpec
from executorch.backends.arm.tosa import TosaSpecification
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e

DQ_PER_CHANNEL = torch.ops.quantized_decomposed.dequantize_per_channel.default
DQ_PER_TENSOR = torch.ops.quantized_decomposed.dequantize_per_tensor.default
Q_PER_TENSOR = torch.ops.quantized_decomposed.quantize_per_tensor.default


class ConvModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv0 = torch.nn.Conv2d(
3,
16,
kernel_size=4,
)
self.conv1 = torch.nn.Conv2d(16, 32, kernel_size=3, bias=False)
self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3)

def forward(self, x):
x = self.conv0(x)
x = torch.sigmoid(x)
x = self.conv1(x)
x = torch.tanh(x)
x = self.conv2(x)
return x


test_inputs = (torch.randn(1, 3, 64, 64),)


def validate_per_tensor_quant(node: torch.fx.Node, qspec: QuantizationSpec):
_, _, zero_point, qmin, qmax, dtype = node.args
if qspec.qscheme == torch.per_tensor_symmetric:
assert (
zero_point == 0
), f"Zero point {zero_point} is not zero for symmetric quantization"
assert (
qmin == qspec.quant_min
), f"Quant min {qmin} does not match expected {qspec.quant_min}"
assert (
qmax == qspec.quant_max
), f"Quant max {qmax} does not match expected {qspec.quant_max}"
assert dtype == qspec.dtype, f"Dtype {dtype} does not match expected {qspec.dtype}"


def validate_per_channel_quant(node: torch.fx.Node, qspec: QuantizationSpec):
_, _, _, channel_axis, qmin, qmax, dtype = node.args
assert (
channel_axis == qspec.ch_axis
), f"Channel axis {channel_axis} does not match expected {qspec.ch_axis}"
assert (
qmin == qspec.quant_min
), f"Quant min {qmin} does not match expected {qspec.quant_min}"
assert (
qmax == qspec.quant_max
), f"Quant max {qmax} does not match expected {qspec.quant_max}"
assert dtype == qspec.dtype, f"Dtype {dtype} does not match expected {qspec.dtype}"


def validate_input(input_node: torch.fx.Node, qspec: QuantizationSpec | None):
if qspec is None:
return

per_channel = qspec.qscheme == torch.per_channel_symmetric
expected_dequant_op = DQ_PER_CHANNEL if per_channel else DQ_PER_TENSOR
assert (
input_node.target == expected_dequant_op
), f"Input node {input_node} is not quantized as expected"
if per_channel:
validate_per_channel_quant(input_node, qspec)
else:
validate_per_tensor_quant(input_node, qspec)


def validate_output(node: torch.fx.Node, qspec: QuantizationSpec | None):
if qspec is None:
return
users = list(node.users)
assert len(users) == 1, f"Node {node} should have exactly one user"
assert (
users[0].target == Q_PER_TENSOR
), f"Output node {users[0]} is not quantized as expected"
validate_per_tensor_quant(users[0], qspec)


def validate_node(
node: torch.fx.Node, quantization_config: QuantizationConfig | None
) -> None:
if quantization_config is None:
assert not is_annotated(node), f"Node {node} is unexpectedly annotated"
return

assert is_annotated(node), f"Node {node} is not annotated"
input_qspec = quantization_config.get_input_act_qspec()
output_qspec = quantization_config.get_output_act_qspec()
weight_qspec = quantization_config.get_weight_qspec()

if len(node.all_input_nodes) == 3:
input_node, weight_node, bias_node = node.all_input_nodes
bias_qspec = quantization_config.get_bias_qspec(node)
validate_input(bias_node, bias_qspec)
else:
input_node, weight_node = node.all_input_nodes

validate_input(input_node, input_qspec)
validate_input(weight_node, weight_qspec)
validate_output(node, output_qspec)


def test_set_module_name() -> None:
model = ConvModel()
model.eval()

# Set up quantizer with different configs for different modules
tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT")
quantizer = TOSAQuantizer(tosa_spec)
int8_config = get_symmetric_quantization_config(is_per_channel=False)
a16w8_config = get_symmetric_a16w8_quantization_config()
# Set module-specific configurations but don't set global config to test that
# only specified modules are quantized
quantizer.set_module_name("conv0", int8_config)
quantizer.set_module_name("conv1", a16w8_config)

# Export model
exported_model = torch.export.export(model, test_inputs)

# Prepare, calibrate and convert model
prepared_model = prepare_pt2e(exported_model.module(), quantizer)
prepared_model(*test_inputs)
converted_model = convert_pt2e(prepared_model)

validate_node(
[node for node in converted_model.graph.nodes if node.name == "conv2d"][0],
int8_config,
)
validate_node(
[node for node in converted_model.graph.nodes if node.name == "conv2d_1"][0],
a16w8_config,
)
validate_node(
[node for node in converted_model.graph.nodes if node.name == "conv2d_2"][0],
None,
)
Loading