diff --git a/backends/arm/operators/op_tosa_rescale.py b/backends/arm/operators/op_tosa_rescale.py index 58e4084b227..ae87dcc9c31 100644 --- a/backends/arm/operators/op_tosa_rescale.py +++ b/backends/arm/operators/op_tosa_rescale.py @@ -23,13 +23,31 @@ from torch.fx import Node -# TOSA uses the RESCALE operation to scale between values with differing precision. -# The RESCALE operator is defined using an integer multiply, add, and shift. -# This utility function is for calculating the multiplier and shift given a scale. -# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling def _compute_multiplier_and_shift( scales: list[float], scaleWidth: int = 32 ) -> Tuple[list[int], list[int]]: + """Derive integer multipliers and shifts from floating-point scales. + + TOSA uses the RESCALE operation to scale between values with differing + precision. The RESCALE operator is defined using an integer multiply, add, + and shift. This utility function is for calculating the multiplier and shift + given a scale. + Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling + + Args: + scales (list[float]): Scale factors to decompose into multiplier and + shift pairs. + scaleWidth (int): Bit-width of the multiplier representation; expects + ``16`` or ``32``. + + Returns: + Tuple[list[int], list[int]]: Parallel lists containing the computed + multipliers and right shifts. + + Raises: + ValueError: If ``scaleWidth`` is not supported. + + """ if scaleWidth == 16: offset = 15 elif scaleWidth == 32: @@ -78,8 +96,6 @@ def _compute_multiplier_and_shift( return multipliers, shifts -# For TOSA spec v1.0 RESCALE operator requires multiplier, shifts, input_zp and output_zp to be -# const inputs. Create constant operators from the data already initialized. def _create_const_ops_for_rescale( tosa_fb, scale_32, @@ -92,6 +108,29 @@ def _create_const_ops_for_rescale( output_dtype, ts, ): + """Materialize constant operands required by the TOSA RESCALE op. + + For TOSA spec v1.0 RESCALE operator requires multiplier, shifts, input_zp + and output_zp to be const inputs. Create constant operators from the data + already initialized. + + Args: + tosa_fb (Any): Graph builder used to emit TOSA operators and tensors. + scale_32 (bool): Flag indicating whether multipliers use 32-bit width. + input_dtype (ts.DType): Data type of the input tensor. + node_name (str): Base name reused for created constant tensors. + multipliers (list[int]): Precomputed multiplier coefficients. + shifts (list[int]): Precomputed shift coefficients. + input_zp (list[int]): Quantization zero points for the input. + output_zp (list[int]): Quantization zero points for the output. + output_dtype (ts.DType): Data type of the output tensor. + ts (module): Reference to the ``tosa_serializer`` module. + + Returns: + list[str]: Names of the constant tensors added to ``tosa_fb`` in the + order expected by RESCALE. + + """ multipliers = tosa_fb.addConst( (len(multipliers),), @@ -124,6 +163,22 @@ def _build_rescale( per_channel: bool = False, is_scale32: bool = True, ): + """Insert a TOSA RESCALE operator configured for the quantized path. + + Args: + tosa_fb (Any): Graph builder receiving the RESCALE operator. + scale (list[float]): Scale factors applied during rescaling. + input_node (Any): Input tensor node feeding the operator. + output_name (str): Name assigned to the RESCALE output tensor. + output_type (ts.DType): Data type of the output tensor. + input_zp (list[int]): Quantization zero points for the input tensor. + output_zp (list[int]): Quantization zero points for the output tensor. + rounding_mode (ts.RoundingMode): Rounding policy for the RESCALE op. + per_channel (bool): Whether scales are applied per output channel. + is_scale32 (bool): Declared scale width; ignored when the input type is + ``ts.DType.INT48``. + + """ scaleWidth = 16 if input_node.dtype == ts.DType.INT48 else 32 is_scale32 = False if input_node.dtype == ts.DType.INT48 else True multipliers, shifts = _compute_multiplier_and_shift(scale, scaleWidth) diff --git a/backends/arm/tosa/utils.py b/backends/arm/tosa/utils.py index 14a22298d8a..60ed0376697 100644 --- a/backends/arm/tosa/utils.py +++ b/backends/arm/tosa/utils.py @@ -2,7 +2,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - +"""Utility helpers for building TOSA graphs in the Arm backend.""" import logging from typing import Any @@ -26,19 +26,21 @@ def are_fake_tensors_broadcastable( fake_tensors: list[FakeTensor], ) -> tuple[bool, list[int]]: - """ - Determines whether a list of FakeTensors can be broadcast together. + """Determine whether the fake tensors share a broadcastable shape. + Args: - fake_tensors (list[FakeTensor]): List of 2 or more FakeTensors - who's shapes to evaluate + fake_tensors (list[FakeTensor]): Fake tensors whose shapes should + be validated for broadcasting. Returns: - tuple[bool, list[int]]: First element is whether the shapes are - broadcastable. Second element is the common shape if compatible. - If not, empty list. + tuple[bool, list[int]]: Tuple where the first element indicates + whether broadcasting is possible and the second element contains + the broadcast shape. The shape list is empty when broadcasting + fails. Raises: - RuntimeError: If less than 2 tensors are passed in. + RuntimeError: Raised when fewer than two tensors are supplied. + """ if len(fake_tensors) < 1: raise RuntimeError(f"Expected 2 or more tensors got {len(fake_tensors)}") @@ -65,26 +67,27 @@ def are_fake_tensors_broadcastable( def broadcast_tensors( tosa_fb, nodes: list[Node], tosa_spec: TosaSpecification ) -> list[Any]: - """ - Given a list of nodes it determines the common shape they broadcast to - and adds the necessary reshape and tile operations to perform the broadcast. + """Broadcast the FX nodes to a shared shape inside the TOSA graph. + + This mirrors ``reshape_for_broadcast`` but also emits the tile operators + needed to materialize the broadcast and supports any number of inputs. Args: - tosa_fb: Tosa graph to add nodes to - nodes (list[Node]): List of nodes to broadcast together - tosa_spec (TosaSpecification): Tosa spec + tosa_fb (Any): TOSA graph builder that receives the broadcast + operators. + nodes (list[Node]): FX nodes whose tensor metadata should be + broadcast. + tosa_spec (TosaSpecification): Active TOSA specification used to + decode tensor metadata. Returns: - list[Any]: List containing the fx.Nodes or TosaSerializerTensors - of the right common shape. Order of output matches order of input. + list[Any]: Broadcast versions of the inputs. Each element is either + the original FX node or a TOSA serializer tensor, ordered to match + ``nodes``. Raises: RuntimeError: If the supplied nodes are not broadcastable. - Note: - This function and `reshape_for_broadcast` both reshape the tensors - for broadcast. However this function also performs the broadcast and - does not have a limit on only two input tensors. """ index_fake_tensors = [node.meta["val"] for node in nodes] broadcastable, common_shape = are_fake_tensors_broadcastable(index_fake_tensors) @@ -137,6 +140,17 @@ def broadcast_tensors( def build_reshape_tosa_1_0( tosa_graph, input_name, new_shape, output_name, shape_name_override="" ): + """Insert a TOSA reshape operator using the v1.0 semantics. + + Args: + tosa_graph (Any): Graph builder used to emit TOSA operators. + input_name (str): Name of the tensor that should be reshaped. + new_shape (list[int]): Target tensor shape. + output_name (str): Name assigned to the reshaped tensor. + shape_name_override (str): Optional override for the shape constant + name. + + """ shape = tosa_graph.addConst( np.array(new_shape).shape, ts.DType.SHAPE, @@ -155,6 +169,19 @@ def build_reshape_tosa_1_0( def tosa_shape(shape, dim_order): + """Reorder a shape tuple into TOSA layout while resolving symints. + + Args: + shape (Sequence[int | torch.SymInt]): Original tensor shape, + possibly containing ``torch.SymInt``. + dim_order (Sequence[int]): Desired dimension order for the output + shape. + + Returns: + list[int]: List containing the reordered dimensions where symbolic + values become ``-1``. + + """ reordered = tuple([shape[dim] for dim in dim_order]) # Dynamic shapes in executorch are represented with torch.SymInt objects in the shapes, # in TOSA we do not have this concept and instead use -1. @@ -170,6 +197,26 @@ def get_resize_parameters_1d( resize_mode: int, align_corners: bool, ): + """Compute resize coefficients for a single spatial dimension. + + Args: + input_size (int | torch.SymInt): Input size for the axis, possibly + symbolic. + output_size (int | torch.SymInt): Output size for the axis, possibly + symbolic. + resize_mode (int): Target resize mode defined by TOSA. + align_corners (bool): Whether the resize should align the corner + pixels. + + Returns: + tuple[int, int, int, int]: Numerator, denominator, offset, and border + terms encoded as integers. + + Raises: + RuntimeError: If symbolic shapes are used with ``align_corners`` or if + the computed ratio or border is not constant. + + """ # We don't support align_corners for symbolic shapes, because handling the edge case where size == 1 is tricky. if align_corners: if (not isinstance(input_size, int)) or (not isinstance(output_size, int)): @@ -229,19 +276,23 @@ def get_resize_parameters( resize_mode: int, align_corners: bool, ) -> tuple[torch.IntTensor, ...]: - """Get the tosa.resize parameters based on the input and output size. + """Calculate 2D resize parameters for TOSA emission. Args: - input_size_xy (tuple[int | torch.SymInt]): Size of the input - output_size_xy (tuple[int | torch.SymInt]): Size of the output - resize_mode (tosa.ResizeMode): The TOSA resize mode - align_corners (bool): Align the corners pixels of the input and output + input_size_xy (tuple[int | torch.SymInt, int | torch.SymInt]): Height + and width of the input tensor. + output_size_xy (tuple[int | torch.SymInt, int | torch.SymInt]): Height + and width of the output tensor. + resize_mode (int): TOSA resize mode used for coefficient generation. + align_corners (bool): Whether to align corner pixels between input and + output. Returns: - scale_n (torch.IntTensor), scale_d (torch.IntTensor), - offset (torch.IntTensor), border (torch.IntTensor) - """ + tuple[torch.IntTensor, ...]: Four-element tuple of tensors describing + the scale numerator, scale denominator, offset, and border for Y + and X dimensions. + """ # Get the parameters for each dimension independently y_params = get_resize_parameters_1d( input_size_xy[0], output_size_xy[0], resize_mode, align_corners