diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 3a0cea4c809..b3ddecbc298 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -12,6 +12,7 @@ AnnotateChannelsLastDimOrder, ) from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass +from executorch.backends.arm._passes.conv1d_unsqueeze_pass import Conv1dUnsqueezePass from executorch.backends.arm._passes.convert_expand_copy_to_repeat import ( ConvertExpandCopyToRepeatPass, ) @@ -69,6 +70,7 @@ def transform_to_backend_pipeline( self.add_pass(DecomposeDivPass()) self.add_pass(InsertSqueezeAfterSumPass()) self.add_pass(ConvertSplitToSlicePass()) + self.add_pass(Conv1dUnsqueezePass(exported_program)) self.add_pass(DecomposeSoftmaxesPass()) for spec in compile_spec: if spec.key == "permute_memory_format": diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index 0e74701ab6d..280864cbc91 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -1,3 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright 2024 Arm Limited and/or its affiliates. # All rights reserved. # @@ -9,11 +10,57 @@ import torch import torch.fx +from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops + +from torch._export.utils import ( + get_buffer, + get_lifted_tensor_constant, + get_param, + is_buffer, + is_lifted_tensor_constant, + is_param, +) from torch._ops import OpOverload from torch._subclasses.fake_tensor import FakeTensor +def is_get_attr_node(node: torch.fx.Node) -> bool: + """ + Returns true if the given node is a get attr node for a tensor of the model + """ + return isinstance(node, torch.fx.Node) and node.op == "get_attr" + + +def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool: + return ( + is_get_attr_node(node) + or is_param(exp_prog, node) + or is_buffer(exp_prog, node) + or is_lifted_tensor_constant(exp_prog, node) + ) + + +def get_param_tensor( + exp_prog: ExportedProgram, node: torch.fx.Node +) -> Optional[torch.Tensor]: + if node is None: + return None + elif is_param(exp_prog, node): + return get_param(exp_prog, node) + elif is_buffer(exp_prog, node): + return get_buffer(exp_prog, node) + elif is_lifted_tensor_constant(exp_prog, node): + return get_lifted_tensor_constant(exp_prog, node) + elif is_get_attr_node(node): + # This is a hack to support both lifted and unlifted graph + try: + return getattr(node.graph.owning_module, node.target) + except AttributeError: + return getattr(exp_prog.graph_module, node.target) + raise RuntimeError(f"unsupported param type, {node.op}.") + + def create_node( graph: torch.fx.Graph, op_target: OpOverload, diff --git a/backends/arm/_passes/conv1d_unsqueeze_pass.py b/backends/arm/_passes/conv1d_unsqueeze_pass.py new file mode 100644 index 00000000000..7fe5c6f7b6d --- /dev/null +++ b/backends/arm/_passes/conv1d_unsqueeze_pass.py @@ -0,0 +1,164 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_param_tensor, + insert_q_dq_pair, + is_param_node, +) +from executorch.backends.arm.tosa_quant_utils import dq_op, q_op +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class Conv1dUnsqueezePass(ExportPass): + """ + This pass is used to change conv1d ops into conv2d since TOSA only + supports 2d and 3d convolution. This is done by modifying the graph to do the + following: + 1) unsqueeze the convolution's input from 3d to 4d + 2) if the input to unsqueeze is quantized, insert q/dq-pair after unsqueeze + 3) perform a conv2d (with a modified version of the original conv1d args) + 4) squeeze the output back down to 3d. + 5) if all users of squeeze are quantized, insert q/dq-pair before squeeze + """ + + def __init__(self, exported_program: ExportedProgram) -> None: + super().__init__() + self.exported_program = exported_program + + def unsqueeze_kernel_weights(self, kernel_node): + """ + Unsqueezes the weights of a conv1d to make it 4 dimensional. + + Args: + kernel_node: the weights of conv1d node to be unsqueezed + """ + kernel_param_3d = get_param_tensor(self.exported_program, kernel_node) + if kernel_param_3d is None: + raise AssertionError("Expected param tensor for the kernel node") + + kernel_param_4d = torch.nn.Parameter( + data=kernel_param_3d.data.contiguous().unsqueeze(dim=-1), + requires_grad=False, + ) + + if torch._export.utils.is_param(self.exported_program, kernel_node): + parameter_name = self.exported_program.graph_signature.inputs_to_parameters[ + kernel_node.name + ] + self.exported_program.state_dict[parameter_name] = kernel_param_4d + kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1) + elif torch._export.utils.is_buffer(self.exported_program, kernel_node): + buffer_name = self.exported_program.graph_signature.inputs_to_buffers[ + kernel_node.name + ] + self.exported_program.state_dict[buffer_name] = kernel_param_4d + kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1) + elif torch._export.utils.is_lifted_tensor_constant( + self.exported_program, kernel_node + ): + buffer_name = ( + self.exported_program.graph_signature.inputs_to_lifted_tensor_constants[ + kernel_node.name + ] + ) + self.exported_program.constants[buffer_name] = kernel_param_4d + kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1) + else: + setattr( + kernel_node.graph.owning_module, + kernel_node.target, + kernel_param_4d, + ) + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + node_list = list(graph.nodes) + for node in node_list: + if node.op == "call_function": + if node.target == exir_ops.edge.aten.convolution.default: + stride = list(node.args[3]) + if len(stride) != 1: + # skip conv if it is not 1d + continue + + kernel_node = node.args[1] + if kernel_node.target == dq_op: + kernel_node = kernel_node.args[0] + + if not is_param_node(self.exported_program, kernel_node): + raise AssertionError( + "Expected op for convolution weight node to be a get_attr node or a parameter" + ) + + # Modify graph such that the conv changes from 1d to 2d + self.unsqueeze_kernel_weights(kernel_node) + + # (b) Extend stride, padding, and dilation for extra dim + node.args = ( + node.args[0], + node.args[1], + node.args[2], + node.args[3] + [1], # stride + node.args[4] + [0], # padding + node.args[5] + [1], # dilation + node.args[6], + node.args[7] + [0], + node.args[8], + ) + + # c. Add unsqueeze to input (3d -> 4d) and squeeze to output (4d -> 3d) + # unsqueeze -> conv2d -> squeeze + with graph.inserting_before(node): + input_node = node.args[0] + unsqueeze_before = create_node( + graph, exir_ops.edge.aten.unsqueeze_copy.default + ) + unsqueeze_before.args = ( + input_node, # Input is node's original input + -1, # Last Dimension + ) + node.replace_input_with(input_node, unsqueeze_before) + + # If Quantized we must insert unsqueeze --> q --> dq --> node + if input_node.target == dq_op: + q_params = input_node.args[1:] + insert_q_dq_pair(graph, unsqueeze_before, q_params) + + with graph.inserting_after(node): + squeeze_after = create_node( + graph, + exir_ops.edge.aten.squeeze_copy.dims, + ) + squeeze_after.args = ( + node, # Input is the conv node + [-1], # Last dimension + ) + original_users = [ + user for user in node.users if user != squeeze_after + ] + for user in original_users: + user.replace_input_with(node, squeeze_after) + + # If quantized, insert conv2d --> q --> dq --> squeeze + if all( + original_user.target == q_op for original_user in original_users + ): + q_params = original_users[0].args[1:] + insert_q_dq_pair(graph, node, q_params) + + graph_module.recompile() + # Since we are overriding "call", we need to call the parent's "call" + # to retrace the graph and regenerate metadata + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) diff --git a/backends/arm/_passes/convert_split_to_slice.py b/backends/arm/_passes/convert_split_to_slice.py index 787cebec9d7..ed2dcd4008f 100644 --- a/backends/arm/_passes/convert_split_to_slice.py +++ b/backends/arm/_passes/convert_split_to_slice.py @@ -70,4 +70,5 @@ def call(self, graph_module: torch.fx.GraphModule): output_node.replace_all_uses_with(slice_node) graph.eliminate_dead_code() graph_module.recompile() + graph_module = super().call(graph_module).graph_module return PassResult(graph_module, True) diff --git a/backends/arm/test/ops/test_conv1d.py b/backends/arm/test/ops/test_conv1d.py new file mode 100644 index 00000000000..3b275542213 --- /dev/null +++ b/backends/arm/test/ops/test_conv1d.py @@ -0,0 +1,298 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from typing import List, Optional, Tuple, Union + +import torch +from executorch.backends.arm.test import common + +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.exir.backend.backend_details import CompileSpec +from parameterized import parameterized + + +class Conv1d(torch.nn.Module): + """ + Creates one or many chained 1D-convolutions. For multiple convolutions, the + respective parameteres are provided as lists. + """ + + def __init__( + self, + inputs: Optional[torch.Tensor] = None, + length=8, + nbr_conv=1, # Number of chained convs + in_channels: Union[List, int, None] = None, + out_channels: Union[List, int, None] = None, + kernel_size: Union[List, Tuple, None] = None, + stride: Union[List, Tuple, None] = None, + padding: Union[List, Tuple, None] = None, + dilation: Union[List, Tuple, None] = None, + groups: Union[List, int, None] = None, + bias: Union[List, bool, None] = None, + padding_mode: Union[List, str, None] = None, + batches=1, + dtype=torch.float32, + ): + super().__init__() + self.nbr_convs = nbr_conv + + # Handle default values + in_channels = [2] * nbr_conv if in_channels is None else in_channels + out_channels = [1 * nbr_conv] if out_channels is None else out_channels + kernel_size = [3] * nbr_conv if kernel_size is None else kernel_size + stride = [2] * nbr_conv if stride is None else stride + padding = [1] * nbr_conv if padding is None else padding + dilation = [1] * nbr_conv if dilation is None else dilation + groups = [1] * nbr_conv if groups is None else groups + bias = [True] * nbr_conv if bias is None else bias + padding_mode = ["zeros"] * nbr_conv if padding_mode is None else padding_mode + + # This allows the input parameters to be either a single value or a list + # as type hint implies + if not isinstance(in_channels, List): + in_channels = [in_channels] + if not isinstance(out_channels, List): + out_channels = [out_channels] + if not isinstance(kernel_size, List): + kernel_size = [kernel_size] + if not isinstance(stride, List): + stride = [stride] + if not isinstance(padding, List): + padding = [padding] + if not isinstance(dilation, List): + dilation = [dilation] + if not isinstance(groups, List): + groups = [groups] + if not isinstance(bias, List): + bias = [bias] + if not isinstance(padding_mode, List): + padding_mode = [padding_mode] + + # Generate test data if not provided + if inputs is None: + self.inputs = (torch.randn(batches, in_channels[0], length).to(dtype),) + else: + self.inputs = (inputs,) + + # Build chain of convs + for i in range(self.nbr_convs): + setattr( + self, + f"conv_{i}", + torch.nn.Conv1d( + in_channels=in_channels[i], + out_channels=out_channels[i], + kernel_size=kernel_size[i], + stride=stride[i], + padding=padding[i], + dilation=dilation[i], + groups=groups[i], + bias=bias[i], + padding_mode=padding_mode[i], + ).to(dtype), + ) + + def get_inputs(self): + return self.inputs + + def forward(self, x): + for i in range(self.nbr_convs): + conv = getattr(self, f"conv_{i}") + x = conv(x) + return x + + +conv1d_2_3x2x40_nobias = Conv1d( + in_channels=2, + out_channels=3, + kernel_size=2, + stride=1, + bias=False, + padding=0, + length=40, + batches=1, +) + +conv1d_3_1x3x256_st1 = Conv1d( + in_channels=3, + out_channels=10, + kernel_size=3, + stride=1, + padding=0, + length=256, + batches=1, +) + +conv1d_3_1x3x12_st2_pd1 = Conv1d( + in_channels=3, + out_channels=4, + kernel_size=3, + stride=2, + padding=1, + length=12, + batches=1, +) + +conv1d_1_1x2x128_st1 = Conv1d( + in_channels=2, + out_channels=1, + kernel_size=1, + stride=1, + padding=0, + length=128, + batches=1, +) + +conv1d_2_1x2x14_st2 = Conv1d( + in_channels=2, + out_channels=1, + kernel_size=2, + stride=2, + padding=0, + length=14, + batches=1, +) + +conv1d_5_3x2x128_st1 = Conv1d( + in_channels=2, + out_channels=3, + kernel_size=5, + stride=1, + padding=0, + length=128, + batches=3, +) + +conv1d_3_1x3x224_st2_pd1 = Conv1d( + in_channels=3, + out_channels=16, + kernel_size=3, + stride=2, + padding=1, + length=224, + batches=1, +) + +two_conv1d_nobias = Conv1d( + nbr_conv=2, + length=256, + in_channels=[3, 10], + out_channels=[10, 15], + kernel_size=[5, 5], + stride=[1, 1], + padding=[0, 0], + bias=[False, False], + batches=1, +) + +two_conv1d = Conv1d( + nbr_conv=2, + length=256, + in_channels=[3, 10], + out_channels=[10, 15], + kernel_size=[5, 5], + stride=[1, 1], + padding=[0, 0], + bias=[True, True], + batches=1, +) + +# Shenanigan to get a nicer output when test fails. With unittest it looks like: +# FAIL: test_conv1d_tosa_BI_2_3x3_1x3x12x12_st2_pd1 +testsuite = [ + ("2_3x2x40_nobias", conv1d_2_3x2x40_nobias), + ("3_1x3x256_st1", conv1d_3_1x3x256_st1), + ("3_1x3x12_st2_pd1", conv1d_3_1x3x12_st2_pd1), + ("1_1x2x128_st1", conv1d_1_1x2x128_st1), + ("2_1x2x14_st2", conv1d_2_1x2x14_st2), + ("5_3x2x128_st1", conv1d_5_3x2x128_st1), + ("3_1x3x224_st2_pd1", conv1d_3_1x3x224_st2_pd1), + ("two_conv1d_nobias", two_conv1d_nobias), + ("two_conv1d", two_conv1d), +] + + +class TestConv1D(unittest.TestCase): + def _test_conv1d_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + ) + .export() + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .check_not(["executorch_exir_dialects_edge__ops_aten_convolution_default"]) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_conv1d_tosa_BI_pipeline( + self, + module: torch.nn.Module, + test_data: Tuple[torch.Tensor], + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + ) + .quantize() + .export() + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .check_not(["executorch_exir_dialects_edge__ops_aten_convolution_default"]) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data, qtol=1) + ) + + def _test_conv1d_ethosu_BI_pipeline( + self, + module: torch.nn.Module, + compile_spec: CompileSpec, + test_data: Tuple[torch.Tensor], + ): + ( + ArmTester(module, example_inputs=test_data, compile_spec=compile_spec) + .quantize() + .export() + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .check_not(["executorch_exir_dialects_edge__ops_aten_convolution_default"]) + .to_executorch() + ) + + @parameterized.expand(testsuite) + def test_conv1d_tosa_MI(self, test_name, model): + self._test_conv1d_tosa_MI_pipeline(model, model.get_inputs()) + + @parameterized.expand(testsuite) + def test_conv1d_tosa_BI(self, test_name, model): + self._test_conv1d_tosa_BI_pipeline(model, model.get_inputs()) + + # Expeted to fail as Conv1D requires transpoes which isn't supported on u55 + @parameterized.expand(testsuite) + @unittest.expectedFailure + def test_conv1d_u55_BI(self, test_name, model): + self._test_conv1d_ethosu_BI_pipeline( + model, common.get_u55_compile_spec(), model.get_inputs() + ) + + @parameterized.expand(testsuite) + def test_conv1d_u85_BI(self, test_name, model): + self._test_conv1d_ethosu_BI_pipeline( + model, common.get_u85_compile_spec(), model.get_inputs() + ) diff --git a/backends/arm/test/ops/test_conv.py b/backends/arm/test/ops/test_conv2d.py similarity index 98% rename from backends/arm/test/ops/test_conv.py rename to backends/arm/test/ops/test_conv2d.py index decf790ce51..46adfc8a016 100644 --- a/backends/arm/test/ops/test_conv.py +++ b/backends/arm/test/ops/test_conv2d.py @@ -6,7 +6,7 @@ import unittest -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from executorch.backends.arm.test import common @@ -18,13 +18,13 @@ class Conv2d(torch.nn.Module): """ - Creates one or many chained convolutions. For multiple convolutions, the + Creates one or many chained 2D-convolutions. For multiple convolutions, the respective parameteres are provided as lists. """ def __init__( self, - inputs: torch.Tensor = None, + inputs: Optional[torch.Tensor] = None, height=8, width=8, nbr_conv=1, # Number of chained convs diff --git a/backends/arm/test/ops/test_depthwise_conv.py b/backends/arm/test/ops/test_depthwise_conv.py index a63066bee68..01ffbc10543 100644 --- a/backends/arm/test/ops/test_depthwise_conv.py +++ b/backends/arm/test/ops/test_depthwise_conv.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import logging import unittest from typing import Tuple @@ -13,14 +12,13 @@ import torch from executorch.backends.arm.test import common -from executorch.backends.arm.test.ops.test_conv import Conv2d +from executorch.backends.arm.test.ops.test_conv1d import Conv1d +from executorch.backends.arm.test.ops.test_conv2d import Conv2d from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.exir.backend.backend_details import CompileSpec from parameterized import parameterized -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) """ The configuration when @@ -29,6 +27,29 @@ where K is a positive integer is termed in literature as depthwise convolution. """ + +dw_conv1d_3_1x3x14_gp3_st1 = Conv1d( + in_channels=3, + out_channels=3, + kernel_size=7, + stride=1, + groups=3, + length=14, + batches=1, + padding=3, +) + +dw_conv1d_2_1x6x4_gp6_st1 = Conv1d( + in_channels=6, + out_channels=12, + kernel_size=2, + stride=1, + groups=6, + padding=0, + length=4, + batches=1, +) + dw_conv2d_2x2_1x6x4x4_gp6_st1 = Conv2d( in_channels=6, out_channels=12, @@ -41,6 +62,17 @@ batches=1, ) +dw_conv1d_3_1x3x256_gp3_st1 = Conv1d( + in_channels=3, + out_channels=3, + kernel_size=3, + stride=1, + groups=3, + padding=0, + length=256, + batches=1, +) + dw_conv2d_3x3_1x3x256x256_gp3_st1 = Conv2d( in_channels=3, out_channels=3, @@ -89,6 +121,19 @@ batches=1, ) +two_dw_conv1d = Conv1d( + nbr_conv=2, + length=64, + in_channels=[4, 8], + out_channels=[8, 24], + kernel_size=[3, 3], + stride=[1, 1], + padding=[0, 0], + groups=[4, 8], + bias=[True, True], + batches=1, +) + two_dw_conv2d = Conv2d( nbr_conv=2, width=64, @@ -104,7 +149,7 @@ ) # Shenanigan to get a nicer output when test fails. -testsuite = [ +testsuite_conv2d = [ ("2x2_1x6x4x4_gp6_st1", dw_conv2d_2x2_1x6x4x4_gp6_st1), ("3x3_1x3x256x256_gp3_st1", dw_conv2d_3x3_1x3x256x256_gp3_st1), ("3x3_1x4x256x256_gp4_st1", dw_conv2d_3x3_1x4x256x256_gp4_st1), @@ -113,12 +158,19 @@ ("two_dw_conv2d", two_dw_conv2d), ] +testsuite_conv1d = [ + ("2_1x6x4_gp6_st1", dw_conv1d_2_1x6x4_gp6_st1), + ("3_1x3x256_gp3_st1", dw_conv1d_3_1x3x256_gp3_st1), + ("two_dw_conv1d", two_dw_conv1d), + ("3_1x3x14_gp3_st1", dw_conv1d_3_1x3x14_gp3_st1), +] + -class TestDepthwiseConv2D(unittest.TestCase): - """Tests Conv2D where groups == in_channels and out_channels = K * in_channels. This +class TestDepthwiseConv(unittest.TestCase): + """Tests Conv1D and Conv2D where groups == in_channels and out_channels = K * in_channels. This is a special case enables depthwise convolution.""" - def _test_dw_conv2d_tosa_MI_pipeline( + def _test_dw_conv_tosa_MI_pipeline( self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] ): ( @@ -136,7 +188,7 @@ def _test_dw_conv2d_tosa_MI_pipeline( .run_method_and_compare_outputs(inputs=test_data) ) - def _test_dw_conv2d_tosa_BI_pipeline( + def _test_dw_conv_tosa_BI_pipeline( self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] ): ( @@ -155,7 +207,7 @@ def _test_dw_conv2d_tosa_BI_pipeline( .run_method_and_compare_outputs(inputs=test_data, qtol=1) ) - def _test_dw_conv2d_ethos_BI_pipeline( + def _test_dw_conv_ethos_BI_pipeline( self, module: torch.nn.Module, compile_spec: CompileSpec, @@ -176,21 +228,36 @@ def _test_dw_conv2d_ethos_BI_pipeline( .to_executorch() ) - @parameterized.expand(testsuite) - def test_dw_conv2d_tosa_MI(self, test_name: str, model: torch.nn.Module): - self._test_dw_conv2d_tosa_MI_pipeline(model, model.get_inputs()) + @parameterized.expand(testsuite_conv1d + testsuite_conv2d) + def test_dw_conv_tosa_MI(self, test_name: str, model: torch.nn.Module): + self._test_dw_conv_tosa_MI_pipeline(model, model.get_inputs()) # TODO: Investigate flakyness (MLTORCH-307) - @parameterized.expand(testsuite) + @parameterized.expand(testsuite_conv1d + testsuite_conv2d) @pytest.mark.flaky(reruns=3) - def test_dw_conv2d_tosa_BI(self, test_name: str, model: torch.nn.Module): - self._test_dw_conv2d_tosa_BI_pipeline(model, model.get_inputs()) + def test_dw_conv_tosa_BI(self, test_name: str, model: torch.nn.Module): + self._test_dw_conv_tosa_BI_pipeline(model, model.get_inputs()) - @parameterized.expand(testsuite, skip_on_empty=True) + @parameterized.expand(testsuite_conv2d, skip_on_empty=True) def test_dw_conv2d_u55_BI( self, test_name: str, model: torch.nn.Module, set_quantize_io: bool = False ): - self._test_dw_conv2d_ethos_BI_pipeline( + self._test_dw_conv_ethos_BI_pipeline( + model, + common.get_u55_compile_spec( + permute_memory_to_nhwc=True, quantize_io=set_quantize_io + ), + model.get_inputs(), + ) + + # Expected to fail as conv1d needs transpose which is not supported + # on u55. + @parameterized.expand(testsuite_conv1d, skip_on_empty=True) + @unittest.expectedFailure + def test_dw_conv1d_u55_BI( + self, test_name: str, model: torch.nn.Module, set_quantize_io: bool = False + ): + self._test_dw_conv_ethos_BI_pipeline( model, common.get_u55_compile_spec( permute_memory_to_nhwc=True, quantize_io=set_quantize_io @@ -198,11 +265,11 @@ def test_dw_conv2d_u55_BI( model.get_inputs(), ) - @parameterized.expand(testsuite) - def test_dw_conv2d_u85_BI( + @parameterized.expand(testsuite_conv1d + testsuite_conv2d) + def test_dw_conv_u85_BI( self, test_name: str, model: torch.nn.Module, set_quantize_io: bool = False ): - self._test_dw_conv2d_ethos_BI_pipeline( + self._test_dw_conv_ethos_BI_pipeline( model, common.get_u85_compile_spec( permute_memory_to_nhwc=True, quantize_io=set_quantize_io diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 167b99a3281..3e9d3620cca 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -101,7 +101,7 @@ def _get_input_quantization_params( ): # break early if we have all the inputs quantized parameters break if len(quant_params) == 0: - raise RuntimeError("No Quantization parameters not found in exported model.") + raise RuntimeError("No Quantization parameters found in exported model.") return quant_params