diff --git a/backends/xnnpack/partition/xnnpack_partitioner.py b/backends/xnnpack/partition/xnnpack_partitioner.py index 358b3085c80..e5532e17f36 100644 --- a/backends/xnnpack/partition/xnnpack_partitioner.py +++ b/backends/xnnpack/partition/xnnpack_partitioner.py @@ -115,7 +115,7 @@ def generate_per_op_partitions(self, ep: ExportedProgram) -> List[Partition]: class XnnpackDynamicallyQuantizedPartitioner(XnnpackPartitioner): def __init__(self): super().__init__( - config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, per_op_mode=True + config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, ) diff --git a/backends/xnnpack/test/ops/test_linear.py b/backends/xnnpack/test/ops/test_linear.py index fec6005d706..cf9473180bb 100644 --- a/backends/xnnpack/test/ops/test_linear.py +++ b/backends/xnnpack/test/ops/test_linear.py @@ -191,6 +191,21 @@ def forward(self, x, y): return a + b +class SharedDQChain(torch.nn.Module): + def __init__(self, input_size, output_size): + super().__init__() + self.linear1_weight = torch.nn.Parameter(torch.rand(output_size, input_size)) + self.linear1_bias = torch.nn.Parameter(torch.rand(output_size)) + + self.linear2_weight = torch.nn.Parameter(torch.rand(output_size, input_size)) + self.linear2_bias = torch.nn.Parameter(torch.rand(output_size)) + + def forward(self, x): + a = torch.nn.functional.linear(x, self.linear1_weight, self.linear1_bias) + b = torch.nn.functional.linear(x, self.linear2_weight, self.linear2_bias) + return a + b + + class TestLinear(unittest.TestCase): """ Test Class for XNNPACK Linear Operators. @@ -520,6 +535,23 @@ def get_qnode_checks(quant_node_checks, dialect): # qtol=bool(quant_config), atol=atol # ) + def test_qd8_f32_per_channel_shared_dq_chain(self): + for use_bias in (False, True): + module = SharedDQChain( + input_size=13, + output_size=17, + ) + inputs = (torch.randn(1, 2, 13),) + + self._test_dqlinear( + module, + inputs, + dynamic_shapes=None, + is_per_channel=True, + linear_count=2, + uses_bias=use_bias, + ) + def _test_qd8_per_channel_linear(self, dtype: torch.dtype = torch.float): for uses_bias in (False, True): module = BaseLinear( diff --git a/backends/xnnpack/test/tester/TARGETS b/backends/xnnpack/test/tester/TARGETS index 0ba34cc0bfa..231de970d7b 100644 --- a/backends/xnnpack/test/tester/TARGETS +++ b/backends/xnnpack/test/tester/TARGETS @@ -26,5 +26,6 @@ runtime.python_library( "//executorch/exir/backend:partitioner", "//executorch/exir/passes:spec_prop_pass", "//executorch/extension/pybindings:portable_lib", # @manual + "//executorch/backends/transforms:duplicate_dynamic_quant_chain" ], ) diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index 79544256022..21d34a8d30b 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -15,6 +15,9 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union import torch +from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( + DuplicateDynamicQuantChainPass, +) from executorch.backends.xnnpack._passes import XNNPACKPassManager from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config @@ -177,6 +180,8 @@ def run( prepared(*inputs) converted = convert_pt2e(prepared) + DuplicateDynamicQuantChainPass()(converted) + self.converted_graph = converted @property