Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 61 additions & 6 deletions backends/arm/operators/op_tosa_rescale.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,31 @@
from torch.fx import Node


# TOSA uses the RESCALE operation to scale between values with differing precision.
# The RESCALE operator is defined using an integer multiply, add, and shift.
# This utility function is for calculating the multiplier and shift given a scale.
# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
def _compute_multiplier_and_shift(
scales: list[float], scaleWidth: int = 32
) -> Tuple[list[int], list[int]]:
"""Derive integer multipliers and shifts from floating-point scales.

TOSA uses the RESCALE operation to scale between values with differing
precision. The RESCALE operator is defined using an integer multiply, add,
and shift. This utility function is for calculating the multiplier and shift
given a scale.
Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling

Args:
scales (list[float]): Scale factors to decompose into multiplier and
shift pairs.
scaleWidth (int): Bit-width of the multiplier representation; expects
``16`` or ``32``.

Returns:
Tuple[list[int], list[int]]: Parallel lists containing the computed
multipliers and right shifts.

Raises:
ValueError: If ``scaleWidth`` is not supported.

"""
if scaleWidth == 16:
offset = 15
elif scaleWidth == 32:
Expand Down Expand Up @@ -78,8 +96,6 @@ def _compute_multiplier_and_shift(
return multipliers, shifts


# For TOSA spec v1.0 RESCALE operator requires multiplier, shifts, input_zp and output_zp to be
# const inputs. Create constant operators from the data already initialized.
def _create_const_ops_for_rescale(
tosa_fb,
scale_32,
Expand All @@ -92,6 +108,29 @@ def _create_const_ops_for_rescale(
output_dtype,
ts,
):
"""Materialize constant operands required by the TOSA RESCALE op.

For TOSA spec v1.0 RESCALE operator requires multiplier, shifts, input_zp
and output_zp to be const inputs. Create constant operators from the data
already initialized.

Args:
tosa_fb (Any): Graph builder used to emit TOSA operators and tensors.
scale_32 (bool): Flag indicating whether multipliers use 32-bit width.
input_dtype (ts.DType): Data type of the input tensor.
node_name (str): Base name reused for created constant tensors.
multipliers (list[int]): Precomputed multiplier coefficients.
shifts (list[int]): Precomputed shift coefficients.
input_zp (list[int]): Quantization zero points for the input.
output_zp (list[int]): Quantization zero points for the output.
output_dtype (ts.DType): Data type of the output tensor.
ts (module): Reference to the ``tosa_serializer`` module.

Returns:
list[str]: Names of the constant tensors added to ``tosa_fb`` in the
order expected by RESCALE.

"""

multipliers = tosa_fb.addConst(
(len(multipliers),),
Expand Down Expand Up @@ -124,6 +163,22 @@ def _build_rescale(
per_channel: bool = False,
is_scale32: bool = True,
):
"""Insert a TOSA RESCALE operator configured for the quantized path.

Args:
tosa_fb (Any): Graph builder receiving the RESCALE operator.
scale (list[float]): Scale factors applied during rescaling.
input_node (Any): Input tensor node feeding the operator.
output_name (str): Name assigned to the RESCALE output tensor.
output_type (ts.DType): Data type of the output tensor.
input_zp (list[int]): Quantization zero points for the input tensor.
output_zp (list[int]): Quantization zero points for the output tensor.
rounding_mode (ts.RoundingMode): Rounding policy for the RESCALE op.
per_channel (bool): Whether scales are applied per output channel.
is_scale32 (bool): Declared scale width; ignored when the input type is
``ts.DType.INT48``.

"""
scaleWidth = 16 if input_node.dtype == ts.DType.INT48 else 32
is_scale32 = False if input_node.dtype == ts.DType.INT48 else True
multipliers, shifts = _compute_multiplier_and_shift(scale, scaleWidth)
Expand Down
109 changes: 80 additions & 29 deletions backends/arm/tosa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Utility helpers for building TOSA graphs in the Arm backend."""

import logging
from typing import Any
Expand All @@ -26,19 +26,21 @@
def are_fake_tensors_broadcastable(
fake_tensors: list[FakeTensor],
) -> tuple[bool, list[int]]:
"""
Determines whether a list of FakeTensors can be broadcast together.
"""Determine whether the fake tensors share a broadcastable shape.

Args:
fake_tensors (list[FakeTensor]): List of 2 or more FakeTensors
who's shapes to evaluate
fake_tensors (list[FakeTensor]): Fake tensors whose shapes should
be validated for broadcasting.

Returns:
tuple[bool, list[int]]: First element is whether the shapes are
broadcastable. Second element is the common shape if compatible.
If not, empty list.
tuple[bool, list[int]]: Tuple where the first element indicates
whether broadcasting is possible and the second element contains
the broadcast shape. The shape list is empty when broadcasting
fails.

Raises:
RuntimeError: If less than 2 tensors are passed in.
RuntimeError: Raised when fewer than two tensors are supplied.

"""
if len(fake_tensors) < 1:
raise RuntimeError(f"Expected 2 or more tensors got {len(fake_tensors)}")
Expand All @@ -65,26 +67,27 @@ def are_fake_tensors_broadcastable(
def broadcast_tensors(
tosa_fb, nodes: list[Node], tosa_spec: TosaSpecification
) -> list[Any]:
"""
Given a list of nodes it determines the common shape they broadcast to
and adds the necessary reshape and tile operations to perform the broadcast.
"""Broadcast the FX nodes to a shared shape inside the TOSA graph.

This mirrors ``reshape_for_broadcast`` but also emits the tile operators
needed to materialize the broadcast and supports any number of inputs.

Args:
tosa_fb: Tosa graph to add nodes to
nodes (list[Node]): List of nodes to broadcast together
tosa_spec (TosaSpecification): Tosa spec
tosa_fb (Any): TOSA graph builder that receives the broadcast
operators.
nodes (list[Node]): FX nodes whose tensor metadata should be
broadcast.
tosa_spec (TosaSpecification): Active TOSA specification used to
decode tensor metadata.

Returns:
list[Any]: List containing the fx.Nodes or TosaSerializerTensors
of the right common shape. Order of output matches order of input.
list[Any]: Broadcast versions of the inputs. Each element is either
the original FX node or a TOSA serializer tensor, ordered to match
``nodes``.

Raises:
RuntimeError: If the supplied nodes are not broadcastable.

Note:
This function and `reshape_for_broadcast` both reshape the tensors
for broadcast. However this function also performs the broadcast and
does not have a limit on only two input tensors.
"""
index_fake_tensors = [node.meta["val"] for node in nodes]
broadcastable, common_shape = are_fake_tensors_broadcastable(index_fake_tensors)
Expand Down Expand Up @@ -137,6 +140,17 @@ def broadcast_tensors(
def build_reshape_tosa_1_0(
tosa_graph, input_name, new_shape, output_name, shape_name_override=""
):
"""Insert a TOSA reshape operator using the v1.0 semantics.

Args:
tosa_graph (Any): Graph builder used to emit TOSA operators.
input_name (str): Name of the tensor that should be reshaped.
new_shape (list[int]): Target tensor shape.
output_name (str): Name assigned to the reshaped tensor.
shape_name_override (str): Optional override for the shape constant
name.

"""
shape = tosa_graph.addConst(
np.array(new_shape).shape,
ts.DType.SHAPE,
Expand All @@ -155,6 +169,19 @@ def build_reshape_tosa_1_0(


def tosa_shape(shape, dim_order):
"""Reorder a shape tuple into TOSA layout while resolving symints.

Args:
shape (Sequence[int | torch.SymInt]): Original tensor shape,
possibly containing ``torch.SymInt``.
dim_order (Sequence[int]): Desired dimension order for the output
shape.

Returns:
list[int]: List containing the reordered dimensions where symbolic
values become ``-1``.

"""
reordered = tuple([shape[dim] for dim in dim_order])
# Dynamic shapes in executorch are represented with torch.SymInt objects in the shapes,
# in TOSA we do not have this concept and instead use -1.
Expand All @@ -170,6 +197,26 @@ def get_resize_parameters_1d(
resize_mode: int,
align_corners: bool,
):
"""Compute resize coefficients for a single spatial dimension.

Args:
input_size (int | torch.SymInt): Input size for the axis, possibly
symbolic.
output_size (int | torch.SymInt): Output size for the axis, possibly
symbolic.
resize_mode (int): Target resize mode defined by TOSA.
align_corners (bool): Whether the resize should align the corner
pixels.

Returns:
tuple[int, int, int, int]: Numerator, denominator, offset, and border
terms encoded as integers.

Raises:
RuntimeError: If symbolic shapes are used with ``align_corners`` or if
the computed ratio or border is not constant.

"""
# We don't support align_corners for symbolic shapes, because handling the edge case where size == 1 is tricky.
if align_corners:
if (not isinstance(input_size, int)) or (not isinstance(output_size, int)):
Expand Down Expand Up @@ -229,19 +276,23 @@ def get_resize_parameters(
resize_mode: int,
align_corners: bool,
) -> tuple[torch.IntTensor, ...]:
"""Get the tosa.resize parameters based on the input and output size.
"""Calculate 2D resize parameters for TOSA emission.

Args:
input_size_xy (tuple[int | torch.SymInt]): Size of the input
output_size_xy (tuple[int | torch.SymInt]): Size of the output
resize_mode (tosa.ResizeMode): The TOSA resize mode
align_corners (bool): Align the corners pixels of the input and output
input_size_xy (tuple[int | torch.SymInt, int | torch.SymInt]): Height
and width of the input tensor.
output_size_xy (tuple[int | torch.SymInt, int | torch.SymInt]): Height
and width of the output tensor.
resize_mode (int): TOSA resize mode used for coefficient generation.
align_corners (bool): Whether to align corner pixels between input and
output.

Returns:
scale_n (torch.IntTensor), scale_d (torch.IntTensor),
offset (torch.IntTensor), border (torch.IntTensor)
"""
tuple[torch.IntTensor, ...]: Four-element tuple of tensors describing
the scale numerator, scale denominator, offset, and border for Y
and X dimensions.

"""
# Get the parameters for each dimension independently
y_params = get_resize_parameters_1d(
input_size_xy[0], output_size_xy[0], resize_mode, align_corners
Expand Down
Loading