From 7b3bd8fd4c4fae74cba02a2b05b79741a6a06646 Mon Sep 17 00:00:00 2001 From: Sebastian Larsson Date: Mon, 15 Sep 2025 11:25:21 +0200 Subject: [PATCH] Arm backend: Add docstrings for bmm and conv2d operators Change-Id: I603e09f79de4d87c0fdfe23d0a5bb87c14a5b61f Signed-off-by: Sebastian Larsson --- backends/arm/operators/op_bmm.py | 11 +++++++- backends/arm/operators/op_conv2d.py | 44 +++++++++++++++++++++++------ 2 files changed, 46 insertions(+), 9 deletions(-) diff --git a/backends/arm/operators/op_bmm.py b/backends/arm/operators/op_bmm.py index 382386ffa26..2636a08d7c5 100644 --- a/backends/arm/operators/op_bmm.py +++ b/backends/arm/operators/op_bmm.py @@ -5,6 +5,8 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe +"""Provide a visitor for lowering batched matmul (BMM) to TOSA.""" + from typing import Any, List import torch @@ -30,6 +32,13 @@ @register_node_visitor class BMMVisitor(NodeVisitor): + """Provide a visitor that lowers ``aten.bmm`` to TOSA ``MATMUL``. + + INT8 accumulates into INT32; add a rescale to INT8 using SINGLE_ROUND + rounding and output zero-point. + + """ + target = "aten.bmm.default" tosa_specs = [ @@ -47,7 +56,7 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - + """Define the TOSA ``MATMUL`` operator and optional rescale.""" import serializer.tosa_serializer as ts # type: ignore validate_num_inputs(self.target, inputs, 2) diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py index 6bfe0ab21eb..3d3bbb48aaf 100644 --- a/backends/arm/operators/op_conv2d.py +++ b/backends/arm/operators/op_conv2d.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe +"""Provide a visitor for lowering 2D convolution to TOSA (INT/FP).""" + import itertools from typing import Any, List @@ -28,6 +30,12 @@ @register_node_visitor class Conv2dVisitor(NodeVisitor): + """Provide a visitor that lowers ``aten.convolution`` to TOSA. + + Map to ``CONV2D`` or ``DEPTHWISE_CONV2D`` as appropriate. + + """ + target = "aten.convolution.default" tosa_specs = [ @@ -38,13 +46,32 @@ class Conv2dVisitor(NodeVisitor): def __init__(self, *args): super().__init__(*args) - # torch.nn.Conv2d does not require the result of - # `(input + 2 * pad - dilation * (weight - 1) - 1) / stride` - # to be an integer, but tosa currently strictly require this property. - # This function adjusts the pad value to meet the requirement. def adjust_pad_if_needed( self, input_size: int, input_weight: int, stride: int, pad: int, dilation: int ) -> int: + """Adjust padding to satisfy TOSA's integer output-size requirement. + + Torch ``Conv2d`` does not require the result of + ``(input + 2 * pad - dilation * (weight - 1) - 1) / stride`` to be an + integer, but TOSA does. This helper reduces the provided padding so + that the expression becomes divisible by ``stride``. + + Args: + input_size (int): Spatial input size along the dimension (H or W). + input_weight (int): Kernel size along the same dimension. + stride (int): Stride along the same dimension. + pad (int): Padding value to adjust (bottom or right after duplication). + dilation (int): Dilation along the same dimension. + + Returns: + int: Adjusted padding value that yields an integer output size. + + Raises: + RuntimeError: If the required adjustment exceeds the provided + padding, which should be handled by the ``SizeAdjustInputPass`` + pass instead. + + """ mod_remainder = ( input_size + 2 * pad - dilation * (input_weight - 1) - 1 ) % stride @@ -55,7 +82,8 @@ def adjust_pad_if_needed( if mod_remainder > pad: raise RuntimeError( - "This case should be handled by the SizeAdjustConv2d pass, is it enabled?" + "This case should be handled by the SizeAdjustInputPass pass, " + "is it enabled?" ) return pad - mod_remainder @@ -66,7 +94,7 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - + """Define the TOSA CONV2D/DEPTHWISE_CONV2D operator and post-rescale.""" import serializer.tosa_serializer as ts # type: ignore from tosa.RoundingMode import RoundingMode # type: ignore @@ -133,7 +161,7 @@ def define_node( in_channels = input.shape[1] out_channels = weight.shape[0] if (in_channels == group.number) and (out_channels % in_channels) == 0: - """Depthwise convolution case""" + """Depthwise convolution case.""" # Reshape torch shape format of weight tensor to tosa required format. # https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d m_length = int(out_channels / in_channels) @@ -178,7 +206,7 @@ def define_node( acc_type=acc_type, ) else: - """Regular convolution case""" + """Regular convolution case.""" tosa_op = ts.TosaOp.Op().CONV2D weight_name = weight.name