diff --git a/backends/arm/operators/operator_validation_utils.py b/backends/arm/operators/operator_validation_utils.py index 9419e116789..32c01143f4f 100644 --- a/backends/arm/operators/operator_validation_utils.py +++ b/backends/arm/operators/operator_validation_utils.py @@ -2,46 +2,42 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Provide validation helpers for operator inputs and dtypes. + +Use these utilities to validate input counts, ensure dtype consistency, check +allowed dtypes, and compute pooling padding adjustments. + +""" from math import ceil, floor from typing import Any, List, Optional def validate_num_inputs(op_name: str, inputs: List[Any], expected: int | List[int]): - """ - Validates the number of inputs provided to an operation against expected values. - - This function checks whether the length of the input list matches the expected - number(s) of inputs. - - Parameters: - ----------- - op_name : str - The name of the operation for which the inputs are being validated. - Used in the error message to provide context. + """Validate the number of inputs against expected values. - inputs : List[TosaArg] - A list of inputs to be validated, where each input is assumed to be an - instance of `TosaArg`. + This function checks whether the length of the input list matches the + expected number(s) of inputs. - expected : int or List[int] - The expected number of inputs. Can be either an integer or a list of integers. + Args: + op_name (str): The name of the operation for which the inputs are being + validated. Used in the error message to provide context. + inputs (List[TosaArg]): A list of inputs to be validated, where each + input is assumed to be an instance of ``TosaArg``. + expected (int | List[int]): The expected number of inputs. Can be either + an integer or a list of integers. Raises: - ------- - ValueError - If the number of inputs does not match the expected value(s), a `ValueError` is - raised with a message indicating the operation name and the mismatch in expected - versus provided number of inputs. + ValueError: If the number of inputs does not match the expected + value(s); the message indicates the operation name and the mismatch + in expected versus provided counts. Example: - -------- - # Example usage: - from executorch.backends.arm.operators.operator_validation_utils import ( - validate_num_inputs, - ) + from executorch.backends.arm.operators.operator_validation_utils import \ + validate_num_inputs + + validate_num_inputs(self.target, inputs, [3, 4]) - validate_num_inputs(self.target, inputs, [3, 4]) """ if isinstance(expected, int): expected = [expected] @@ -54,39 +50,28 @@ def validate_num_inputs(op_name: str, inputs: List[Any], expected: int | List[in def validate_same_dtype(op_name: str, tensors: List[Any], ts: Optional[Any] = None): - """ - Validates that all given tensors have the same dtype attribute. - - This function checks whether all items in the `tensors` list have the same - `dtype` as the first item. - - Parameters: - ----------- - op_name : str - The name of the operation for which the dtype validation is being performed. - Used in the error message to provide context. + """Validate that all given tensors have the same dtype. - tensors : List[Any] - A list of tensors to be validated, each is assumed to have a `dtype` attribute. + This function checks whether all items in the ``tensors`` list have the + same ``dtype`` as the first item. - ts: Optional[Any] - TOSA serializer. Not required but only to get clearer error messages. + Args: + op_name (str): The name of the operation for which the dtype validation + is being performed. Used in the error message to provide context. + tensors (List[Any]): A list of tensors to be validated, each assumed to + have a ``dtype`` attribute. + ts (Optional[Any]): TOSA serializer (optional) to improve readability of + dtype names in error messages. Raises: - ------- - ValueError - If the dtype of any item in the list does not match the dtype of the first item, - a `ValueError` is raised with a message indicating the operation name and the - mismatch in dtypes. + ValueError: If the dtype of any item in the list does not match the + dtype of the first item, or if the list is empty. Example: - -------- - # Example usage: - from executorch.backends.arm.operators.operator_validation_utils import ( - validate_same_dtype, - ) + from executorch.backends.arm.operators.operator_validation_utils import \ + validate_same_dtype - validate_same_dtype(self.target, [input1, input2, output]) + validate_same_dtype(self.target, [input1, input2, output]) """ if not tensors: @@ -110,48 +95,40 @@ def validate_same_dtype(op_name: str, tensors: List[Any], ts: Optional[Any] = No def validate_valid_dtype( op_name: str, tensors: Any | List[Any], valid_dtypes: Any | List[Any], tosa_spec ): - """ - Validates that one or more tensors have dtypes within a set of allowed dtypes. - - This function checks whether the `dtype` attribute of the provided tensor(s) is one - of the valid dtype values. It supports checking a single tensor or a list of - tensors. - - Parameters: - ----------- - op_name : str - The name of the operation performing the validation. - tensors : Any or List[Any] - A tensor or list of tensors (each assumed to have `dtype` and `name` attributes) - whose dtype will be validated. - valid_dtypes : Any or List[Any] - A dtype enum or list of dtype enums representing allowed dtype values. - tosa_spec : Any - A TosaSpecification instance indicating which TOSA version is targeted. This - determines which serializer to use for dtype name resolution. + """Validate that one or more tensors have allowed dtypes. + + This function checks whether the ``dtype`` attribute of the provided + tensor(s) is one of the valid dtype values. It supports checking a single + tensor or a list of tensors. + + Args: + op_name (str): The name of the operation performing the validation. + tensors (Any | List[Any]): A tensor or list of tensors (each assumed to + have ``dtype`` and ``name`` attributes) whose dtype will be + validated. + valid_dtypes (Any | List[Any]): A dtype enum or list of dtype enums + representing allowed dtype values. + tosa_spec (Any): A TosaSpecification instance indicating which TOSA + version is targeted. This determines which serializer to use for + dtype name resolution. Raises: - ------- - ValueError - If no tensors are provided, or if any tensor has a dtype not in `valid_dtypes`. + ValueError: If no tensors are provided, or if any tensor has a dtype not + in ``valid_dtypes``. Example: - -------- - # Example usage: - from executorch.backends.arm.operators.operator_validation_utils import ( - validate_valid_dtype, - ) - - - validate_valid_dtype( - self.target, - [*inputs, output], - [ts.DType.INT8, ts.DType.INT32], - output.tosa_spec, - ) + from executorch.backends.arm.operators.operator_validation_utils import \ + validate_valid_dtype + import serializer.tosa_serializer as ts + + validate_valid_dtype( + self.target, + [*inputs, output], + [ts.DType.INT8, ts.DType.INT32], + output.tosa_spec, + ) """ - if not tensors: raise ValueError( f"{op_name}: Input tensor list is empty, cannot validate dtypes" @@ -176,36 +153,27 @@ def validate_valid_dtype( def adjust_pooling_pad_if_needed( input_size: int, kernel_size: int, stride: int, pad: int, ceil_mode: bool ) -> int: - """ - The Aten pooling ops has one value 'pad' per dimension to specify padding, but they - do not require input and output sizes to match up perfectly. Instead, the output - size is rounded up or down depending on ceil_mode, and padding at the end of the - input is automatically added or removed. TOSA on the other hand specifies two - padding values, one for pre-padding and one for post-padding, and these must satisfy + """Compute the post padding needed for pooling. - output_size = (input_size + pre_pad + post_pad - kernel_size) / stride + 1 + ATen pooling uses a single symmetric ``pad`` per dimension and rounds the + output size up or down depending on ``ceil_mode``. TOSA requires distinct + pre- and post-padding values that satisfy: - This function returns the post_pad value required to satisfy the above condition. + output_size == (input_size + pre_pad + post_pad - kernel_size) / stride + 1 - Parameters: - ----------- - input_size : int - The size of the input to the operator. + This function returns the required ``post_pad`` given a symmetric ``pad``. - kernel_size : int - The size of the kernel. + Args: + input_size (int): Input size. + kernel_size (int): Kernel size. + stride (int): Stride size. + pad (int): Symmetric padding specified by ATen. + ceil_mode (bool): Use ceil when computing output size. - stride : int - The size of the stride. + Returns: + int: Post-padding to satisfy the TOSA formula. - pad : int - The amount of padding. - - Output: - ------- - An int, giving the post-padding to use for the """ - if ceil_mode: output_size = ceil((input_size - kernel_size + 2 * pad) / stride) + 1 else: