From 09f621ba76aac90dd60fbad5972e751d3598b348 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 17 Jun 2025 11:16:13 -0700 Subject: [PATCH 1/3] [Quantized DeConv Support] Enable Quantized Transposed Convs with groups==1 Pull Request resolved: https://github.com/pytorch/executorch/pull/11730 Supporting Quantized Transposed Convs with Groups being 1. Previously, There was some added support for Quantized Transposed Convolutions but only when the channel axis is 1 and when the groups is 1. The current Quantizer didn't support this because it only allows quantizaing along the zero dim, which is generally the output channels. However for TransposedConvs, the dimension of the weights are: ``` [in_channels, out_channels/groups, h, w] ``` Since we want to keep quantization along the output channels, we now need to quantize along axis = 1. The reason we require groups to be one is because XNNPACK takes in filters of the dimension: ``` [out_channels, H, W, in_channels/groups] ``` Since we are quantizing along the output channels, in pytorch we expect to have out_channels/groups scales, but in xnnpack we have out_channels scales! Realistically we would need to support this with some affine quantization, where we provide a scale for every group, every out_channel. However for now, we just ensure the constraint where groups == 1. ghstack-source-id: 291033630 @exported-using-ghexport Differential Revision: [D76631781](https://our.internmc.facebook.com/intern/diff/D76631781/) --- .../quantizer/xnnpack_quantizer_utils.py | 24 ++++++++++++++++--- backends/xnnpack/test/ops/test_conv2d.py | 19 ++++++++------- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py index 3d687d0b513..a0d04314733 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py @@ -267,7 +267,18 @@ def _do_annotate_conv( weight = conv_node.args[1] assert isinstance(weight, Node) weight_qspec = get_weight_qspec(quantization_config) - num_groups = get_groups_from_conv(conv_node) + if is_conv_transpose: + # transposed convs per output channel quantization + weight_qspec = QuantizationSpec( + dtype=weight_qspec.dtype, + quant_min=weight_qspec.quant_min, + quant_max=weight_qspec.quant_max, + qscheme=weight_qspec.qscheme, + ch_axis=1, + is_dynamic=False, + observer_or_fake_quant_ctr=weight_qspec.observer_or_fake_quant_ctr, + ) + input_qspec_map[weight] = weight_qspec # skip if transposed conv has more than 1 group skip = skip or (is_conv_transpose and num_groups != 1) @@ -350,10 +361,17 @@ def _do_annotate_conv_relu( weight = conv_node.args[1] assert isinstance(weight, Node) weight_qspec = get_weight_qspec(quantization_config) - groups = get_groups_from_conv(conv_node) if is_conv_transpose: # transposed convs per output channel quantization - weight_qspec = change_quantization_config(weight_qspec, ch_axis=1) + weight_qspec = QuantizationSpec( + dtype=weight_qspec.dtype, + quant_min=weight_qspec.quant_min, + quant_max=weight_qspec.quant_max, + qscheme=weight_qspec.qscheme, + ch_axis=1, + is_dynamic=False, + observer_or_fake_quant_ctr=weight_qspec.observer_or_fake_quant_ctr, + ) input_qspec_map[weight] = weight_qspec # adding weight node to the partition as well diff --git a/backends/xnnpack/test/ops/test_conv2d.py b/backends/xnnpack/test/ops/test_conv2d.py index 2a0a82d99b6..2f3500a4f71 100644 --- a/backends/xnnpack/test/ops/test_conv2d.py +++ b/backends/xnnpack/test/ops/test_conv2d.py @@ -507,14 +507,17 @@ def forward(self, x): def get_inputs(self): return (torch.randn(batches, in_channels, height, width) * 11,) - for per_channel_quant in (False, True): - model = ModelConvReLU() - self._test( - model, - quant_config=get_symmetric_quantization_config( - is_per_channel=per_channel_quant - ), - ) + for transpose in (True, False): + for per_channel_quant in (False, True): + if transpose and per_channel_quant: + continue + model = ModelConvReLU(transpose=transpose) + self._test( + model, + quant_config=get_symmetric_quantization_config( + is_per_channel=per_channel_quant + ), + ) def test_qs8_conv2d_relu_seq(self): class ConvReLUSeq(torch.nn.Module): From 7d3ec3d818ea98306493dc22273cf279a5b94d05 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 17 Jun 2025 11:16:14 -0700 Subject: [PATCH 2/3] [Quantized DeConv Support] Dynamically Quantized Deconvolutions with groups ==1 Pull Request resolved: https://github.com/pytorch/executorch/pull/11731 Here we support dynamically quantized Deconvolutions. There is some refactoring of the previous diff, but in general, we just remove the constraint in the Dynamism check that the convolution isn't transposed. For the same reasons as before, this only supports channel_axis = 1 and groups = 1. ghstack-source-id: 291033632 @exported-using-ghexport Differential Revision: [D76638904](https://our.internmc.facebook.com/intern/diff/D76638904/) --- .../quantizer/xnnpack_quantizer_utils.py | 24 +++---------------- backends/xnnpack/test/ops/test_conv2d.py | 19 +++++++-------- 2 files changed, 11 insertions(+), 32 deletions(-) diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py index a0d04314733..3d687d0b513 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py @@ -267,18 +267,7 @@ def _do_annotate_conv( weight = conv_node.args[1] assert isinstance(weight, Node) weight_qspec = get_weight_qspec(quantization_config) - if is_conv_transpose: - # transposed convs per output channel quantization - weight_qspec = QuantizationSpec( - dtype=weight_qspec.dtype, - quant_min=weight_qspec.quant_min, - quant_max=weight_qspec.quant_max, - qscheme=weight_qspec.qscheme, - ch_axis=1, - is_dynamic=False, - observer_or_fake_quant_ctr=weight_qspec.observer_or_fake_quant_ctr, - ) - input_qspec_map[weight] = weight_qspec + num_groups = get_groups_from_conv(conv_node) # skip if transposed conv has more than 1 group skip = skip or (is_conv_transpose and num_groups != 1) @@ -361,17 +350,10 @@ def _do_annotate_conv_relu( weight = conv_node.args[1] assert isinstance(weight, Node) weight_qspec = get_weight_qspec(quantization_config) + groups = get_groups_from_conv(conv_node) if is_conv_transpose: # transposed convs per output channel quantization - weight_qspec = QuantizationSpec( - dtype=weight_qspec.dtype, - quant_min=weight_qspec.quant_min, - quant_max=weight_qspec.quant_max, - qscheme=weight_qspec.qscheme, - ch_axis=1, - is_dynamic=False, - observer_or_fake_quant_ctr=weight_qspec.observer_or_fake_quant_ctr, - ) + weight_qspec = change_quantization_config(weight_qspec, ch_axis=1) input_qspec_map[weight] = weight_qspec # adding weight node to the partition as well diff --git a/backends/xnnpack/test/ops/test_conv2d.py b/backends/xnnpack/test/ops/test_conv2d.py index 2f3500a4f71..2a0a82d99b6 100644 --- a/backends/xnnpack/test/ops/test_conv2d.py +++ b/backends/xnnpack/test/ops/test_conv2d.py @@ -507,17 +507,14 @@ def forward(self, x): def get_inputs(self): return (torch.randn(batches, in_channels, height, width) * 11,) - for transpose in (True, False): - for per_channel_quant in (False, True): - if transpose and per_channel_quant: - continue - model = ModelConvReLU(transpose=transpose) - self._test( - model, - quant_config=get_symmetric_quantization_config( - is_per_channel=per_channel_quant - ), - ) + for per_channel_quant in (False, True): + model = ModelConvReLU() + self._test( + model, + quant_config=get_symmetric_quantization_config( + is_per_channel=per_channel_quant + ), + ) def test_qs8_conv2d_relu_seq(self): class ConvReLUSeq(torch.nn.Module): From b7572d0681a69f36899b2095837a1ca8b2dfc646 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 17 Jun 2025 11:16:15 -0700 Subject: [PATCH 3/3] [XNNPACK Quantizer] Select between TConvs and Convs Pull Request resolved: https://github.com/pytorch/executorch/pull/11732 Allow selection of Difference between transposed convs and regular convs. Previously, we grouped all conv targets together (transposed and regular convs), but now we enable better per-operator selection ghstack-source-id: 291033631 Differential Revision: [D76641838](https://our.internmc.facebook.com/intern/diff/D76641838/) --- .../xnnpack/quantizer/xnnpack_quantizer.py | 17 ++- .../test/quantizer/test_xnnpack_quantizer.py | 110 ++++++++++++++++++ 2 files changed, 123 insertions(+), 4 deletions(-) diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer.py b/backends/xnnpack/quantizer/xnnpack_quantizer.py index c07d27e4231..3c82a65ad71 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer.py @@ -251,6 +251,15 @@ class QuantPattern: torch.ops.aten.convolution.default, } +CONV_TRANSPOSE_TARGETS = { + torch.ops.aten.conv_transpose1d, + torch.ops.aten.conv_transpose1d.default, + torch.ops.aten.conv_transpose2d, + torch.ops.aten.conv_transpose2d.input, + torch.ops.aten.conv_transpose3d, + torch.ops.aten.conv_transpose3d.input, +} + LINEAR_TARGETS = { torch.ops.aten.linear.default, } @@ -269,14 +278,14 @@ class XNNPACKQuantizer(Quantizer): SUPPORTED_PATTERNS = [ QuantPattern("conv_bn_relu", False, True, CONV_TARGETS), QuantPattern("conv_bn", False, True, CONV_TARGETS), - QuantPattern("conv_transpose_bn_relu", False, True, CONV_TARGETS), - QuantPattern("conv_transpose_bn", False, True, CONV_TARGETS), + QuantPattern("conv_transpose_bn_relu", False, True, CONV_TRANSPOSE_TARGETS), + QuantPattern("conv_transpose_bn", False, True, CONV_TRANSPOSE_TARGETS), QuantPattern("linear_relu", False, False, LINEAR_TARGETS), QuantPattern("linear", True, False, LINEAR_TARGETS), QuantPattern("conv", True, False, CONV_TARGETS), - QuantPattern("conv_transpose", True, False, CONV_TARGETS), + QuantPattern("conv_transpose", True, False, CONV_TRANSPOSE_TARGETS), QuantPattern("conv_relu", False, False, CONV_TARGETS), - QuantPattern("conv_transpose_relu", False, False, CONV_TARGETS), + QuantPattern("conv_transpose_relu", False, False, CONV_TRANSPOSE_TARGETS), QuantPattern("adaptive_avg_pool2d", False, False, ADAPTIVE_AVG_POOL2D_TARGETS), QuantPattern("add_relu", False, False, ADD_TARGETS), QuantPattern("add", False, False, ADD_TARGETS), diff --git a/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py b/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py index 0a317ad8822..84b1a932a5b 100644 --- a/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py +++ b/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py @@ -120,6 +120,116 @@ def test_conv1d_with_conv2d(self): node_list, ) + def test_q_tconv_and_conv2d(self): + class TConv2dConv2d(torch.nn.Module): + def __init__(self): + super().__init__() + self.first = torch.nn.ConvTranspose2d( + in_channels=1, + out_channels=3, + kernel_size=(3, 3), + padding=1, + bias=False, + ) + self.second = torch.nn.Conv2d( + in_channels=3, + out_channels=2, + kernel_size=(3, 3), + padding=1, + bias=False, + ) + + def forward(self, x): + y = self.first(x) + return self.second(y) + + def example_inputs(self): + return (torch.randn(1, 1, 3, 3),) + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_operator_type( + torch.ops.aten.conv_transpose2d.input, quantization_config + ) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv_transpose2d.input, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + ] + m = TConv2dConv2d() + self._test_quantizer( + m, + m.example_inputs(), + quantizer, + node_occurrence, + node_list, + is_debug_mode=True, + ) + + def test_q_conv2_and_tconv2d(self): + class TConv2dConv2d(torch.nn.Module): + def __init__(self): + super().__init__() + self.first = torch.nn.ConvTranspose2d( + in_channels=1, + out_channels=3, + kernel_size=(3, 3), + padding=1, + bias=False, + ) + self.second = torch.nn.Conv2d( + in_channels=3, + out_channels=2, + kernel_size=(3, 3), + padding=1, + bias=False, + ) + + def forward(self, x): + y = self.first(x) + return self.second(y) + + def example_inputs(self): + return (torch.randn(1, 1, 3, 3),) + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_operator_type(torch.ops.aten.conv2d.default, quantization_config) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.aten.conv_transpose2d.input, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + m = TConv2dConv2d() + self._test_quantizer( + m, + m.example_inputs(), + quantizer, + node_occurrence, + node_list, + is_debug_mode=True, + ) + def test_linear(self): quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config(is_per_channel=True)