diff --git a/backends/nxp/aten_passes/neutron_aten_pass_manager.py b/backends/nxp/aten_passes/neutron_aten_pass_manager.py index 6c989b6c136..d11309894d1 100644 --- a/backends/nxp/aten_passes/neutron_aten_pass_manager.py +++ b/backends/nxp/aten_passes/neutron_aten_pass_manager.py @@ -13,11 +13,14 @@ from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import ( FuseBatchNormWithLinearPass, ) +from executorch.backends.nxp.aten_passes.split_group_convolution import ( + SplitGroupConvolution, +) from executorch.exir.pass_manager import PassManager from torch import nn from torch.fx.passes.infra.pass_base import PassResult -PassType = list[type[Callable[[torch.fx.GraphModule], PassResult]]] +PassType = type[Callable[[torch.fx.GraphModule], PassResult]] class NeutronAtenPassManager(PassManager): @@ -26,6 +29,7 @@ def __init__(self, passes: list[PassType] = None): passes: list[PassType] = passes or [ FuseBatchNormWithConvPass(), FuseBatchNormWithLinearPass(), + SplitGroupConvolution(), ] super().__init__(passes) diff --git a/backends/nxp/aten_passes/split_group_convolution.py b/backends/nxp/aten_passes/split_group_convolution.py new file mode 100644 index 00000000000..58c87730c84 --- /dev/null +++ b/backends/nxp/aten_passes/split_group_convolution.py @@ -0,0 +1,281 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import operator + +import torch + +from executorch.backends.nxp.backend.ir.converter.node_converters.shared.conv_utils import ( + group_conv_convertible_into_multiple_convolutions, +) +from torch._subclasses import FakeTensor, FakeTensorMode +from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix +from torch.export.unflatten import _assign_attr, _AttrKind +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torch.nn.parameter import Parameter + + +class SplitGroupConvolution(PassBase): + """The eIQ Neutron NPU supports only regular and depthwise convolutions. Group convolutions must be decomposed into + multiple parallel single group convolutions. + Replace the nodes in the following pattern. The square brackets indicate the tensor shapes. + + + │[N, Ic, ...] + ┌───▼───┐ + │ split │ + └┬─────┬┘ + ┌──────────────────┘ ... └────────────────┐ + │[N, Ic, ...] │[N, Ic/G, ...] │[N, Ic/G, ...] + ┌──────▼──────┐ ┌──────▼──────┐ ┌──────▼──────┐ + │ convolution ◄──W [Oc, Ic/G, ...] replace │ convolution ◄──W [Oc/G, Ic/G, ...] │ convolution ◄──W [Oc/G, Ic/G, ...] + │ group=G ◄──B [Oc] ────────► │ group=1 ◄──B [Oc/G] ... │ group=1 ◄──B [Oc/G] + └──────┬──────┘ with └──────┬──────┘ └──────┬──────┘ + ▼[N, Oc, ...] │ [N, Oc/G, ...] │[N, Oc/G, ...] + └──────────────────┐ ... ┌────────────────┘ + ┌▼─────▼┐ + │ cat │ + └───┬───┘ + ▼[N, Oc, ...] + """ + + module: GraphModule + + def _get_tensor_constant_from_node(self, node) -> Parameter | None: + """Get the static data from a given node. If it doesn't have any data, return `None`.""" + if node is None or node.op != "get_attr": + return None + + target_atoms = node.target.split(".") + attr_itr = self.module + for atom in target_atoms: + if not hasattr(attr_itr, atom): + return None + attr_itr = getattr(attr_itr, atom) + return attr_itr + + def _create_and_insert_get_item_node(self, input_node: Node, idx: int) -> Node: + """Create a `GetItem` node which extracts the output of `input_node` on index `idx`. + The `GetItem` is also added to the graph right after the `input_node`. + """ + with self.module.graph.inserting_after(input_node): + get_item_node = self.module.graph.create_node( + "call_function", + operator.getitem, + (input_node, idx), + {}, + ) + + # Assign the `source_fn_stack` and `val` meta fields as they are required for quantization. + get_item_node.meta["source_fn_stack"] = [ + (get_item_node.name, input_node.meta["source_fn_stack"]) + ] + get_item_node.meta["val"] = input_node.meta["val"][idx] + + return get_item_node + + def _create_split_node(self, *split_args) -> Node: + split_target = torch.ops.aten.split.default + split_node = self.module.graph.call_function(split_target, split_args) + + # Assign the `source_fn_stack` and `val` meta fields as they are required for quantization. + split_node.meta["source_fn_stack"] = [(split_node.name, torch.split)] + + # Compute the output shapes for the `split`, and assign the `val` meta. + x_val = split_args[0].meta["val"] + with FakeTensorMode() as mode: + fake_input = FakeTensor.from_tensor( + torch.empty(x_val.shape, dtype=x_val.dtype), mode + ) + output_shapes = [t.shape for t in split_target(fake_input, *split_args[1:])] + split_node.meta["val"] = tuple( + [ + FakeTensor.from_tensor(torch.empty(shape, dtype=x_val.dtype), mode) + for shape in output_shapes + ] + ) + + return split_node + + def _create_convolution_node(self, conv_target, args: tuple) -> Node: + convolution_node = self.module.graph.call_function(conv_target, args) + + # Assign the `source_fn_stack` and `val` meta fields as they are required for quantization. + convolution_node.meta["source_fn_stack"] = [ + (convolution_node.name, torch.convolution) + ] + + # Compute the output shapes for the `convolution`, and assign the `val` meta. + with FakeTensorMode() as mode: + input_shapes = [ + input_.meta["val"].shape if hasattr(input_, "meta") else input_.shape + for input_ in args[:3] + ] + input_dtypes = [ + input_.meta["val"].dtype if hasattr(input_, "meta") else input_.dtype + for input_ in args[:3] + ] + fake_inputs = [ + FakeTensor.from_tensor(torch.empty(shape, dtype=dtype), mode) + for shape, dtype in zip(input_shapes, input_dtypes) + ] + output = conv_target(*fake_inputs, *args[3:]) + convolution_node.meta["val"] = FakeTensor.from_tensor( + torch.empty(output.shape, dtype=output.dtype), mode + ) + + return convolution_node + + def _create_concat_node(self, *cat_args) -> Node: + cat_target = torch.ops.aten.cat.default + concat_node = self.module.graph.call_function(cat_target, cat_args) + + # Assign the `source_fn_stack` and `val` meta fields as they are required for quantization. + concat_node.meta["source_fn_stack"] = [(concat_node.name, torch.cat)] + + # Compute the output shape for the `concat`, and assign the `val` meta. + with FakeTensorMode() as mode: + fake_inputs = [ + FakeTensor.from_tensor( + torch.empty( + input_.meta["val"].shape, dtype=input_.meta["val"].dtype + ), + mode, + ) + for input_ in cat_args[0] + ] + output = cat_target(fake_inputs, *cat_args[1:]) + concat_node.meta["val"] = FakeTensor.from_tensor( + torch.empty(output.shape, dtype=output.dtype), mode + ) + + return concat_node + + def _get_topologically_last_node(self, nodes: list[Node]) -> Node: + """Return the node from `nodes` which appears last in the graph.""" + for node in reversed(self.module.graph.nodes): + if node in nodes: + return node + + raise RuntimeError(f"None of the nodes `{nodes}` are in the graph.") + + def _create_parameter_node_for_data( + self, data: torch.Tensor, name: str, insert_after_node: torch.Node + ) -> torch.Node: + """Create a parameter node in the graph, which contains the provided `data`.""" + new_name = get_new_attr_name_with_prefix(name)(self.module) + + # Create the node for the parameter. + param = torch.nn.Parameter(data, False) + _assign_attr(param, self.module, str(new_name), _AttrKind.PARAMETER) + with self.module.graph.inserting_after(insert_after_node): + static_parameter_node = self.module.graph.get_attr(new_name) + + with FakeTensorMode() as mode: + static_parameter_node.meta["val"] = FakeTensor.from_tensor( + torch.empty(data.shape, dtype=data.dtype), mode + ) + + return static_parameter_node + + def call(self, module: GraphModule): + self.module = module + + def _is_conv(node_: Node): + return node_.op == "call_function" and node_.target in ( + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + ) + + made_changes = False + + for node in self.module.graph.nodes: + if not _is_conv(conv_node := node): + continue + + if len(conv_node.args) < 7: + # The `aten.conv` can have fewer args if the others use default values. + # So in this case, `groups == 1`. + continue + x, w, b, stride, padding, dilation, groups = conv_node.args + + if not group_conv_convertible_into_multiple_convolutions(conv_node, groups): + continue + + if len(x.meta["val"].shape) not in [3, 4]: + # Only 1D and 2D convolutions are supported by the Neutron backend. Don't decompose anything else. + continue + + w_data = self._get_tensor_constant_from_node(w) + b_data = self._get_tensor_constant_from_node(b) + if w_data is None or b_data is None: + continue # Only the standard case with static weights and bias is supported. + + # Create a `split` node to split the main input. + # Split across dimension `1` (channels), `groups` slices of size `input_split_size`. + num_input_channels = x.meta["val"].shape[1] + input_split_sizes = [num_input_channels // groups] * groups + with self.module.graph.inserting_before(conv_node): + split_node = self._create_split_node(x, input_split_sizes, 1) + + # Add `GetItem` nodes to extract the outputs of the `split_node`. + split_getitem_nodes = [ + self._create_and_insert_get_item_node(split_node, i) + for i in range(groups) + ] + + # Split the weights and bias, across dimension `0`, slices of size `weight_split_size`. + weight_split_size = w.meta["val"].shape[0] // groups + split_weights_data = torch.split(w_data, weight_split_size, 0) + split_bias_data = torch.split(b_data, weight_split_size, 0) + + # Turn the weights and biases into parameter nodes containing the data. + # Use a different name for every parameter. The function internally ensures the name's uniqueness, but + # relying on it sometimes causes strange failures when `groups > 5` for some weird reason. + split_weight_nodes = [ + self._create_parameter_node_for_data( + weight_data, w.name + f"_{i}_", split_node + ) + for i, weight_data in enumerate(split_weights_data) + ] + split_bias_nodes = [ + self._create_parameter_node_for_data( + bias_data, b.name + f"_{i}_", split_node + ) + for i, bias_data in enumerate(split_bias_data) + ] + + # Create the `conv` nodes. + with self.module.graph.inserting_after( + self._get_topologically_last_node( + split_getitem_nodes + split_weight_nodes + split_bias_nodes + ) + ): + split_conv_nodes = [ + self._create_convolution_node( + conv_node.target, # Use the same target as the original convolution (1d/2d/3d/...). + (input_getitem, weight, bias, stride, padding, dilation, 1), + ) + for input_getitem, weight, bias in zip( + split_getitem_nodes, split_weight_nodes, split_bias_nodes + ) + ] + + # Create the `cat` node. + with self.module.graph.inserting_after( + self._get_topologically_last_node(split_conv_nodes) + ): + concat_node = self._create_concat_node( + split_conv_nodes, 1 + ) # Concatenate along the channels. + + # Replace the uses of the original convolution with the `concat_node`. + conv_node.replace_all_uses_with(concat_node) + self.module.graph.erase_node(conv_node) + + made_changes = True + + return PassResult(self.module, made_changes) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py index c4b6e6713ca..dff003445ae 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging + import numpy as np import torch @@ -79,8 +81,9 @@ def _is_supported_on_target( return False elif conv_utils.group_conv_convertible_into_multiple_convolutions( node, groups - ): # Separable conv. - # Requires addition of `Split` and `Concatenation` operators, which are not supported on Neutron. + ): # Separable conv. This should never be reached, as the node should have been decomposed into + # multiple parallel convolutions by the `SplitGroupConvolution` pre-processing pass. + logging.warning("Group convolution was not decomposed.") return False else: # Unexpected case (should never happen). return False @@ -324,17 +327,8 @@ def _convert_2d_conv( elif conv_utils.group_conv_convertible_into_multiple_convolutions( t_op, conv_params.groups ): - # Note: by default the Group Separable Convolution is rejected by the Neutron Partitioner, see the - # ConvolutionConveter._is_supported_in_IR() - t_op.builtin_options = conv_2d_options.Conv2D() - - return conv_utils.create_separated_convolutions_based_on_group( - t_op, - conv_params, - self.builder, - self._convert_unpadded_2D, - conv_utils.conv_op_factory, - ) + # This case should have been rejected in the `is_supported_on_target()` method. + raise RuntimeError("Group convolution was not decomposed.") else: # Convert to regular `Conv2D`. diff --git a/backends/nxp/backend/ir/converter/node_converters/shared/conv_utils.py b/backends/nxp/backend/ir/converter/node_converters/shared/conv_utils.py index 3422e214982..5817fd127b3 100755 --- a/backends/nxp/backend/ir/converter/node_converters/shared/conv_utils.py +++ b/backends/nxp/backend/ir/converter/node_converters/shared/conv_utils.py @@ -3,28 +3,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from copy import copy from dataclasses import dataclass -from typing import Callable, cast -import numpy as np - -from executorch.backends.nxp.backend.ir.converter.builder.model_builder import ( - ModelBuilder, -) -from executorch.backends.nxp.backend.ir.converter.conversion import aten_translator from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList -from executorch.backends.nxp.backend.ir.converter.conversion.translator import ( - tf_lite_type_to_numpy, -) -from executorch.backends.nxp.backend.ir.converter.tensor_utils import tensor_has_data -from executorch.backends.nxp.backend.ir.lib.tflite.Padding import Padding from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model -from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import ( - concatenation_options, - conv_2d_options, - split_options, -) + from torch.fx import Node @@ -64,6 +47,9 @@ def group_conv_convertible_into_multiple_convolutions( if group == 1: return False + if group_conv_convertible_as_depthwise(node, group): + return False + _, output_channels = _get_IO_channels(node) if output_channels % group != 0: return False # Unable to split group Conv into separated convolutions because out_channels % group != 0. @@ -90,321 +76,3 @@ def __init__( self.conv_bias_tensor = bias_tensor self.conv_output_tensor = output_tensor self.ops_list = OpsList() - - -ConvBuiltinOptions = conv_2d_options.Conv2D -ConvOpFactory = Callable[ - [ - ConvParameters, - tflite_model.Tensor, - tflite_model.Tensor, - tflite_model.Tensor, - tflite_model.Tensor, - ModelBuilder, - ConvBuiltinOptions, - ], - OpsList, -] -ConvConversionFn = Callable[ - [tflite_model.Operator, ConvParameters], ConvConversionResult -] - - -class _InputTensorsSplitter: - """Splits the tensors of a `Conv2D` operator. Static tensors are split statically, and for dynamic tensors, a - TFLite `Split` operator is added. - """ - - input_tensors: list[tflite_model.Tensor] - weight_tensors: list[tflite_model.Tensor] - bias_tensors: list[tflite_model.Tensor] - split_ops: list[tflite_model.Operator] - - def __init__( - self, - input_tensor: tflite_model.Tensor, - weight_tensor: tflite_model.Tensor, - bias_tensor: tflite_model.Tensor, - groups: int, - builder: ModelBuilder, - ): - self.input_tensors = [] - self.weight_tensors = [] - self.bias_tensors = [] - self.split_ops = [] - - inputs = [ - # input tensor, split by axis, output tensors container - (input_tensor, -1, self.input_tensors), - (weight_tensor, 0, self.weight_tensors), - (bias_tensor, 0, self.bias_tensors), - ] - - for i in inputs: - if tensor_has_data(i[0]): - self._generate_static_tensors(builder, groups, i[0], i[1], i[2]) - else: - self._generate_dynamic_tensors(builder, groups, i[0], i[1], i[2]) - - def _generate_dynamic_tensors( - self, builder, groups, split_tensor, axis, target_list - ): - quantization = None - if split_tensor.quantization is not None: - if split_tensor.quantization.is_per_channel(): - scale = np.split( - np.array(split_tensor.quantization.scale.vector, "float32"), groups - ) - zero_point = np.split( - np.array(split_tensor.quantization.zero_point.vector, "int32"), - groups, - ) - quantization = [ - tflite_model.Quantization( - scale=tflite_model.Scale(s), - zero_point=tflite_model.ZeroPoint(zp), - ) - for s, zp in zip(scale, zero_point) - ] - else: - quantization = [split_tensor.quantization] * groups - - split_op = self._create_split_op(builder, groups, split_tensor, axis) - - new_tensor_shape = split_tensor.shape.vector.copy() - new_tensor_shape[axis] = new_tensor_shape[axis] // groups - - for i in range(groups): - conv_split_tensor = builder.duplicate_tensor( - split_tensor, name_suffix="_group_" + str(i) - ) - conv_split_tensor.shape = tflite_model.Shape(new_tensor_shape) - if quantization is not None: - conv_split_tensor.quantization = copy(quantization[i]) - - split_op.tmp_outputs.append(conv_split_tensor) - target_list.append(conv_split_tensor) - self.split_ops.append(split_op) - - # noinspection PyMethodMayBeStatic - def _generate_static_tensors( - self, builder, groups, split_tensor, axis, target_list - ): - quantization = None - if split_tensor.quantization is not None: - if split_tensor.quantization.is_per_channel(): - scale = np.split( - np.array(split_tensor.quantization.scale.vector, "float32"), groups - ) - zero_point = np.split( - np.array(split_tensor.quantization.zero_point.vector, "int32"), - groups, - ) - quantization = [ - tflite_model.Quantization( - scale=tflite_model.Scale(s), - zero_point=tflite_model.ZeroPoint(zp), - ) - for s, zp in zip(scale, zero_point) - ] - else: - quantization = [split_tensor.quantization] * groups - - input_data = np.split(split_tensor.tmp_buffer.data, groups, axis) - - for i in range(len(input_data)): - tensor_name = split_tensor.name + "_group_" + str(i) - conv_input_tensor = builder.create_tensor_for_data( - input_data[i], tensor_name - ) - if quantization is not None: - conv_input_tensor.quantization = copy(quantization[i]) - - target_list.append(conv_input_tensor) - - # noinspection PyMethodMayBeStatic - def _create_split_op(self, builder, groups, input_tensor, axis): - axis_tensor = builder.create_tensor_for_data( - np.asarray([axis], np.int32), "split_dim_" - ) - input_split_op = tflite_model.Operator( - builtin_options=split_options.Split(groups) - ) - input_split_op.tmp_inputs = [axis_tensor, input_tensor] - - return input_split_op - - def get_input_tensor(self, idx) -> tflite_model.Tensor: - return self.input_tensors[idx] - - def get_weight_tensor(self, idx) -> tflite_model.Tensor: - return self.weight_tensors[idx] - - def get_bias_tensor(self, idx) -> tflite_model.Tensor: - return self.bias_tensors[idx] - - def get_ops(self) -> list[tflite_model.Operator]: - return self.split_ops - - -class _OutputTensorsCombiner: - """Handles creation and aggregation of the TFLite Conv2D output tensors. - Aggregation is done with `Concatenation` op. - """ - - output_tensors: list[tflite_model.Tensor] - concat_op: tflite_model.Operator - - def __init__(self, output_tensor, groups, builder): - self.output_tensors = [] - combine_axis = -1 - - new_conv_output_shape = output_tensor.shape.vector.copy() - new_conv_output_shape[combine_axis] = ( - new_conv_output_shape[combine_axis] // groups - ) - conv_output_shape = tflite_model.Shape(new_conv_output_shape) - - self.concat_op = tflite_model.Operator( - builtin_options=concatenation_options.Concatenation(combine_axis) - ) - self.concat_op.tmp_outputs = [output_tensor] - - for i in range(groups): - tensor_name = output_tensor.name + "_group_" + str(i) - output_tensor = builder.duplicate_tensor(output_tensor, tensor_name) - output_tensor.shape = conv_output_shape - - self.output_tensors.append(output_tensor) - self.concat_op.tmp_inputs.append(output_tensor) - - def get_output_tensor(self, idx): - return self.output_tensors[idx] - - def get_ops(self): - return [self.concat_op] - - -def build_input_tensor_padding( - t_op, conv_params: ConvParameters, builder, input_idx=0 -) -> (Padding, tflite_model.Operator | None): - """Build padding for input tensor of Conv2D op 't_op'.""" - - tfl_padding, explicit_padding = aten_translator.convert_padding(conv_params.padding) - if explicit_padding is not None: - # Must add extra 'Pad' operator, which adds 0s (or `zero_point` for the quantized case). - input_quantization = t_op.tmp_inputs[0].quantization - pad_value = ( - None - if input_quantization is None - else np.array(input_quantization.zero_point[0]).astype( - tf_lite_type_to_numpy(t_op.tmp_inputs[0].type) - ) - ) - return tfl_padding, builder.create_pad_operator_before( - t_op, input_idx, explicit_padding, pad_value - ) - - return tfl_padding, None - - -def conv_op_factory( - conv_params: ConvParameters, - input_tensor: tflite_model.Tensor, - weight_tensor: tflite_model.Tensor, - bias_tensor: tflite_model.Tensor, - output_tensor: tflite_model.Tensor, - builder, - builtin_options, -) -> OpsList: - """Build padded 'Conv2D' TFLite operator. Padding is realized by 'builtin_options.padding' definition and by - optional prepended 'Pad' operator. - """ - - conv_op = tflite_model.Operator(builtin_options=copy(builtin_options)) - conv_op.tmp_inputs = [input_tensor, weight_tensor, bias_tensor] - conv_op.tmp_outputs = [output_tensor] - - padding, pad_op = build_input_tensor_padding(conv_op, conv_params, builder) - conv_op.builtin_options.padding = padding - - if pad_op is not None: - return OpsList(pre_ops=[pad_op], middle_op=conv_op) - else: - return OpsList(middle_op=conv_op) - - -# noinspection GrazieInspection -def create_separated_convolutions_based_on_group( - t_op: tflite_model.Operator, - conv_params: ConvParameters, - builder: ModelBuilder, - conv_conversion_fn: ConvConversionFn, - conv_op_factory_fn: ConvOpFactory, -) -> list[tflite_model.Operator]: - """Build a subgraph with multiple TFLite Conv2D operators that replace an `aten.convolution` operator with 'group' - attribute higher than one. The number of new Conv2D operators corresponds to the number of groups. Input - tensors of the Aten operator are split and distributed into related convolution operators. Outputs are then - concatenated back together. - - Example: 'aten.convolution' operator with group=2 converted into TFLite subgraph will have - the following structure (tensor dimensions are just for illustrative purposes): - - │ (1,4,4,48) - ┌───▼──┐ - │Split │ - └┬────┬┘ - (1,4,4,24) │ │ (1,4,4,24) - ┌─────▼┐ ┌▼─────┐ - │Conv2D│ │Conv2D│ - └────┬─┘ └─┬────┘ - (1,4,4,18)│ │(1,4,4,18) - ┌─▼──────▼──┐ - │Concatenate│ - └─────┬─────┘ - │ (1,4,4,36) - ▼ - """ - - conversion_result = conv_conversion_fn(t_op, conv_params) - - splitter = _InputTensorsSplitter( - conversion_result.conv_input_tensor, - conversion_result.conv_weight_tensor, - conversion_result.conv_bias_tensor, - conv_params.groups, - builder, - ) - combiner = _OutputTensorsCombiner( - conversion_result.conv_output_tensor, conv_params.groups, builder - ) - - conv_ops = [] - for i in range(conv_params.groups): - input_tensor = splitter.get_input_tensor(i) - weight_tensor = splitter.get_weight_tensor(i) - bias_tensor = splitter.get_bias_tensor(i) - output_tensor = combiner.get_output_tensor(i) - - conv_builtin_options = cast( - ConvBuiltinOptions, conversion_result.ops_list.middle_op.builtin_options - ) - conv_ops_list = conv_op_factory_fn( - conv_params, - input_tensor, - weight_tensor, - bias_tensor, - output_tensor, - builder, - conv_builtin_options, - ) - - conv_ops.extend(conv_ops_list.flatten()) - - return ( - conversion_result.ops_list.pre_ops # `Pad` operator - + splitter.get_ops() - + conv_ops - + combiner.get_ops() # Split, Conv2D, Concatenate ops - + conversion_result.ops_list.post_ops - ) # Currently not used diff --git a/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py index 68550692049..745b26ef8ff 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py @@ -240,160 +240,6 @@ def test_conv1d_quant_conversion__depthwise__padded( ) # `Conv` input zp. -@pytest.mark.parametrize("stride", [1, 2]) -@pytest.mark.parametrize("dilation", [2, 1]) -@pytest.mark.parametrize("kernel_size", [(3,), 4]) -@pytest.mark.parametrize( - "input_shape, group, out_channels", [((1, 4, 12), 2, 2), ((1, 16, 9), 4, 16)] -) -def test_conv1d_conversion__separated( - input_shape, group, out_channels, stride, dilation, kernel_size, mocker -): - model = Conv1dModule( - group=group, - in_channels=input_shape[1], - out_channels=out_channels, - stride=stride, - dilation=dilation, - kernel_size=kernel_size, - ) - ops_spy = mocker.spy(ModelBuilder, "finish") - - # Run conversion - edge_program = to_edge_program(model, input_shape).exported_program() - - input_data = np.random.random(input_shape).astype(np.float32) - - convert_run_compare( - edge_program, - input_data, - tflite_input_preprocess=ToChannelLastPreprocess(), - tflite_output_preprocess=ToChannelFirstPreprocess(), - atol=3.0e-7, - ) - - # Capture IR model ops - ops = ops_spy.spy_return.sub_graphs[0].operators.vector - - assert ( - len(ops) == 1 + 1 + group + 1 + 1 - ) # Reshape + Split -> Conv (group times) -> Concat + Reshape - assert ops[0].builtin_options.operator_type == BuiltinOperator.RESHAPE - assert ops[1].builtin_options.operator_type == BuiltinOperator.SPLIT - for op in ops[3:-2]: - assert op.builtin_options.operator_type == BuiltinOperator.CONV_2D - assert ops[-2].builtin_options.operator_type == BuiltinOperator.CONCATENATION - assert ops[-1].builtin_options.operator_type == BuiltinOperator.RESHAPE - - -@pytest.mark.parametrize("stride", [1, 2]) -@pytest.mark.parametrize("dilation", [2, 1]) -@pytest.mark.parametrize("kernel_size", [(3,), 4]) -@pytest.mark.parametrize("padding", [2, (1,)]) -@pytest.mark.parametrize( - "input_shape, group, out_channels", [((1, 4, 12), 2, 2), ((1, 16, 9), 4, 16)] -) -def test_conv1d_conversion__separated__padded( - input_shape, group, out_channels, stride, dilation, kernel_size, padding, mocker -): - model = Conv1dModule( - group=group, - in_channels=input_shape[1], - out_channels=out_channels, - stride=stride, - dilation=dilation, - kernel_size=kernel_size, - padding=padding, - ) - ops_spy = mocker.spy(ModelBuilder, "finish") - - # Run conversion - edge_program = to_edge_program(model, input_shape).exported_program() - - input_data = np.random.random(input_shape).astype(np.float32) - - convert_run_compare( - edge_program, - input_data, - tflite_input_preprocess=ToChannelLastPreprocess(), - tflite_output_preprocess=ToChannelFirstPreprocess(), - atol=3.0e-7, - ) - - # Capture IR model ops - ops = ops_spy.spy_return.sub_graphs[0].operators.vector - - assert ( - len(ops) == 1 + 1 + 2 * group + 1 + 1 - ) # Reshape + Split -> Pad + Conv (group times) -> Concat + Reshape - assert ops[0].builtin_options.operator_type == BuiltinOperator.RESHAPE - assert ops[1].builtin_options.operator_type == BuiltinOperator.SPLIT - for op in ops[2:-3:2]: - assert op.builtin_options.operator_type == BuiltinOperator.PAD - for op in ops[3:-2:2]: - assert op.builtin_options.operator_type == BuiltinOperator.CONV_2D - assert ops[-2].builtin_options.operator_type == BuiltinOperator.CONCATENATION - assert ops[-1].builtin_options.operator_type == BuiltinOperator.RESHAPE - - -@pytest.mark.parametrize("stride", [1, 2]) -@pytest.mark.parametrize("dilation", [2, 1]) -@pytest.mark.parametrize("kernel_size", [(1,), (3,)]) -@pytest.mark.parametrize( - "input_shape, group, out_channels", [((1, 4, 12), 2, 2), ((1, 16, 9), 4, 16)] -) -def test_conv1d_quant_conversion__separated( - input_shape, group, out_channels, stride, dilation, kernel_size -): - model = Conv1dModule( - group=group, - in_channels=input_shape[1], - out_channels=out_channels, - stride=stride, - dilation=dilation, - kernel_size=kernel_size, - ) - - # Run conversion - edge_program = to_quantized_edge_program(model, input_shape).exported_program() - - nodes = list(edge_program.graph.nodes) - assert len(nodes) == 11 - assert ( - nodes[7].target.__name__ == "aten.convolution.default" - ) # Convolution not delegated. - - -@pytest.mark.parametrize("stride", [1, 2]) -@pytest.mark.parametrize("dilation", [2, 1]) -@pytest.mark.parametrize("kernel_size", [(1,), (3,)]) -@pytest.mark.parametrize("padding", [(1,), 2]) -@pytest.mark.parametrize( - "input_shape, group, out_channels", [((1, 4, 12), 2, 2), ((1, 16, 9), 4, 16)] -) -def test_conv1d_quant_conversion__separated__padded( - input_shape, group, out_channels, stride, dilation, kernel_size, padding -): - model = Conv1dModule( - group=group, - in_channels=input_shape[1], - out_channels=out_channels, - stride=stride, - dilation=dilation, - kernel_size=kernel_size, - padding=padding, - ) - - # Run conversion - edge_program = to_quantized_edge_program(model, input_shape).exported_program() - - nodes = list(edge_program.graph.nodes) - assert len(nodes) == 11 - assert ( - nodes[7].target.__name__ == "aten.convolution.default" - ) # Convolution not delegated. - - @pytest.mark.parametrize( "model, input_shape", [ @@ -649,170 +495,3 @@ def test_conv2d_conversion__depthwise__padded__quantized(padding, mocker): len(nodes) == 7 ) # input, Quant, lowered_module, delegate_call, getitem, Deq, output assert nodes[2].target == "lowered_module_0" - - # Make sure the padding used the `zero-point`. - assert ( - ops[0].tmp_inputs[2].tmp_buffer.data.item() - == ops[0].tmp_outputs[0].quantization.zero_point[0] - ) - - -@pytest.mark.parametrize("stride", [1, 2]) -@pytest.mark.parametrize("dilation", [1, 2]) -@pytest.mark.parametrize( - "input_shape, group, out_channels", - [((1, 4, 12, 12), 2, 2), ((2, 3, 8, 15), 3, 6), ((11, 16, 9, 8), 4, 16)], -) -def test_conv2d_conversion__separated( - input_shape, group, out_channels, stride, dilation, mocker -): - edge_program = to_edge_program( - Conv2dModule( - group=group, - in_channels=input_shape[1], - out_channels=out_channels, - stride=stride, - dilation=dilation, - ), - input_shape, - ).exported_program() - - input_data = np.random.random(input_shape).astype(np.float32) - - spy = mocker.spy(ModelBuilder, "finish") - convert_run_compare( - edge_program, - input_data, - tflite_input_preprocess=ToChannelLastPreprocess(), - tflite_output_preprocess=ToChannelFirstPreprocess(), - atol=3.0e-7, - ) - - ops = spy.spy_return.sub_graphs[0].operators.vector - assert len(ops) == 1 + group + 1 # Split -> Conv (group times) -> Concat - assert ops[0].builtin_options.operator_type == BuiltinOperator.SPLIT - for op in ops[1:-1]: - assert op.builtin_options.operator_type == BuiltinOperator.CONV_2D - assert ops[-1].builtin_options.operator_type == BuiltinOperator.CONCATENATION - - -@pytest.mark.parametrize("stride", [1, 2]) -@pytest.mark.parametrize("dilation", [1, 2]) -@pytest.mark.parametrize( - "input_shape, group, out_channels", - [((1, 4, 12, 12), 2, 2), ((2, 3, 17, 9), 3, 6), ((11, 16, 9, 8), 4, 16)], -) -def test_conv2d_conversion__separated__quantized( - input_shape, group, out_channels, stride, dilation -): - - # Note: The generic group convolution is not yet supported by Neutron Converter. Once supported, the - # commented out code allows usuall testing flow for this test-case. - # spy = mocker.spy(ModelBuilder, 'finish') - - # The convert_run_compare skips the partitioner call, hence conversion failure indicated by exception - # is expected behavior now. - edge_program = to_quantized_edge_program( - Conv2dModule( - group=group, - in_channels=input_shape[1], - out_channels=out_channels, - stride=stride, - dilation=dilation, - ), - tuple(input_shape), - target="imxrt700", - ).exported_program() - - # ops = spy.spy_return.sub_graphs[0].operators.vector - # assert len(ops) == 1 + group + 1 # Split -> Conv (group times) -> Concat - # assert ops[0].builtin_options.operator_type == BuiltinOperator.SPLIT - # for op in ops[1:-1]: - # assert op.builtin_options.operator_type == BuiltinOperator.CONV_2D - # assert ops[-1].builtin_options.operator_type == BuiltinOperator.CONCATENATION - - nodes = list(edge_program.graph.nodes) - assert len(nodes) == 11 - assert ( - nodes[7].target.__name__ == "aten.convolution.default" - ) # Convolution not delegated. - - -@pytest.mark.parametrize("padding", [1, 2]) -@pytest.mark.parametrize( - "input_shape, group, out_channels", - [((1, 4, 12, 12), 2, 2), ((2, 3, 4, 5), 3, 6), ((11, 16, 9, 8), 4, 16)], -) -def test_conv2d_conversion__separated__padded( - input_shape, group, out_channels, padding, mocker -): - edge_program = to_edge_program( - Conv2dModule( - group=group, - in_channels=input_shape[1], - out_channels=out_channels, - padding=padding, - ), - input_shape, - ).exported_program() - - input_data = np.random.random(input_shape).astype(np.float32) - - spy = mocker.spy(ModelBuilder, "finish") - - convert_run_compare( - edge_program, - input_data, - tflite_input_preprocess=ToChannelLastPreprocess(), - tflite_output_preprocess=ToChannelFirstPreprocess(), - atol=3.0e-7, - ) - - conversion_result = spy.spy_return - ops = conversion_result.sub_graphs[0].operators.vector - assert len(ops) == 1 + 2 * group + 1 # Split -> Pad + Conv (group times) -> Concat - assert ops[0].builtin_options.operator_type == BuiltinOperator.SPLIT - for op in ops[1:-2:2]: - assert op.builtin_options.operator_type == BuiltinOperator.PAD - for op in ops[2:-1:2]: - assert op.builtin_options.operator_type == BuiltinOperator.CONV_2D - assert ops[-1].builtin_options.operator_type == BuiltinOperator.CONCATENATION - - -@pytest.mark.parametrize("padding", [1, 2]) -@pytest.mark.parametrize( - "input_shape, group, out_channels", - [((1, 4, 12, 12), 2, 2), ((2, 3, 4, 5), 3, 6), ((11, 16, 9, 8), 4, 16)], -) -def test_conv2d_conversion__separated__padded__quantized( - input_shape, group, out_channels, padding -): - - # Note: The generic group convolution is not yet supported by Neutron Converter. Once supported, the - # commented out code allows usuall testing flow for this test-case. - # spy = mocker.spy(ModelBuilder, 'finish') - - edge_program = to_quantized_edge_program( - Conv2dModule( - group=group, - in_channels=input_shape[1], - out_channels=out_channels, - padding=padding, - ), - tuple(input_shape), - ).exported_program() - - # ops = spy.spy_return.sub_graphs[0].operators.vector - # assert len(ops) == 1 + 2 * group + 1 # Split -> Pad + Conv (group times) -> Concat - # assert ops[0].builtin_options.operator_type == BuiltinOperator.SPLIT - # for op in ops[1:-2:2]: - # assert op.builtin_options.operator_type == BuiltinOperator.PAD - # for op in ops[2:-1:2]: - # assert op.builtin_options.operator_type == BuiltinOperator.CONV_2D - # assert ops[-1].builtin_options.operator_type == BuiltinOperator.CONCATENATION - - nodes = list(edge_program.graph.nodes) - assert len(nodes) == 11 - assert ( - nodes[7].target.__name__ == "aten.convolution.default" - ) # Convolution not delegated. diff --git a/backends/nxp/tests/models.py b/backends/nxp/tests/models.py index 7f552d185e3..bdad9ddc4b4 100644 --- a/backends/nxp/tests/models.py +++ b/backends/nxp/tests/models.py @@ -67,6 +67,35 @@ def forward(self, x): return self.conv(x) +class Conv3dModule(torch.nn.Module): + def __init__( + self, + bias: bool = True, + dilation: Union[int, tuple[int, int]] = 1, + in_channels: int = 4, + kernel_size: Union[int, tuple[int, int]] = 3, + out_channels: int = 8, + padding: Union[str, int, Collection[int]] = 0, + stride: Union[int, tuple[int, int]] = 2, + group: int = 1, + ): + super().__init__() + + self.conv = torch.nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + groups=group, + ) + + def forward(self, x): + return self.conv(x) + + class Conv2dAndMaxPool2DModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/nxp/tests/test_split_group_convolution.py b/backends/nxp/tests/test_split_group_convolution.py new file mode 100644 index 00000000000..5786f00de07 --- /dev/null +++ b/backends/nxp/tests/test_split_group_convolution.py @@ -0,0 +1,237 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from copy import deepcopy + +import numpy as np +import torch + +from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import ( + NeutronAtenPassManager, +) +from executorch.backends.nxp.aten_passes.split_group_convolution import ( + SplitGroupConvolution, +) +from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program +from executorch.backends.nxp.tests.executors import graph_contains_any_of_ops +from executorch.backends.nxp.tests.models import ( + Conv1dModule, + Conv2dModule, + Conv3dModule, +) +from executorch.exir.dialects._ops import ops as exir_ops +from parameterized import parameterized + + +class TestSplitGroupConvolution(unittest.TestCase): + __test__ = False # Prevent interfering with PyTest tests. + + @classmethod + def setUp(cls): + torch.manual_seed(23) + np.random.seed(42) + + @parameterized.expand( + [ + ["group = 2", [1, 16, 10, 10], 2], + ["group = 3", [1, 24, 10, 10], 3], + ["group = 8", [1, 8, 10, 10], 8], + ] + ) + def test_split_group_convolution__2d(self, _, input_shape: list[int], group: int): + example_input = (torch.ones(input_shape),) + + module = Conv2dModule( + bias=True, + in_channels=input_shape[1], + out_channels=8 + * group, # Make sure the output channels are multiple of 8, so the `cat` can be delegated. + group=group, + stride=1, + ) + graph_module = torch.export.export(module, example_input, strict=True).module() + original_module = deepcopy(graph_module) + + modified_module = NeutronAtenPassManager([SplitGroupConvolution()])( + graph_module + ).graph_module + + # Make sure the fusion worked. + original_nodes = list(original_module.graph.nodes) + modified_nodes = list(modified_module.graph.nodes) + + assert len(original_nodes) == 5 + assert original_nodes[3].target == torch.ops.aten.conv2d.default + assert original_nodes[3].args[-1] == group + + assert len(modified_nodes) == 4 + group * 4 + assert modified_nodes[1].target == torch.ops.aten.split.default + for node in modified_nodes[2 + 3 * group : 4 + 3 * group]: + assert node.target == torch.ops.aten.conv2d.default + assert node.args[-1] == 1 # Groups. + assert modified_nodes[-2].target == torch.ops.aten.cat.default + + # Verify that the behavior has not changed. + input_data = torch.randn(input_shape, dtype=torch.float32) + out1 = original_module(input_data).detach().numpy() + out2 = modified_module(input_data).detach().numpy() + assert np.allclose(out1, out2, atol=2.0e-7) + + # Make sure the graph can be correctly quantized and lowered to edge. + ep = to_quantized_edge_program( + modified_module, tuple(input_shape) + ).exported_program() + nodes = list(ep.graph.nodes) + assert nodes[-5].name == "lowered_module_0" + assert not graph_contains_any_of_ops( + ep.graph, + [exir_ops.edge.aten.convolution.default, exir_ops.edge.aten.cat.default], + ) + + @parameterized.expand( + [ + ["group = 2", [1, 16, 10], 2], + ["group = 3", [1, 24, 10], 3], + ["group = 6", [1, 24, 10], 6], + ] + ) + def test_split_group_convolution__1d(self, _, input_shape: list[int], group: int): + example_input = (torch.ones(input_shape),) + + module = Conv1dModule( + bias=True, + in_channels=input_shape[1], + out_channels=8 + * group, # Make sure the output channels are multiple of 8, so the `cat` can be delegated. + group=group, + stride=1, + ) + graph_module = torch.export.export(module, example_input).module() + original_module = deepcopy(graph_module) + + modified_module = NeutronAtenPassManager([SplitGroupConvolution()])( + graph_module + ).graph_module + + # Make sure the fusion worked. + original_nodes = list(original_module.graph.nodes) + modified_nodes = list(modified_module.graph.nodes) + + assert len(original_nodes) == 5 + assert original_nodes[3].target == torch.ops.aten.conv1d.default + assert original_nodes[3].args[-1] == group + + assert len(modified_nodes) == 4 + group * 4 + assert modified_nodes[1].target == torch.ops.aten.split.default + for node in modified_nodes[2 + 3 * group : 4 + 3 * group]: + assert node.target == torch.ops.aten.conv1d.default + assert node.args[-1] == 1 # Groups. + assert modified_nodes[-2].target == torch.ops.aten.cat.default + + # Verify that the behavior has not changed. + input_data = torch.randn(input_shape, dtype=torch.float32) + out1 = original_module(input_data).detach().numpy() + out2 = modified_module(input_data).detach().numpy() + assert np.allclose(out1, out2, atol=2.0e-7) + + # Make sure the graph can be correctly quantized and lowered to edge. + ep = to_quantized_edge_program( + modified_module, tuple(input_shape) + ).exported_program() + nodes = list(ep.graph.nodes) + assert nodes[-5].name == "lowered_module_0" + assert not graph_contains_any_of_ops( + ep.graph, + [exir_ops.edge.aten.convolution.default, exir_ops.edge.aten.cat.default], + ) + + @parameterized.expand( + [ + ["group = 2", [1, 16, 10, 10, 10], 2], + ] + ) + def test_split_group_convolution__3d(self, _, input_shape: list[int], group: int): + example_input = (torch.ones(input_shape),) + + module = Conv3dModule( + bias=True, + in_channels=input_shape[1], + out_channels=8 + * group, # Make sure the output channels are multiple of 8, so the `cat` can be delegated. + group=group, + ) + graph_module = torch.export.export(module, example_input).module() + original_module = deepcopy(graph_module) + + modified_module = NeutronAtenPassManager([SplitGroupConvolution()])( + graph_module + ).graph_module + + # Verify that the pass has NOT made any changes, as it is disabled for 3D convolution. + original_nodes = list(original_module.graph.nodes) + modified_nodes = list(modified_module.graph.nodes) + + assert len(original_nodes) == len(modified_nodes) + for original_node, modified_node in zip(original_nodes, modified_nodes): + assert original_node.name == modified_node.name + assert original_node.target == modified_node.target + + # Verify that the behavior has not changed. + input_data = torch.randn(input_shape, dtype=torch.float32) + out1 = original_module(input_data).detach().numpy() + out2 = modified_module(input_data).detach().numpy() + assert np.allclose(out1, out2) + + def test_split_group_convolution__applied_by_default(self): + input_shape = [1, 16, 10, 10] + group = 2 + example_input = (torch.ones(input_shape),) + + module = Conv2dModule( + in_channels=input_shape[1], + out_channels=8 + * group, # Make sure the output channels are multiple of 8, so the `cat` can be delegated. + group=group, + stride=1, + ) + graph_module = torch.export.export(module, example_input).module() + original_module = deepcopy(graph_module) + + modified_module = NeutronAtenPassManager()( + graph_module + ).graph_module # Default passes. + + # Make sure the fusion worked. + original_nodes = list(original_module.graph.nodes) + modified_nodes = list(modified_module.graph.nodes) + + assert len(original_nodes) == 5 + assert original_nodes[3].target == torch.ops.aten.conv2d.default + assert original_nodes[3].args[-1] == group + + assert len(modified_nodes) == 4 + group * 4 + assert modified_nodes[1].target == torch.ops.aten.split.default + for node in modified_nodes[2 + 3 * group : 4 + 3 * group]: + assert node.target == torch.ops.aten.conv2d.default + assert node.args[-1] == 1 # Groups. + assert modified_nodes[-2].target == torch.ops.aten.cat.default + + # Verify that the behavior has not changed. + input_data = torch.randn(input_shape, dtype=torch.float32) + out1 = original_module(input_data).detach().numpy() + out2 = modified_module(input_data).detach().numpy() + assert np.allclose(out1, out2, atol=5.0e-7) + + # Make sure the graph can be correctly quantized and lowered to edge. + ep = to_quantized_edge_program( + modified_module, tuple(input_shape) + ).exported_program() + nodes = list(ep.graph.nodes) + assert nodes[-5].name == "lowered_module_0" + assert not graph_contains_any_of_ops( + ep.graph, + [exir_ops.edge.aten.convolution.default, exir_ops.edge.aten.cat.default], + )