From d505ec6e82dc12571b13368ebf5dd374bb5874f8 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 18 Mar 2025 15:13:49 -0700 Subject: [PATCH] Remove Per-Op mode from DQPartitioner Differential Revision: [D71427234](https://our.internmc.facebook.com/intern/diff/D71427234/) [ghstack-poisoned] --- .../xnnpack/partition/xnnpack_partitioner.py | 2 +- backends/xnnpack/test/ops/test_linear.py | 37 ++++++++++++++++++- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/backends/xnnpack/partition/xnnpack_partitioner.py b/backends/xnnpack/partition/xnnpack_partitioner.py index 358b3085c80..e5532e17f36 100644 --- a/backends/xnnpack/partition/xnnpack_partitioner.py +++ b/backends/xnnpack/partition/xnnpack_partitioner.py @@ -115,7 +115,7 @@ def generate_per_op_partitions(self, ep: ExportedProgram) -> List[Partition]: class XnnpackDynamicallyQuantizedPartitioner(XnnpackPartitioner): def __init__(self): super().__init__( - config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, per_op_mode=True + config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, ) diff --git a/backends/xnnpack/test/ops/test_linear.py b/backends/xnnpack/test/ops/test_linear.py index fec6005d706..cc22906ec1a 100644 --- a/backends/xnnpack/test/ops/test_linear.py +++ b/backends/xnnpack/test/ops/test_linear.py @@ -191,6 +191,21 @@ def forward(self, x, y): return a + b +class SharedDQChain(torch.nn.Module): + def __init__(self, input_size, output_size): + super().__init__() + self.linear1_weight = torch.nn.Parameter(torch.rand(output_size, input_size)) + self.linear1_bias = torch.nn.Parameter(torch.rand(output_size)) + + self.linear2_weight = torch.nn.Parameter(torch.rand(output_size, input_size)) + self.linear2_bias = torch.nn.Parameter(torch.rand(output_size)) + + def forward(self, x): + a = torch.nn.functional.linear(x, self.linear1_weight, self.linear1_bias) + b = torch.nn.functional.linear(x, self.linear2_weight, self.linear2_bias) + return a + b + + class TestLinear(unittest.TestCase): """ Test Class for XNNPACK Linear Operators. @@ -316,6 +331,7 @@ def _test_dqlinear( uses_bias=False, qconfig: Optional[QuantizationConfig] = None, atol=5e-02, # TODO(T212995726): Investigate right atol for rand[n] inputs + no_per_op_mode=False, ): """ Helper function to test dynamic quantized linear op with different configurations. @@ -324,8 +340,9 @@ def _test_dqlinear( is_per_channel=is_per_channel, is_dynamic=True, ) + per_op_mode_choices = [False] if no_per_op_mode else [True, False] for legacy_partitioner in (True, False): - for per_op_mode in (True, False): + for per_op_mode in per_op_mode_choices: DynamicallyQuantizedPartitioner = XnnpackPartitioner( config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, per_op_mode=per_op_mode, @@ -520,6 +537,24 @@ def get_qnode_checks(quant_node_checks, dialect): # qtol=bool(quant_config), atol=atol # ) + def test_qd8_f32_per_channel_shared_dq_chain(self): + for use_bias in (False, True): + module = SharedDQChain( + input_size=13, + output_size=17, + ) + inputs = (torch.randn(1, 2, 13),) + + self._test_dqlinear( + module, + inputs, + dynamic_shapes=None, + is_per_channel=True, + linear_count=2, + uses_bias=use_bias, + no_per_op_mode=True, + ) + def _test_qd8_per_channel_linear(self, dtype: torch.dtype = torch.float): for uses_bias in (False, True): module = BaseLinear(