From daf7e6ec40a442f15981ee99cceb94de216c2fdb Mon Sep 17 00:00:00 2001 From: haowhsu-quic Date: Wed, 18 Sep 2024 12:35:51 +0800 Subject: [PATCH] Qualcomm AI Engine Direct - support Conv2dTranspose Summary: - Conv2dTranspose op enablement - test cases --- backends/qualcomm/builders/op_conv2d.py | 118 ++++++++++--------- backends/qualcomm/builders/qnn_constants.py | 9 ++ backends/qualcomm/quantizer/utils.py | 8 +- backends/qualcomm/tests/models.py | 40 +++++++ backends/qualcomm/tests/test_qnn_delegate.py | 32 +++++ 5 files changed, 149 insertions(+), 58 deletions(-) diff --git a/backends/qualcomm/builders/op_conv2d.py b/backends/qualcomm/builders/op_conv2d.py index 4b58edbac63..b6e70c374e0 100644 --- a/backends/qualcomm/builders/op_conv2d.py +++ b/backends/qualcomm/builders/op_conv2d.py @@ -18,6 +18,7 @@ OpDepthWiseConv2d, OpExpandDims, OpReshape, + OpTransposeConv2d, QNN_OP_PACKAGE_NAME_QTI_AISW, ) from .utils import get_parameter @@ -42,6 +43,9 @@ def _add_conv_op_parameter( padding_shape, dilation, dilation_shape, + output_padding=None, + output_padding_shape=None, + transpose_conv=False, groups=None, ) -> PyQnnWrapper.PyQnnOpWrapper: """ @@ -68,14 +72,26 @@ def _add_conv_op_parameter( ), True, ) - conv_op.AddTensorParam( - OP.param_dilation, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - len(dilation_shape), - dilation_shape, - np.array(dilation, dtype=np.uint32), - True, - ) + + if transpose_conv: + conv_op.AddTensorParam( + OP.param_output_padding, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(output_padding_shape), + output_padding_shape, + np.array(output_padding, dtype=np.uint32), + True, + ) + else: + conv_op.AddTensorParam( + OP.param_dilation, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(dilation_shape), + dilation_shape, + np.array(dilation, dtype=np.uint32), + True, + ) + if groups is not None: conv_op.AddScalarParam( OP.param_group, @@ -94,6 +110,11 @@ def _define_conv1d( Conv1D is a special case for convolutional operation. QNN does not support Conv1D, therefore, we need to cast from input -> Conv1d -> output to input -> unsqueeze -> Conv2d -> squeeze -> output. """ + transpose_conv = cast(bool, node.args[6]) + if transpose_conv: + print("ConvTranspose1d is not yet supported") + return + op_wrapper_list = [] # op_wrapper to return unsqueeze_input_node = node.args[0] input_quant_encoding, input_quant_configs = self.get_quant_encoding_conf( @@ -239,9 +260,9 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - if get_parameter(node.args[1], self.edge_program).dim() == 3: return self._define_conv1d(node, nodes_to_wrappers) + input_node = node.args[0] input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( @@ -254,8 +275,9 @@ def define_node( filter_node = node.args[1] filter_tensor = get_parameter(filter_node, self.edge_program) - # weight of pytorch OIHW, yet QNN is HWIO - filter_axis_order = (2, 3, 1, 0) + # weight of pytorch OIHW(conv2d) | IOHW(conv_transpose2d), yet QNN is HWIO + is_transpose_conv = cast(bool, node.args[6]) + filter_axis_order = (2, 3, 0, 1) if is_transpose_conv else (2, 3, 1, 0) filter_tensor = filter_tensor.permute(dims=filter_axis_order).contiguous() filter_tensor_wrapper = self.define_tensor( filter_node, @@ -291,6 +313,7 @@ def define_node( stride = cast(List[int], node.args[3]) padding = cast(List[int], node.args[4]) dilation = cast(List[int], node.args[5]) + output_padding = cast(List[int], node.args[7]) groups = cast(int, node.args[8]) # Qnn filter tensor is (H, W, Cin, Cout) @@ -308,57 +331,38 @@ def define_node( if len(padding) == 1: padding = padding + padding - # args[6] = transposed - if cast(bool, node.args[6]): - print("Currently, No support for transposed convolution") - return - - # args[7] = output padding - if not all(out_pad == 0 for out_pad in cast(List[int], node.args[7])): - print("QNN does not support output padding") - return - stride_shape = [len(stride)] padding_shape = [2, 2] dilation_shape = [len(dilation)] + output_padding_shape = [len(output_padding)] if is_depthwise_conv: - conv_op = PyQnnWrapper.PyQnnOpWrapper( - node.name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - OpDepthWiseConv2d.op_name, - ) - conv_op = self._add_conv_op_parameter( - OpDepthWiseConv2d, - conv_op, - conv_input_tensors, - conv_output_tensors, - stride, - stride_shape, - padding, - padding_shape, - dilation, - dilation_shape, - ) - + op_class = OpDepthWiseConv2d + elif is_transpose_conv: + op_class = OpTransposeConv2d else: - conv_op = PyQnnWrapper.PyQnnOpWrapper( - node.name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - OpConv2d.op_name, - ) - conv_op = self._add_conv_op_parameter( - OpConv2d, - conv_op, - conv_input_tensors, - conv_output_tensors, - stride, - stride_shape, - padding, - padding_shape, - dilation, - dilation_shape, - groups, - ) + op_class = OpConv2d + + conv_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + op_class.op_name, + ) + conv_op = self._add_conv_op_parameter( + op_class, + conv_op, + conv_input_tensors, + conv_output_tensors, + stride, + stride_shape, + padding, + padding_shape, + dilation, + dilation_shape, + output_padding, + output_padding_shape, + is_transpose_conv, + None if is_depthwise_conv else groups, + ) return conv_op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index 8ac702f2ad5..9c589c76784 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -356,3 +356,12 @@ class OpTile: class OpTranspose: op_name: str = "Transpose" param_perm: str = "perm" + + +@dataclass(init=False, frozen=True) +class OpTransposeConv2d: + op_name: str = "TransposeConv2d" + param_stride: str = "stride" + param_pad_amount: str = "pad_amount" + param_group: str = "group" + param_output_padding: str = "output_padding" diff --git a/backends/qualcomm/quantizer/utils.py b/backends/qualcomm/quantizer/utils.py index d3ae1194acd..ed9afd70ce6 100644 --- a/backends/qualcomm/quantizer/utils.py +++ b/backends/qualcomm/quantizer/utils.py @@ -941,7 +941,13 @@ def annotate_bmm(node: Node, quantization_config: QuantizationConfig) -> None: node.meta["source_fn_stack"] = [(node, torch.bmm)] -@register_annotator([torch.ops.aten.conv2d.default, torch.ops.aten.conv1d.default]) +@register_annotator( + [ + torch.ops.aten.conv2d.default, + torch.ops.aten.conv1d.default, + torch.ops.aten.conv_transpose2d.input, + ] +) def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None: if _is_annotated([node]): return diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index e448a219284..ee3d6cf93a7 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -361,6 +361,46 @@ def forward(self, x): return self.conv(x) +class ConvTranspose2dSingle(torch.nn.Module): + def __init__(self, bias=True): + super().__init__() + self.conv_transpose = torch.nn.ConvTranspose2d( + in_channels=1, + out_channels=3, + kernel_size=3, + stride=2, + padding=1, + bias=bias, + ) + + def forward(self, x): + return self.conv_transpose(x) + + +class Conv2dDownUpSample(torch.nn.Module): + def __init__(self, bias=True): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=16, + out_channels=16, + kernel_size=3, + stride=2, + padding=1, + bias=bias, + ) + self.conv_transpose = torch.nn.ConvTranspose2d( + in_channels=16, + out_channels=16, + kernel_size=3, + stride=2, + padding=1, + bias=bias, + ) + + def forward(self, x): + return self.conv_transpose(self.conv(x)) + + class Conv2dSumReduceDim(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index d022ac96c48..8abed68c630 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -130,6 +130,16 @@ def test_qnn_backend_conv2d(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv_transpose2d(self): + modules = [ + ConvTranspose2dSingle(), # noqa: F405 + ConvTranspose2dSingle(bias=False), # noqa: F405 + ] + sample_input = (torch.randn([1, 1, 3, 3]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_element_wise_add(self): test_comb = [ { @@ -521,6 +531,11 @@ def test_qnn_backend_conv2d_cat(self): sample_input = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5)) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv2d_down_up_sample(self): + module = Conv2dDownUpSample() # noqa: F405 + sample_input = (torch.randn(1, 16, 224, 224),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv2d_max_pool2d(self): module = Conv2dMaxPool2d() # noqa: F405 sample_input = (torch.rand(1, 2, 14, 14),) @@ -713,6 +728,17 @@ def test_qnn_backend_conv2d(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv_transpose2d(self): + modules = [ + ConvTranspose2dSingle(), # noqa: F405 + ConvTranspose2dSingle(bias=False), # noqa: F405 + ] # noqa: F405 + sample_input = (torch.randn([1, 1, 3, 3]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_element_wise_add(self): test_comb = [ { @@ -1157,6 +1183,12 @@ def test_qnn_backend_conv2d_cat(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv2d_down_up_sample(self): + module = Conv2dDownUpSample() # noqa: F405 + sample_input = (torch.randn(1, 16, 224, 224),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv2d_max_pool2d(self): module = Conv2dMaxPool2d() # noqa: F405 sample_input = (torch.rand(1, 2, 14, 14),)