diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index a96c5b21d42..691ba1607ff 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -43,9 +43,12 @@ class LayoutTransform(ExportPass): layout_sensitive_ops = { exir_ops.edge.aten.adaptive_avg_pool2d.default, exir_ops.edge.aten._adaptive_avg_pool3d.default, + exir_ops.edge.aten.adaptive_max_pool2d.default, exir_ops.edge.aten.avg_pool2d.default, exir_ops.edge.aten.avg_pool3d.default, exir_ops.edge.aten.convolution.default, + exir_ops.edge.aten.grid_sampler_2d.default, + exir_ops.edge.aten.grid_sampler_3d.default, exir_ops.edge.aten.instance_norm.default, exir_ops.edge.aten.max_pool2d_with_indices.default, exir_ops.edge.aten._native_batch_norm_legit_no_training.default, diff --git a/backends/qualcomm/builders/README.md b/backends/qualcomm/builders/README.md index 54cfae6591c..2f1c2d54828 100644 --- a/backends/qualcomm/builders/README.md +++ b/backends/qualcomm/builders/README.md @@ -2,15 +2,19 @@ Thank you for contributing to Qualcomm AI Engine Direct delegate for ExecuTorch. Reading and following these guidelines will help you quickly get the essentials of implementing operator builder to unblock yourself and land pull requests more efficiently. ## Sections -* [References](#references) -* [Getting Started](#getting-started) - * [Identify Unsupported Operator](#identify-unsupported-operator) - * [Check Operator Spec](#check-operator-spec) - * [Implementation](#implementation) - * [Quantizer Annotation](#quantizer-annotation) -* [Operator Support Status](#operator-support-status) -* [Issues](#issues) -* [Pull Requests](#pull-requests) +- [Contribution for More Operators](#contribution-for-more-operators) + - [Sections](#sections) + - [References](#references) + - [Qualcomm AI Engine Direct](#qualcomm-ai-engine-direct) + - [PyTorch](#pytorch) + - [Getting Started](#getting-started) + - [Identify Unsupported Operator](#identify-unsupported-operator) + - [Check Operator Spec](#check-operator-spec) + - [Implementation](#implementation) + - [Quantizer Annotation](#quantizer-annotation) + - [Operator Support Status](#operator-support-status) + - [Issues](#issues) + - [Pull Requests](#pull-requests) ## References ### Qualcomm AI Engine Direct @@ -365,7 +369,7 @@ Please help update following table if you are contributing new operators: + 🚫 = Deprecated, supported with other QNN Ops -| Operators | HTP - 92/116 Enabled | +| Operators | HTP - 94/116 Enabled | |-----------|---------| | Argmax | ✓ | | Argmin | ✓ | @@ -431,7 +435,7 @@ Please help update following table if you are contributing new operators: | Gelu | ✓ | | GetSparseIndices | ✗ | | GetSparseValues | ✗ | -| GridSample | ✗ | +| GridSample | ✓ | | GroupNorm | ✓ | | HardSwish | ✓ | | InstanceNorm | ✓ | diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index 4bf0ea7e210..e982985477d 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -8,6 +8,7 @@ node_visitor, op_abs, op_adaptive_avg_pool2d, + op_adaptive_max_pool2d, op_add, op_amax, op_amin, @@ -44,6 +45,7 @@ op_gather, op_ge, op_gelu, + op_grid_sampler_2d, op_group_norm, op_gt, op_hardsigmoid, @@ -114,6 +116,7 @@ node_visitor, op_abs, op_adaptive_avg_pool2d, + op_adaptive_max_pool2d, op_add, op_amax, op_amin, @@ -150,6 +153,7 @@ op_gather, op_ge, op_gelu, + op_grid_sampler_2d, op_group_norm, op_gt, op_hardswish, diff --git a/backends/qualcomm/builders/op_adaptive_max_pool2d.py b/backends/qualcomm/builders/op_adaptive_max_pool2d.py new file mode 100644 index 00000000000..0db8f42ceb2 --- /dev/null +++ b/backends/qualcomm/builders/op_adaptive_max_pool2d.py @@ -0,0 +1,151 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import warnings +from typing import cast, Dict, List + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import numpy as np + +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA + +from .node_visitor import NodeVisitor +from .node_visitor_manager import register_node_visitor +from .qnn_constants import OpPoolMax2d, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class AdaptiveMaxPool2D(NodeVisitor): + target = ["aten.adaptive_max_pool2d.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = self.get_node(node.args[0]) + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + users = list(node.users.keys()) + for user in users: + if user.target.__name__ == "getitem": + getitem_index = user.args[1] + if getitem_index != 0: + warnings.warn( + f"[QNN Delegate Op Builder]: Expected second argument of getitem node for {node.target.__name__ } to be 0, got {getitem_index}", + stacklevel=1, + ) + return + + if len(node.args) > 2: + warnings.warn( + "[QNN Delegate Op Builder]: The return_indices is not supported, fallback op", + stacklevel=1, + ) + return + + input_height = input_tensor.shape[1] + input_width = input_tensor.shape[2] + # output cases + out_wh = cast(List[int], node.args[1]) + if len(out_wh) == 1: + output_height = node.args[1][0] + output_width = node.args[1][0] + else: + output_height = node.args[1][0] + output_width = node.args[1][1] + if output_height is None: + output_height = input_height + if output_width is None: + output_width = input_width + # NOTE: Here we need not to emphasize on mode, cuz the output shape is decided by user. + mode = OpPoolMax2d.RoundingMode.FLOOR + + # floor division + stride_height = input_height // output_height + filter_height = input_height - (output_height - 1) * stride_height + stride_width = input_width // output_width + filter_width = input_width - (output_width - 1) * stride_width + + filter = [filter_height, filter_width] + filter_shape = [len(filter)] + + stride = [stride_height, stride_width] + stride_shape = [len(stride)] + + padding = [0, 0] + padding_shape = [len(padding), len(padding)] + + out_tensor = self.get_tensor(node, node, 0) + output_tensor_wrapper = self.define_tensor( + node, + node, + out_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + adaptive_max_pool2d_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpPoolMax2d.op_name, + ) + + adaptive_max_pool2d_op.AddInputTensors([input_tensor_wrapper]) + adaptive_max_pool2d_op.AddOutputTensors([output_tensor_wrapper]) + + adaptive_max_pool2d_op.AddTensorParam( + OpPoolMax2d.param_filter_size, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(filter_shape), + filter_shape, + np.array( + filter, + dtype=np.uint32, + ), + True, + ) + + adaptive_max_pool2d_op.AddTensorParam( + OpPoolMax2d.param_stride, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(stride_shape), + stride_shape, + np.array( + stride, + dtype=np.uint32, + ), + True, + ) + + adaptive_max_pool2d_op.AddTensorParam( + OpPoolMax2d.param_pad_amount, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(padding_shape), + padding_shape, + np.array( + [[padding[0], padding[0]], [padding[1], padding[1]]], + dtype=np.uint32, + ), + True, + ) + + adaptive_max_pool2d_op.AddScalarParam( + OpPoolMax2d.param_rounding_mode, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: np.uint32(mode)}, + ) + + return adaptive_max_pool2d_op diff --git a/backends/qualcomm/builders/op_grid_sampler_2d.py b/backends/qualcomm/builders/op_grid_sampler_2d.py new file mode 100644 index 00000000000..6b6e7bf8610 --- /dev/null +++ b/backends/qualcomm/builders/op_grid_sampler_2d.py @@ -0,0 +1,162 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import warnings +from typing import cast, Dict, List + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import numpy as np + +import torch + +from executorch.backends.qualcomm.utils.constants import QCOM_DATA, QCOM_DTYPE + +from .node_visitor import NodeVisitor, QNN_QUANT_TYPE_MAP, QNN_TENSOR_TYPE_MAP +from .node_visitor_manager import register_node_visitor +from .qnn_constants import OpGridSample, OpTranspose, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class GridSample(NodeVisitor): + target = ["aten.grid_sampler_2d.default", "aten.grid_sampler_3d.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + grid_sample_op_list = [] + input_node = self.get_node(node.args[0]) + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + grid_node = self.get_node(node.args[1]) + grid_tensor = self.get_tensor(grid_node, node) + grid_tensor_wrapper = self.define_tensor( + grid_node, + node, + grid_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + input_shape = input_node.meta["val"].shape + input_rank = len(input_shape) + if input_rank not in [4, 5]: + warnings.warn( + "[QNN Delegate Op Builder]: The shape is not supported, fallback op", + stacklevel=1, + ) + return + + # About this operator, in ATen, the layout of input_tensor and of grid_tensor are not identical. + # But in HW they are all NHWC or NDHWC. So, we make shape transformation again. + if input_rank == 4: + dims_shape_back = (0, 3, 1, 2) + elif input_rank == 5: + dims_shape_back = (0, 4, 1, 2, 3) + else: + warnings.warn( + f"[QNN Delegate Op Builder]: Not support rank {input_rank}, fallback op", + stacklevel=1, + ) + return + + grid_quant_encoding, grid_quant_configs = self.get_quant_encoding_conf( + grid_node, node + ) + grid_dtype = ( + QNN_TENSOR_TYPE_MAP[grid_tensor.dtype] + if grid_quant_encoding + == PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED + else QNN_QUANT_TYPE_MAP[ + ( + torch.uint16 + if grid_quant_configs[QCOM_DTYPE] == torch.int32 + else grid_quant_configs[QCOM_DTYPE] + ) + ] + ) + # transpose + permute_output_tensor = grid_tensor.permute(dims=dims_shape_back) + transpose_output_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + "_transpose", + tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=grid_dtype, + quant_encoding=grid_quant_encoding, + quant_configs=grid_quant_configs, + dims=permute_output_tensor.size(), + tensor=permute_output_tensor, + is_fake_tensor=True, + nodes_to_wrappers=nodes_to_wrappers, + ) + + permute_order = cast(List[int], dims_shape_back) + permute_order_shape = [len(permute_order)] + transpose_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpTranspose.op_name, + ) + transpose_op.AddInputTensors([grid_tensor_wrapper]) + transpose_op.AddOutputTensors([transpose_output_tensor_wrapper]) + transpose_op.AddTensorParam( + OpTranspose.param_perm, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(permute_order_shape), + permute_order_shape, + np.array(permute_order, dtype=np.uint32), + True, + ) + grid_sample_op_list.append(transpose_op) + + out_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + out_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + align_corners = node.args[4] if len(node.args) > 4 else False + padding_mode = node.args[3] if len(node.args) > 3 else 0 + interpo_mode = node.args[2] if len(node.args) > 2 else 0 + + grid_sample_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpGridSample.op_name, + ) + grid_sample_op.AddInputTensors( + [input_tensor_wrapper, transpose_output_tensor_wrapper] + ) + grid_sample_op.AddOutputTensors([output_tensor_wrapper]) + grid_sample_op.AddScalarParam( + OpGridSample.param_align_corners, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + {QCOM_DATA: align_corners}, + ) + grid_sample_op.AddScalarParam( + OpGridSample.param_mode, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: np.uint32(interpo_mode)}, + ) + grid_sample_op.AddScalarParam( + OpGridSample.param_padding_mode, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: np.uint32(padding_mode)}, + ) + grid_sample_op_list.append(grid_sample_op) + return grid_sample_op_list diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index 19c63015f64..ecc221885dc 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -304,6 +304,24 @@ class OpGather: param_axis: str = "axis" +class OpGridSample: + op_name: str = "GridSample" + param_align_corners: str = "align_corners" + param_mode: str = "mode" + param_padding_mode: str = "padding_mode" + + @unique + class Mode(IntEnum): + BILINAR = 0 + NEAREST = 1 + + @unique + class PaddingMode(IntEnum): + ZEROS = 0 + BORDER = 1 + REFLECTION = 2 + + @dataclass(init=False, frozen=True) class OpGatherElements: op_name: str = "GatherElements" diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index 76f22552c8d..2447e6a06c6 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -19,7 +19,6 @@ ] to_be_implemented_operator = [ - exir_ops.edge.aten.adaptive_max_pool2d.default, exir_ops.edge.aten.adaptive_max_pool3d.default, exir_ops.edge.aten.div.Tensor_mode, exir_ops.edge.aten.log10.default, diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index 8b59de3bd4e..7df29d431ea 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -520,6 +520,29 @@ def annotate_full(node: Node, quantization_config: QuantizationConfig) -> None: ) +@register_annotator([torch.ops.aten.grid_sampler.default]) +def annotate_grid_sampler(node: Node, quantization_config: QuantizationConfig) -> None: + if _is_annotated([node]): + return + input_act_qsec = quantization_config.input_activation + output_act_qsec = quantization_config.output_activation + + input_qspec_map = {} + input_act0 = node.args[0] + if isinstance(input_act0, Node): + input_qspec_map[input_act0] = input_act_qsec + + input_act1 = node.args[1] + if isinstance(input_act1, Node): + input_qspec_map[input_act1] = input_act_qsec + + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qsec, + _annotated=True, + ) + + @register_annotator( [torch.ops.aten.hardswish.default, torch.ops.aten.hardswish_.default] ) @@ -561,6 +584,27 @@ def annotate_neg(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) +@register_annotator([torch.ops.aten.adaptive_max_pool2d.default]) +def annotate_adaptive_max_pool2d( + node: Node, quantization_config: QuantizationConfig +) -> None: + if _is_annotated([node]): + return + input_act_qsec = quantization_config.input_activation + output_act_qsec = quantization_config.output_activation + + input_qspec_map = {} + input_act0 = node.args[0] + if isinstance(input_act0, Node): + input_qspec_map[input_act0] = input_act_qsec + + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qsec, + _annotated=True, + ) + + @register_annotator( [ torch.ops.aten.adaptive_avg_pool1d.default, diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 1674c99175a..cdd0c194fe3 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -41,6 +41,19 @@ def forward(self, x): return torch.abs(x) +class AdaptiveMaxPool2D(torch.nn.Module): + def __init__(self, output_size, return_indices=False): + super().__init__() + self.output_size = output_size + self.return_indices = return_indices + + def forward(self, x): + adaptive_max_pool = torch.nn.AdaptiveMaxPool2d( + self.output_size, self.return_indices + ) + return adaptive_max_pool(x) + + class AdaptiveAvgPool1D(torch.nn.Module): def __init__(self): super().__init__() @@ -1098,6 +1111,20 @@ def forward(self, x): return x > self.constant +class GridSample(torch.nn.Module): + def __init__(self, mode, padding_mode, align_corners): + super().__init__() + self.mode = mode + self.align_corners = align_corners + self.padding_mode = padding_mode + + def forward(self, x, grid): + grid_sample = torch.nn.functional.grid_sample( + x, grid, self.mode, self.padding_mode, self.align_corners + ) + return grid_sample + + class GroupNorm(torch.nn.Module): def __init__(self, bias=True): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 93215218526..1cd6dbe4847 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -134,6 +134,21 @@ def test_qnn_backend_adaptive_avg_pool3d(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_inputs[j]) + def test_qnn_backend_adaptive_max_pool2d(self): + sample_input = (torch.randn(1, 512, 24, 24),) + # NOTE: Currently, we only support the return_indices is False and default is False. + # NOTE: Currently, we only support the case mod(in_w, out_w)=0 and mod(in_h, out_h)=0. + modules = [ + AdaptiveMaxPool2D((1, 1), False), # noqa: F405 + AdaptiveMaxPool2D((4, 4)), # noqa: F405 + AdaptiveMaxPool2D((24, 24)), # noqa: F405 + AdaptiveMaxPool2D((None, 4)), # noqa: F405 + AdaptiveMaxPool2D((12, None)), # noqa: F405 + ] + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_alias(self): module = Alias() # noqa: F405 sample_input = (torch.randn(1, 10),) @@ -858,6 +873,29 @@ def test_qnn_backend_gelu(self): sample_input = (torch.randn(2, 5, 1, 3),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_grid_sampler(self): + # NOTE: The grid_sampler 3d version is not supported in fp16. + modes = ["bilinear", "nearest"] + padding_modes = ["zeros", "border", "reflection"] + align_corners = [False, True] + grid_samples = [ + GridSample(mode, pad, align) # noqa: F405 + for mode, pad, align in itertools.product( + modes, padding_modes, align_corners + ) + ] + sample_inputs = [ + ( + torch.randn(1, 12, 14, 14), + torch.randn(1, 3, 3, 2), + ), # for grid_sampler 2d + ] + + for j in range(len(sample_inputs)): + for i, module in enumerate(grid_samples): + with self.subTest(i=i, j=j, module=module): + self.lower_module_and_test_output(module, sample_inputs[j]) + def test_qnn_backend_glu(self): modules = [torch.nn.GLU(), torch.nn.GLU(dim=0)] sample_input = (torch.randn(2, 5, 1, 4),) @@ -2091,6 +2129,22 @@ def test_qnn_backend_adaptive_avg_pool3d(self): module = self.get_qdq_module(module, sample_inputs[j]) self.lower_module_and_test_output(module, sample_inputs[j]) + def test_qnn_backend_adaptive_max_pool2d(self): + sample_input = (torch.randn(1, 512, 24, 24),) + # NOTE: Currently, we only support the return_indices is False and default is False. + # NOTE: Currently, we only support the case mod(in_w, out_w)=0 and mod(in_h, out_h)=0. + modules = [ + AdaptiveMaxPool2D((1, 1), False), # noqa: F405 + AdaptiveMaxPool2D((4, 4)), # noqa: F405 + AdaptiveMaxPool2D((24, 24)), # noqa: F405 + AdaptiveMaxPool2D((None, 4)), # noqa: F405 + AdaptiveMaxPool2D((12, None)), # noqa: F405 + ] + for i, module in enumerate(modules): + with self.subTest(i=i): + module_one = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module_one, sample_input) + def test_qnn_backend_alias(self): module = Alias() # noqa: F405 sample_input = (torch.randn(1, 10),) @@ -2909,6 +2963,34 @@ def test_qnn_backend_gelu(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_grid_sampler(self): + modes = ["bilinear", "nearest"] + padding_modes = ["zeros", "border", "reflection"] + align_corners = [False, True] + grid_samples = [ + GridSample(mode, pad, align) # noqa: F405 + for mode, pad, align in itertools.product( + modes, padding_modes, align_corners + ) + ] + sample_inputs = [ + ( + torch.randn(1, 12, 14, 14), + torch.randn(1, 3, 3, 2), + ), # for grid_sampler 2d + ( + torch.randn(1, 15, 9, 17, 33), + torch.randn(1, 7, 8, 9, 3), + ), # for grid_sampler 3d + ] + for j in range(len(sample_inputs)): + for i, module in enumerate(grid_samples): + with self.subTest(i=i, j=j, module=module): + module = self.get_qdq_module( + module, sample_inputs[j], quant_dtype=QuantDtype.use_16a16w + ) + self.lower_module_and_test_output(module, sample_inputs[j]) + def test_qnn_backend_glu(self): modules = [torch.nn.GLU(), torch.nn.GLU(dim=0)] sample_input = (torch.randn(2, 5, 1, 4),)