Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 38 additions & 9 deletions backends/arm/operator_support/convolution_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Declare operator support for ``aten.convolution`` in TOSA.

Provide general checks and hardware-specific constraints (e.g., U55 subset) for
convolution nodes prior to delegation to the TOSA backend.

"""

from typing import cast

Expand All @@ -18,15 +24,24 @@

@register_tosa_support_check
class ConvolutionSupported(SupportedTOSAOperatorCheck):
"""Provide TOSA support check for convolutions."""

targets = [exir_ops.edge.aten.convolution.default]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
def is_node_tosa_supported(
self, node: fx.Node, tosa_spec: TosaSpecification
) -> bool:
"""Return True if the node is supported by TOSA.

Reject transposed convolutions and convolutions with non-zero output
padding. Apply additional hardware-specific constraints for U55.

"""
# Not implemented
transposed = cast(bool, node.args[6])
output_padding = cast(list[int], node.args[7])
Expand All @@ -46,9 +61,19 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
else:
return True

def _is_node_supported_u55(self, node: fx.Node):
"""Hardware constraints for Ethos-U-55 case, Vela 4.2.0 (25.02 release)"""
def _is_node_supported_u55(self, node: fx.Node) -> bool:
"""Enforce Ethos-U55-specific constraints (Vela 4.2.0).

Check channel dimensions, kernel sizes, and stride/pad/dilation
combinations permitted on U55.

Args:
node (fx.Node): Convolution node to validate.

Returns:
bool: True if supported; otherwise, False.

"""
shape_in = cast(torch.Tensor, node.all_input_nodes[0].meta["val"]).shape
shape_out = node.meta["val"].shape
kernel = cast(fx.Node, node.args[1]).meta["val"].shape
Expand Down Expand Up @@ -98,13 +123,17 @@ def _is_node_supported_u55(self, node: fx.Node):
return True

def _stride_condition(self, node: fx.Node) -> bool:
"""This condition is somewhat complex but boils down
to not supporting stride > 3, unless we have some special conditions.
This condition is a simplified, relaxed version of the hardware constraint,
since the actual constraint requires information not available
here (without a lot of work).
"""Check a simplified stride/padding/dilation constraint.

Disallow strides greater than 3 unless there is no padding and the
dilation is 1. For 3D convolutions, enforce ``stride_z <= 1``.

Args:
node (fx.Node): Convolution node to evaluate.

Returns:
bool: True if the condition is satisfied.

This means that we might accept ops that are not actually supported.
"""
strides = cast(list[int], node.args[3])
has_padding = any(pad > 0 for pad in cast(list[int], node.args[4]))
Expand Down
Loading