Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 43 additions & 10 deletions backends/arm/_passes/rewrite_conv_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ def _add_bias(
weight_node: torch.fx.Node,
) -> torch.fx.Node:
output_channels = get_first_fake_tensor(node).shape[1]
# add a node containging zeros if quantized, use int32, otherwise use float32
if "output_qparams" in node.meta and len(node.meta["output_qparams"]) > 0:
# add a node containing zeros if quantized, use int32, otherwise use float32
if self._is_quantized_conv(node):
bias_data = torch.zeros(size=(output_channels,), dtype=torch.int32)
else:
output_dtype = node.meta["val"].dtype
Expand All @@ -188,9 +188,40 @@ def _add_bias(
node.update_arg(2, bias_node)
return bias_node

def insert_output_rescale(self, graph_module, node):
input_qparams = get_input_qparams(node)
output_qparams = get_output_qparams(node)[0]
def _is_quantized_conv(self, node: torch.fx.Node) -> bool:
return bool(node.meta.get("input_qparams", {}))

def _get_effective_output_qparams(self, node: torch.fx.Node):
"""Return the quantized output domain for a conv node.

Quantization annotation may place output qparams on a following
activation instead of on the conv itself. If that activation is not
fuseable, it survives as a quantized ``clamp`` and still owns the
branch output qparams needed for the conv output rescale.

"""
output_qparams = node.meta.get("output_qparams", {})
if output_qparams:
return output_qparams

users = list(node.users)
if len(users) != 1:
raise ValueError(
f"RewriteConvPass: No output quantization parameter found in node {node}\n"
f"original_aten={node.meta.get('original_aten', 'None')}"
)

activation = users[0]
if activation.target == exir_ops.edge.aten.clamp.default:
activation_output_qparams = activation.meta.get("output_qparams", {})
if activation_output_qparams:
return activation_output_qparams

return get_output_qparams(node)

def insert_output_rescale(self, graph_module, source_node, conv_node):
input_qparams = get_input_qparams(source_node)
output_qparams = self._get_effective_output_qparams(source_node)[0]
weight_qparams = input_qparams[1]
input_qparams = input_qparams[0]
is_per_channel = weight_qparams.per_channel
Expand All @@ -207,18 +238,18 @@ def insert_output_rescale(self, graph_module, node):
itertools.cycle([output_qparams.get_scale_per_tensor()]),
)
]
with graph_module.graph.inserting_after(node):
with graph_module.graph.inserting_after(conv_node):
rescale_node = create_node(
graph=graph_module.graph,
op_target=exir_ops.backend.tosa.RESCALE.default,
args=(
node,
conv_node,
output_qparams.dtype,
post_conv2d_scale,
0,
output_qparams.get_zp_per_tensor(),
),
from_node=node,
from_node=source_node,
)
return rescale_node

Expand Down Expand Up @@ -347,15 +378,17 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
tosa_node_fake_tensor.dtype == torch.int32
and input_fake_tensor.dtype == torch.int8
):
output_rescale = self.insert_output_rescale(graph_module, tosa_op)
output_rescale = self.insert_output_rescale(graph_module, node, tosa_op)
node.replace_all_uses_with(output_rescale)
elif (
tosa_node_fake_tensor.dtype == torch.int32
and input_fake_tensor.dtype == torch.int16
):
has_bias = len(node.meta["input_qparams"]) > 2
if not has_bias:
output_rescale = self.insert_output_rescale(graph_module, tosa_op)
output_rescale = self.insert_output_rescale(
graph_module, node, tosa_op
)
node.replace_all_uses_with(output_rescale)
else:
node.replace_all_uses_with(tosa_op)
Expand Down
133 changes: 132 additions & 1 deletion backends/arm/test/passes/test_rewrite_conv_pass.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,98 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
import torch.nn.functional as F
from executorch.backends.arm._passes import (
ConvertToClampPass,
FoldAndAnnotateQParamsPass,
FuseQuantizedActivationPass,
QuantizeClampArgumentsPass,
)
from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass
from executorch.backends.arm.quantizer.arm_quantizer import (
get_symmetric_quantization_config,
VgfQuantizer,
)
from executorch.backends.arm.test.misc.test_dw_convs_with_shared_weights import (
DWConvsModule,
)
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
from executorch.backends.arm.tosa.specification import TosaLoweringContext
from executorch.backends.arm.vgf import VgfCompileSpec, VgfPartitioner
from executorch.exir import EdgeCompileConfig, to_edge, to_edge_transform_and_lower
from executorch.exir.dialects._ops import ops as exir_ops


