diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer.py b/backends/xnnpack/quantizer/xnnpack_quantizer.py index c07d27e4231..3c82a65ad71 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer.py @@ -251,6 +251,15 @@ class QuantPattern: torch.ops.aten.convolution.default, } +CONV_TRANSPOSE_TARGETS = { + torch.ops.aten.conv_transpose1d, + torch.ops.aten.conv_transpose1d.default, + torch.ops.aten.conv_transpose2d, + torch.ops.aten.conv_transpose2d.input, + torch.ops.aten.conv_transpose3d, + torch.ops.aten.conv_transpose3d.input, +} + LINEAR_TARGETS = { torch.ops.aten.linear.default, } @@ -269,14 +278,14 @@ class XNNPACKQuantizer(Quantizer): SUPPORTED_PATTERNS = [ QuantPattern("conv_bn_relu", False, True, CONV_TARGETS), QuantPattern("conv_bn", False, True, CONV_TARGETS), - QuantPattern("conv_transpose_bn_relu", False, True, CONV_TARGETS), - QuantPattern("conv_transpose_bn", False, True, CONV_TARGETS), + QuantPattern("conv_transpose_bn_relu", False, True, CONV_TRANSPOSE_TARGETS), + QuantPattern("conv_transpose_bn", False, True, CONV_TRANSPOSE_TARGETS), QuantPattern("linear_relu", False, False, LINEAR_TARGETS), QuantPattern("linear", True, False, LINEAR_TARGETS), QuantPattern("conv", True, False, CONV_TARGETS), - QuantPattern("conv_transpose", True, False, CONV_TARGETS), + QuantPattern("conv_transpose", True, False, CONV_TRANSPOSE_TARGETS), QuantPattern("conv_relu", False, False, CONV_TARGETS), - QuantPattern("conv_transpose_relu", False, False, CONV_TARGETS), + QuantPattern("conv_transpose_relu", False, False, CONV_TRANSPOSE_TARGETS), QuantPattern("adaptive_avg_pool2d", False, False, ADAPTIVE_AVG_POOL2D_TARGETS), QuantPattern("add_relu", False, False, ADD_TARGETS), QuantPattern("add", False, False, ADD_TARGETS), diff --git a/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py b/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py index 0a317ad8822..84b1a932a5b 100644 --- a/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py +++ b/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py @@ -120,6 +120,116 @@ def test_conv1d_with_conv2d(self): node_list, ) + def test_q_tconv_and_conv2d(self): + class TConv2dConv2d(torch.nn.Module): + def __init__(self): + super().__init__() + self.first = torch.nn.ConvTranspose2d( + in_channels=1, + out_channels=3, + kernel_size=(3, 3), + padding=1, + bias=False, + ) + self.second = torch.nn.Conv2d( + in_channels=3, + out_channels=2, + kernel_size=(3, 3), + padding=1, + bias=False, + ) + + def forward(self, x): + y = self.first(x) + return self.second(y) + + def example_inputs(self): + return (torch.randn(1, 1, 3, 3),) + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_operator_type( + torch.ops.aten.conv_transpose2d.input, quantization_config + ) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv_transpose2d.input, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + ] + m = TConv2dConv2d() + self._test_quantizer( + m, + m.example_inputs(), + quantizer, + node_occurrence, + node_list, + is_debug_mode=True, + ) + + def test_q_conv2_and_tconv2d(self): + class TConv2dConv2d(torch.nn.Module): + def __init__(self): + super().__init__() + self.first = torch.nn.ConvTranspose2d( + in_channels=1, + out_channels=3, + kernel_size=(3, 3), + padding=1, + bias=False, + ) + self.second = torch.nn.Conv2d( + in_channels=3, + out_channels=2, + kernel_size=(3, 3), + padding=1, + bias=False, + ) + + def forward(self, x): + y = self.first(x) + return self.second(y) + + def example_inputs(self): + return (torch.randn(1, 1, 3, 3),) + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_operator_type(torch.ops.aten.conv2d.default, quantization_config) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.aten.conv_transpose2d.input, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + m = TConv2dConv2d() + self._test_quantizer( + m, + m.example_inputs(), + quantizer, + node_occurrence, + node_list, + is_debug_mode=True, + ) + def test_linear(self): quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config(is_per_channel=True)