class TinyConvReluCat(nn.Module):
def __init__(self, conv1_bias: bool = True) -> None:
super().__init__()
self.conv1 = nn.Conv2d(4, 4, 3, padding=1, bias=conv1_bias)
self.conv2 = nn.Conv2d(8, 4, 1)
with torch.no_grad():
for param in self.parameters():
param.uniform_(-0.1, 0.1)

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
relu_out = F.relu(self.conv1(x))
merged = torch.cat((relu_out, y), dim=1)
return self.conv2(merged)


def _example_inputs() -> tuple[torch.Tensor, torch.Tensor]:
torch.manual_seed(0)
x = torch.rand(1, 4, 16, 16)
y = torch.rand(1, 4, 16, 16) - 0.065
return x, y


def _compile_spec() -> VgfCompileSpec:
return VgfCompileSpec("TOSA-1.0+INT+FP")


def _quantizer() -> VgfQuantizer:
quantizer = VgfQuantizer(_compile_spec())
quantizer.set_global(
get_symmetric_quantization_config(
is_per_channel=True,
act_qmin=-127,
act_qmax=127,
weight_qmin=-127,
weight_qmax=127,
)
)
return quantizer


def _export_quantized(model: nn.Module):
inputs = _example_inputs()
exported = torch.export.export(model.eval(), inputs).module(check_guards=False)
quantized = _quantizer()._quantize_with_submodules(exported, [inputs])
return torch.export.export(quantized, inputs)


def _run_pre_rewrite_passes(exported_program: torch.export.ExportedProgram):
gm = exported_program.graph_module
for pass_ in (
FuseQuantizedActivationPass(),
ConvertToClampPass(),
FoldAndAnnotateQParamsPass(exported_program),
QuantizeClampArgumentsPass(),
):
result = pass_(gm)
assert result is not None
gm = result.graph_module
return gm


def _get_call_function_node(gm: torch.fx.GraphModule, target):
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == target:
return node
raise AssertionError(f"Node with target {target} not found")


def test_rewrite_conv_tosa_FP():
Expand All @@ -18,3 +103,49 @@ def test_rewrite_conv_tosa_FP():
# We can't run TOSA backend dialect operators in eager mode
pipeline.pop_stage("run_method_and_compare_outputs")
pipeline.run()


def test_fold_and_annotate_q_params_vgf_quant_preserves_output_qparams_on_non_fuseable_clamp() -> (
None
):
exported_program = _export_quantized(TinyConvReluCat())
gm = _run_pre_rewrite_passes(to_edge(exported_program).exported_program())

conv = _get_call_function_node(gm, exir_ops.edge.aten.convolution.default)
clamp = _get_call_function_node(gm, exir_ops.edge.aten.clamp.default)

assert conv.meta["input_qparams"]
assert not conv.meta["output_qparams"]
assert clamp.meta["output_qparams"]


def test_rewrite_conv_vgf_quant_handles_non_fuseable_conv_clamp_cat_branch() -> None:
exported_program = _export_quantized(TinyConvReluCat())
compile_spec = _compile_spec()

to_edge_transform_and_lower(
exported_program,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
partitioner=[VgfPartitioner(compile_spec)],
)


def test_rewrite_conv_vgf_quant_infers_quantized_bias_dtype_from_inputs() -> None:
exported_program = _export_quantized(TinyConvReluCat(conv1_bias=False))
edge_program = to_edge(
exported_program, compile_config=EdgeCompileConfig(_check_ir_validity=False)
).exported_program()
gm = _run_pre_rewrite_passes(edge_program)
with TosaLoweringContext(_compile_spec().tosa_spec):
result = RewriteConvPass(edge_program)(gm)
assert result is not None
gm = result.graph_module

bias_nodes = [
node
for node in gm.graph.nodes
if node.op == "placeholder" and node.name.endswith("_bias")
]

assert len(bias_nodes) == 1
assert bias_nodes[0].meta["val"].dtype == torch.int32
Loading