diff --git a/backends/arm/test/tester/analyze_output_utils.py b/backends/arm/test/tester/analyze_output_utils.py index c707eed8013..9bea6337655 100644 --- a/backends/arm/test/tester/analyze_output_utils.py +++ b/backends/arm/test/tester/analyze_output_utils.py @@ -5,6 +5,7 @@ import logging import tempfile +from typing import Any, cast, Sequence import torch from executorch.backends.arm.test.runner_utils import ( @@ -17,9 +18,30 @@ logger = logging.getLogger(__name__) -def _print_channels(result, reference, channels_close, C, H, W, rtol, atol): +TensorLike = torch.Tensor | tuple[torch.Tensor, ...] + + +def _ensure_tensor(value: TensorLike) -> torch.Tensor: + if isinstance(value, torch.Tensor): + return value + if value and isinstance(value[0], torch.Tensor): + return value[0] + raise TypeError("Expected a Tensor or a non-empty tuple of Tensors") + + +def _print_channels( + result: torch.Tensor, + reference: torch.Tensor, + channels_close: Sequence[bool], + C: int, + H: int, + W: int, + rtol: float, + atol: float, +) -> str: output_str = "" + exp = "000" booldata = False if reference.dtype == torch.bool or result.dtype == torch.bool: booldata = True @@ -62,7 +84,15 @@ def _print_channels(result, reference, channels_close, C, H, W, rtol, atol): return output_str -def _print_elements(result, reference, C, H, W, rtol, atol): +def _print_elements( + result: torch.Tensor, + reference: torch.Tensor, + C: int, + H: int, + W: int, + rtol: float, + atol: float, +) -> str: output_str = "" for y in range(H): res = "[" @@ -92,14 +122,16 @@ def _print_elements(result, reference, C, H, W, rtol, atol): def print_error_diffs( - tester, - result: torch.Tensor | tuple, - reference: torch.Tensor | tuple, - quantization_scale=None, - atol=1e-03, - rtol=1e-03, - qtol=0, -): + tester_or_result: Any, + result_or_reference: TensorLike, + reference: TensorLike | None = None, + # Force remaining args to be keyword-only to keep the two positional call patterns unambiguous. + *, + quantization_scale: float | None = None, + atol: float = 1e-03, + rtol: float = 1e-03, + qtol: float = 0, +) -> None: """ Prints the error difference between a result tensor and a reference tensor in NCHW format. Certain formatting rules are applied to clarify errors: @@ -130,15 +162,16 @@ def print_error_diffs( """ - - if isinstance(reference, tuple): - reference = reference[0] - if isinstance(result, tuple): - result = result[0] - - if not result.shape == reference.shape: + if reference is None: + result = _ensure_tensor(cast(TensorLike, tester_or_result)) + reference_tensor = _ensure_tensor(result_or_reference) + else: + result = _ensure_tensor(result_or_reference) + reference_tensor = _ensure_tensor(reference) + + if result.shape != reference_tensor.shape: raise ValueError( - f"Output needs to be of same shape: {result.shape} != {reference.shape}" + f"Output needs to be of same shape: {result.shape} != {reference_tensor.shape}" ) shape = result.shape @@ -161,29 +194,29 @@ def print_error_diffs( # Reshape tensors to 4D NCHW format result = torch.reshape(result, (N, C, H, W)) - reference = torch.reshape(reference, (N, C, H, W)) + reference_tensor = torch.reshape(reference_tensor, (N, C, H, W)) output_str = "" for n in range(N): output_str += f"BATCH {n}\n" result_batch = result[n, :, :, :] - reference_batch = reference[n, :, :, :] + reference_batch = reference_tensor[n, :, :, :] is_close = torch.allclose(result_batch, reference_batch, rtol, atol) if is_close: output_str += ".\n" else: - channels_close = [None] * C + channels_close: list[bool] = [False] * C for c in range(C): result_hw = result[n, c, :, :] - reference_hw = reference[n, c, :, :] + reference_hw = reference_tensor[n, c, :, :] channels_close[c] = torch.allclose(result_hw, reference_hw, rtol, atol) if any(channels_close) or len(channels_close) == 1: output_str += _print_channels( result[n, :, :, :], - reference[n, :, :, :], + reference_tensor[n, :, :, :], channels_close, C, H, @@ -193,7 +226,13 @@ def print_error_diffs( ) else: output_str += _print_elements( - result[n, :, :, :], reference[n, :, :, :], C, H, W, rtol, atol + result[n, :, :, :], + reference_tensor[n, :, :, :], + C, + H, + W, + rtol, + atol, ) if reference_batch.dtype == torch.bool or result_batch.dtype == torch.bool: mismatches = (reference_batch != result_batch).sum().item() @@ -201,9 +240,9 @@ def print_error_diffs( output_str += f"(BOOLEAN tensor) {mismatches} / {total} elements differ ({mismatches / total:.2%})\n" # Only compute numeric error metrics if tensor is not boolean - if reference.dtype != torch.bool and result.dtype != torch.bool: - reference_range = torch.max(reference) - torch.min(reference) - diff = torch.abs(reference - result).flatten() + if reference_tensor.dtype != torch.bool and result.dtype != torch.bool: + reference_range = torch.max(reference_tensor) - torch.min(reference_tensor) + diff = torch.abs(reference_tensor - result).flatten() diff = diff[diff.nonzero()] if not len(diff) == 0: diff_percent = diff / reference_range @@ -230,14 +269,14 @@ def print_error_diffs( def dump_error_output( - tester, - reference_output, - stage_output, - quantization_scale=None, - atol=1e-03, - rtol=1e-03, - qtol=0, -): + tester: Any, + reference_output: TensorLike, + stage_output: TensorLike, + quantization_scale: float | None = None, + atol: float = 1e-03, + rtol: float = 1e-03, + qtol: float = 0, +) -> None: """ Prints Quantization info and error tolerances, and saves the differing tensors to disc. """ diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 21b537cfda0..1d9ee42c19e 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -14,9 +14,11 @@ from typing import ( Any, Callable, + cast, Dict, Iterable, List, + no_type_check, Optional, Sequence, Tuple, @@ -60,7 +62,14 @@ from executorch.backends.test.harness.error_statistics import ErrorStatistics from executorch.backends.test.harness.stages import Stage, StageType -from executorch.backends.xnnpack.test.tester import Tester +from executorch.backends.xnnpack.test.tester import ( + Partition as XnnpackPartitionStage, + Quantize as XnnpackQuantize, + Tester, + ToEdge as XnnpackToEdge, + ToEdgeTransformAndLower as XnnpackToEdgeTransformAndLower, + ToExecutorch as XnnpackToExecutorch, +) from executorch.devtools.backend_debug import get_delegation_info from executorch.exir import ( @@ -71,11 +80,7 @@ to_edge_transform_and_lower, ) from executorch.exir.backend.backend_api import validation_disabled -from executorch.exir.backend.operator_support import ( - DontPartition, - DontPartitionModule, - DontPartitionName, -) +from executorch.exir.backend.operator_support import OperatorSupportBase from executorch.exir.backend.partitioner import Partitioner from executorch.exir.lowered_backend_module import LoweredBackendModule from executorch.exir.pass_base import ExportPass @@ -84,7 +89,7 @@ _copy_module, _update_exported_program_graph_module, ) -from tabulate import tabulate +from tabulate import tabulate # type: ignore[import-untyped] from torch.export.graph_signature import ExportGraphSignature, InputSpec, OutputSpec from torch.fx import Graph @@ -94,9 +99,13 @@ def _dump_lowered_modules_artifact( path_to_dump: Optional[str], - artifact: ExecutorchProgramManager, - graph_module: torch.fx.GraphModule, -): + artifact: Union[EdgeProgramManager, ExecutorchProgramManager], + graph_module: torch.fx.GraphModule | None, +) -> None: + if graph_module is None: + logger.warning("No graph module available to dump lowered modules.") + return + output = "Formated Graph Signature:\n" output += _format_export_graph_signature( artifact.exported_program().graph_signature @@ -117,7 +126,7 @@ def _dump_lowered_modules_artifact( output += f"\nTOSA deserialized {node.name}: \n{to_print}\n" elif isinstance(compile_spec, EthosUCompileSpec): vela_cmd_stream = lowered_module.processed_bytes - output += f"\nVela command stream {node.name}: \n{vela_cmd_stream}\n" + output += f"\nVela command stream {node.name}: \n{vela_cmd_stream!r}\n" else: logger.warning( f"No TOSA nor Vela compile spec found in compile specs of {node.name}." @@ -134,7 +143,14 @@ def _dump_lowered_modules_artifact( class Partition(tester.Partition): def dump_artifact(self, path_to_dump: Optional[str]): super().dump_artifact(path_to_dump) - _dump_lowered_modules_artifact(path_to_dump, self.artifact, self.graph_module) + artifact = cast(Optional[EdgeProgramManager], self.artifact) + graph_module = cast(Optional[torch.fx.GraphModule], self.graph_module) + if artifact is None: + logger.warning( + "Partition stage artifact missing; skipping lowered module dump." + ) + return + _dump_lowered_modules_artifact(path_to_dump, artifact, graph_module) class ToEdgeTransformAndLower(tester.ToEdgeTransformAndLower): @@ -153,7 +169,14 @@ def __init__( def dump_artifact(self, path_to_dump: Optional[str]): super().dump_artifact(path_to_dump) - _dump_lowered_modules_artifact(path_to_dump, self.artifact, self.graph_module) + artifact = cast(Optional[EdgeProgramManager], self.artifact) + graph_module = cast(Optional[torch.fx.GraphModule], self.graph_module) + if artifact is None: + logger.warning( + "ToEdgeTransformAndLower stage artifact missing; skipping lowered module dump." + ) + return + _dump_lowered_modules_artifact(path_to_dump, artifact, graph_module) def run( self, artifact: ExportedProgram, inputs=None, generate_etrecord: bool = False @@ -177,15 +200,18 @@ def run_artifact(self, inputs): class RunPasses(tester.RunPasses): + @no_type_check def __init__( self, - pass_list: Optional[List[Type[ExportPass]]] = None, + pass_list: Optional[List[Type[PassType]]] = None, pass_functions: Optional[List[Callable]] = None, passes_with_exported_program: Optional[List[Type[ExportPass]]] = None, ): """Passes are run in the order they are passed: first pass_list, second pass_functions, and lastly passes_with_exported_program.""" - self.pass_with_exported_program = passes_with_exported_program + self.pass_with_exported_program: Optional[List[Type[ExportPass]]] = ( + passes_with_exported_program + ) super().__init__(pass_list, pass_functions) @@ -193,14 +219,15 @@ def run( self, artifact: Union[EdgeProgramManager, ExportedProgram], inputs=None ) -> None: if self.pass_with_exported_program is not None: - self.pass_functions = self.pass_functions or [] # type: ignore + pass_functions = list(self.pass_functions or []) # type: ignore[has-type] # pass_function list from superclass expects functions that take in # and return ExportedPrograms. # Create a wrapper to fit pass_with_exported_program into this. def wrap_ep_pass(ep_pass: Type[ExportPass]): def wrapped_ep_pass(ep: ExportedProgram) -> ExportedProgram: - pass_result = ep_pass(ep).call(ep.graph_module) + pass_instance = ep_pass(ep) # type: ignore[call-arg] + pass_result = pass_instance.call(ep.graph_module) with validation_disabled(): return _update_exported_program_graph_module( ep, pass_result.graph_module @@ -208,9 +235,10 @@ def wrapped_ep_pass(ep: ExportedProgram) -> ExportedProgram: return wrapped_ep_pass - self.pass_functions.extend( + pass_functions.extend( [wrap_ep_pass(ep_pass) for ep_pass in self.pass_with_exported_program] ) + self.pass_functions = pass_functions super().run(artifact, inputs) @@ -243,7 +271,7 @@ class ArmTester(Tester): def __init__( self, model: torch.nn.Module, - example_inputs: Tuple, + example_inputs: Tuple[Any, ...], compile_spec: ArmCompileSpec, tosa_ref_model_path: str | None = None, dynamic_shapes: Optional[Tuple[Any]] = None, @@ -272,15 +300,17 @@ def __init__( self.original_module.requires_grad_(False) # Initial model needs to be set as a *possible* but not yet added Stage, therefore add None entry. - self.stages[StageType.INITIAL_MODEL] = None + self.stages[StageType.INITIAL_MODEL] = cast(Stage, None) self._run_stage(InitialModel(self.original_module)) self.use_portable_ops = use_portable_ops self.timeout = timeout + @no_type_check def quantize( self, - quantize_stage: Optional[tester.Quantize] = None, + quantize_stage: Optional[XnnpackQuantize] = None, ): + # Same stage type as parent but exposed via module alias if quantize_stage is None: quantizer = create_quantizer(self.compile_spec) quantize_stage = tester.Quantize( @@ -289,11 +319,15 @@ def quantize( ) return super().quantize(quantize_stage) + @no_type_check def to_edge( self, - to_edge_stage: Optional[tester.ToEdge] = None, + to_edge_stage: Optional[XnnpackToEdge] = None, + # Keep config keyword-only to avoid positional clashes with legacy calls. + *, config: Optional[EdgeCompileConfig] = None, ): + # Allow optional config override beyond base signature if to_edge_stage is None: to_edge_stage = tester.ToEdge(config) else: @@ -302,25 +336,29 @@ def to_edge( return super().to_edge(to_edge_stage) - def partition(self, partition_stage: Optional[Partition] = None): + @no_type_check + def partition(self, partition_stage: Optional[XnnpackPartitionStage] = None): + # Accept Arm-specific partition stage subclass if partition_stage is None: arm_partitioner = create_partitioner(self.compile_spec) partition_stage = Partition(arm_partitioner) return super().partition(partition_stage) + @no_type_check def to_edge_transform_and_lower( self, - to_edge_and_lower_stage: Optional[ToEdgeTransformAndLower] = None, + to_edge_and_lower_stage: Optional[XnnpackToEdgeTransformAndLower] = None, + generate_etrecord: bool = False, + # Force the optional tuning knobs to be keyword-only for readability/back-compat. + *, partitioners: Optional[List[Partitioner]] = None, edge_compile_config: Optional[EdgeCompileConfig] = None, - additional_checks: Optional[ - List[DontPartition | DontPartitionModule | DontPartitionName] - ] = None, + additional_checks: Optional[Sequence[OperatorSupportBase]] = None, transform_passes: Optional[ Union[Sequence[PassType], Dict[str, Sequence[PassType]]] ] = None, - generate_etrecord: bool = False, ): + # Arm flow exposes extra stage wiring knobs if transform_passes is not None: raise RuntimeError( "transform passes are given to ArmTester at construction." @@ -328,9 +366,10 @@ def to_edge_transform_and_lower( if to_edge_and_lower_stage is None: if partitioners is None: - arm_partitioner = create_partitioner( - self.compile_spec, additional_checks + operator_checks = ( + list(additional_checks) if additional_checks is not None else None ) + arm_partitioner = create_partitioner(self.compile_spec, operator_checks) partitioners = [arm_partitioner] to_edge_and_lower_stage = ToEdgeTransformAndLower( partitioners, @@ -347,14 +386,20 @@ def to_edge_transform_and_lower( to_edge_and_lower_stage, generate_etrecord=generate_etrecord ) - def to_executorch(self, to_executorch_stage: Optional[ToExecutorch] | None = None): + @no_type_check + def to_executorch(self, to_executorch_stage: Optional[XnnpackToExecutorch] = None): + # Allow custom ExecuTorch stage subclass if to_executorch_stage is None: to_executorch_stage = ToExecutorch() return super().to_executorch(to_executorch_stage) + @no_type_check def serialize( self, serialize_stage: Optional[Serialize] = None, + # Keep timeout keyword-only so positional usage matches the base class. + *, + timeout: int = 480, ): if serialize_stage is None: serialize_stage = Serialize( @@ -374,15 +419,17 @@ def is_quantized(self) -> bool: def run_method_and_compare_outputs( self, - inputs: Optional[Tuple[torch.Tensor]] = None, - stage: Optional[str] = None, - num_runs=1, - atol=1e-03, - rtol=1e-03, - qtol=0, - error_callbacks=None, - run_eager_mode=False, + stage: Optional[StageType] = None, + inputs: Optional[Tuple[torch.Tensor, ...]] = None, + num_runs: int = 1, + atol: float = 1e-03, + rtol: float = 1e-03, + qtol: int = 0, statistics_callback: Callable[[ErrorStatistics], None] | None = None, + # Preserve positional compatibility while keeping new flags keyword-only. + *, + error_callbacks: Optional[Sequence[Callable[..., None]]] = None, + run_eager_mode: bool = False, ): """ Compares the run_artifact output of 'stage' with the output of a reference stage. @@ -399,6 +446,12 @@ def run_method_and_compare_outputs( The default is random data. """ + # backward-compatible ordering (accept inputs as the first positional argument) + if inputs is None and isinstance(stage, tuple): + if all(isinstance(arg, torch.Tensor) for arg in stage): + inputs = cast(Tuple[torch.Tensor, ...], stage) + stage = None + if not run_eager_mode: edge_stage = self.stages[StageType.TO_EDGE] if edge_stage is None: @@ -415,6 +468,8 @@ def run_method_and_compare_outputs( ), "To compare outputs in eager mode, the model must be at Export stage" stage = stage or self.cur + if stage is None: + raise RuntimeError("No stage has been executed yet.") test_stage = self.stages[stage] is_quantized = self.is_quantized() @@ -423,7 +478,8 @@ def run_method_and_compare_outputs( else: reference_stage = self.stages[StageType.INITIAL_MODEL] - exported_program = self.stages[StageType.EXPORT].artifact + exported_stage = self.stages[StageType.EXPORT] + exported_program = cast(ExportedProgram, exported_stage.artifact) output_node = exported_program.graph_module.graph.output_node() output_qparams = get_output_quantization_params(output_node) @@ -436,7 +492,9 @@ def run_method_and_compare_outputs( ) # Loop inputs and compare reference stage with the compared stage. - for run_iteration in range(num_runs): + number_of_runs = 1 if inputs is not None else num_runs + + for run_iteration in range(number_of_runs): reference_input = inputs if inputs else next(self.generate_random_inputs()) # Avoid issues with inplace operators @@ -455,11 +513,10 @@ def run_method_and_compare_outputs( ) if run_eager_mode: # Run exported module directly - test_outputs, _ = pytree.tree_flatten( - self._calculate_reference_output( - exported_program.module(), test_input - ) + eager_output, _ = self._calculate_reference_output( + exported_program, test_input ) + test_outputs, _ = pytree.tree_flatten(eager_output) else: # Run lowered model with target test_outputs, _ = pytree.tree_flatten( @@ -480,14 +537,17 @@ def run_method_and_compare_outputs( atol, rtol, qtol, - error_callbacks, + statistics_callback=statistics_callback, + error_callbacks=error_callbacks, ) return self - def get_graph(self, stage: str | None = None) -> Graph: + def get_graph(self, stage: StageType | None = None) -> Graph: if stage is None: stage = self.cur + if stage is None: + raise RuntimeError("No stage has been executed yet.") artifact = self.get_artifact(stage) if ( self.cur == StageType.TO_EDGE @@ -527,8 +587,8 @@ def dump_operator_distribution( and print_table ): graph_module = self.get_artifact().exported_program().graph_module + delegation_info = get_delegation_info(graph_module) if print_table: - delegation_info = get_delegation_info(graph_module) op_dist = delegation_info.get_operator_delegation_dataframe() else: op_dist = dict(_get_operator_distribution(graph_module.graph)) @@ -572,6 +632,7 @@ def dump_dtype_distribution( all_dtypes = set(dtype_dist_placeholders.keys()) | set( dtype_dirst_tensors.keys() ) + dtype_dist: dict[str, Any] if print_table: dtype_dist = { "Dtype": all_dtypes, @@ -589,13 +650,14 @@ def dump_dtype_distribution( ], } else: - dtype_dist = dict(dtype_dist_placeholders + dtype_dirst_tensors) + combined_counts = dtype_dist_placeholders + dtype_dirst_tensors + dtype_dist = {key: combined_counts[key] for key in combined_counts} to_print += _format_dict(dtype_dist, print_table) + "\n" _dump_str(to_print, path_to_dump) return self def run_transform_for_annotation_pipeline( - self, stage: str | None = None + self, stage: StageType | None = None ) -> torch.fx.GraphModule: """Run transform_for_annotation_pipeline on exported program to ensure passes do not break the initial model before quantization. @@ -609,6 +671,8 @@ def run_transform_for_annotation_pipeline( if stage is None: stage = self.cur + if stage is None: + raise RuntimeError("No stage has been executed yet.") # We need to clone the artifact in order to ensure that the state_dict is preserved after passes are run. artifact = self.get_artifact(stage) if self.cur == StageType.EXPORT: @@ -622,8 +686,8 @@ def run_transform_for_annotation_pipeline( @staticmethod def _calculate_reference_output( - module: Union[torch.fx.GraphModule, torch.nn.Module], inputs - ) -> torch.Tensor: + program: ExportedProgram, inputs: Tuple[Any, ...] + ) -> Tuple[torch.Tensor, Optional[float]]: """ Note: I'd prefer to use the base class method here, but since it use the exported program, I can't. The partitioner stage clears the state_dict @@ -631,8 +695,10 @@ def _calculate_reference_output( module. """ - return module.forward(*inputs) + module = program.module() + return module.forward(*inputs), None + @no_type_check def _compare_outputs( self, reference_output, @@ -641,9 +707,12 @@ def _compare_outputs( atol=1e-03, rtol=1e-03, qtol=0, - error_callbacks=None, statistics_callback: Callable[[ErrorStatistics], None] | None = None, + # Extra debugging hooks are keyword-only to keep the signature stable. + *, + error_callbacks: Optional[Sequence[Callable[..., None]]] = None, ): + # Accept extra error callback hook for debugging try: super()._compare_outputs( reference_output, @@ -655,14 +724,17 @@ def _compare_outputs( statistics_callback=statistics_callback, ) except AssertionError as e: - if error_callbacks is None: - error_callbacks = [print_error_diffs, dump_error_output] - for callback in error_callbacks: + callbacks = ( + list(error_callbacks) + if error_callbacks is not None + else [print_error_diffs, dump_error_output] + ) + for callback in callbacks: callback( self, stage_output, reference_output, - quantization_scale=None, + quantization_scale=quantization_scale, atol=1e-03, rtol=1e-03, qtol=0, @@ -680,12 +752,12 @@ def __del__(self): def _get_dtype_distribution( graph: Graph, tosa_spec: TosaSpecification -) -> tuple[dict, dict]: +) -> tuple[Counter[str], Counter[str]]: """Counts the occurences of placeholder and call_function dtypes in a graph. The result is a tuple of Counters (placeholder_distribution, call_function_distribution) """ - placeholder_dtypes = [] - call_function_dtypes = [] + placeholder_dtypes: list[str] = [] + call_function_dtypes: list[str] = [] for node in graph.nodes: if node.op == "placeholder": placeholder_dtypes.append(str(node.meta["val"].dtype)) @@ -706,7 +778,7 @@ def _get_operator_distribution(graph: Graph) -> dict[str, int]: def _format_export_graph_signature(signature: ExportGraphSignature) -> str: - def specs_dict(specs: list[InputSpec | OutputSpec], title: str): + def specs_dict(specs: Sequence[InputSpec | OutputSpec], title: str): _dict: dict[str, list] = {title: [], "arg": [], "kind": [], "target": []} for i, spec in enumerate(specs): _dict[title].append(i) diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index 54a8f08ee50..f3f5ab390e5 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -41,10 +41,42 @@ from torch._export.pass_base import PassType logger = logging.getLogger(__name__) -T = TypeVar("T") +T = TypeVar("T", bound=Tuple[Any, ...]) """ Generic type used for test data in the pipeline. Depends on which type the operator expects.""" +def _require_tosa_version() -> str: + version = conftest.get_option("tosa_version") + if not isinstance(version, str): + raise TypeError(f"TOSA version option must be a string, got {type(version)}.") + return version + + +class PipelineStage: + """Container for a pipeline stage (callable plus arguments).""" + + def __init__(self, func: Callable, id: str, *args, **kwargs): + self.id: str = id + self.func: Callable = func + self.args = args + self.kwargs = kwargs + self.is_called = False + + def __call__(self): + if not self.is_called: + self.func(*self.args, **self.kwargs) + else: + raise RuntimeError(f"{self.id} called twice.") + self.is_called = True + + def update(self, *args, **kwargs): + if not self.is_called: + self.args = args + self.kwargs = kwargs + else: + raise RuntimeError(f"{self.id} args updated after being called.") + + class BasePipelineMaker(Generic[T]): """ The BasePiplineMaker defines a list of stages to be applied to a torch.nn.module for lowering it @@ -65,46 +97,21 @@ class BasePipelineMaker(Generic[T]): tester.to_edge().check(exir_ops).partition() """ - class PipelineStage: - """ - Helper class to store a pipeline stage as a function call + args for calling later on. - - Attributes: - id: name of the function to be called, used for refering to stages in the pipeline. - func: handle to the function to be called. - args: args used when called. - kwargs: kwargs used when called. - is_called: keeps track of if the function has been called. - """ - - def __init__(self, func: Callable, id: str, *args, **kwargs): - self.id: str = id - self.func: Callable = func - self.args = args - self.kwargs = kwargs - self.is_called = False - - def __call__(self): - if not self.is_called: - self.func(*self.args, **self.kwargs) - else: - raise RuntimeError(f"{self.id} called twice.") - self.is_called = True - - def update(self, *args, **kwargs): - if not self.is_called: - self.args = args - self.kwargs = kwargs - else: - raise RuntimeError(f"{self.id} args updated after being called.") + @staticmethod + def _normalize_ops(ops: str | Sequence[str] | None) -> list[str]: + if ops is None: + return [] + if isinstance(ops, str): + return [ops] + return list(ops) def __init__( self, module: torch.nn.Module, test_data: T, - aten_ops: str | List[str], + aten_ops: str | Sequence[str] | None, compile_spec: ArmCompileSpec, - exir_ops: Optional[str | List[str]] = None, + exir_ops: str | Sequence[str] | None = None, use_to_edge_transform_and_lower: bool = True, dynamic_shapes: Optional[Tuple[Any]] = None, transform_passes: Optional[ @@ -120,15 +127,10 @@ def __init__( transform_passes=transform_passes, ) - self.aten_ops = aten_ops if isinstance(aten_ops, list) else [aten_ops] - if exir_ops is None: - self.exir_ops = [] - elif isinstance(exir_ops, list): - self.exir_ops = exir_ops - else: - self.exir_ops = [exir_ops] + self.aten_ops = self._normalize_ops(aten_ops) + self.exir_ops = self._normalize_ops(exir_ops) self.test_data = test_data - self._stages = [] + self._stages: list[PipelineStage] = [] self.add_stage(self.tester.export) self.add_stage(self.tester.check, self.aten_ops, suffix="aten") @@ -203,7 +205,7 @@ def add_stage(self, func: Callable, *args, **kwargs): if stage_id in id_list: raise ValueError("Suffix must be unique in pipeline") - pipeline_stage = self.PipelineStage(func, stage_id, *args, **kwargs) + pipeline_stage = PipelineStage(func, stage_id, *args, **kwargs) self._stages.insert(pos, pipeline_stage) logger.debug(f"Added stage {stage_id} to {type(self).__name__}") @@ -217,6 +219,8 @@ def pop_stage(self, identifier: int | str): elif isinstance(identifier, str): pos = self.find_pos(identifier) stage = self._stages.pop(pos) + else: + raise TypeError("identifier must be an int or str") logger.debug(f"Removed stage {stage.id} from {type(self).__name__}") @@ -244,19 +248,19 @@ def add_stage_after(self, stage_id: str, func: Callable, *args, **kwargs): self.add_stage(func, *args, **kwargs) return self - def dump_artifact(self, stage_id: str, suffix: str = None): + def dump_artifact(self, stage_id: str, suffix: str | None = None): """Adds a dump_artifact stage after the given stage id.""" self.add_stage_after(stage_id, self.tester.dump_artifact, suffix=suffix) return self - def dump_operator_distribution(self, stage_id: str, suffix: str = None): + def dump_operator_distribution(self, stage_id: str, suffix: str | None = None): """Adds a dump_operator_distribution stage after the given stage id.""" self.add_stage_after( stage_id, self.tester.dump_operator_distribution, suffix=suffix ) return self - def visualize(self, stage_id: str, suffix: str = None): + def visualize(self, stage_id: str, suffix: str | None = None): """Adds a dump_operator_distribution stage after the given stage id.""" self.add_stage_after(stage_id, self.tester.visualize, suffix=suffix) return self @@ -289,7 +293,7 @@ def is_tosa_ref_model_available(): # Not all deployments of ET have the TOSA reference model available. # Make sure we don't try to use it if it's not available. try: - import tosa_reference_model + import tosa_reference_model # type: ignore[import-not-found, import-untyped] # Check if the module has content return bool(dir(tosa_reference_model)) @@ -338,7 +342,7 @@ def __init__( symmetric_io_quantization: bool = False, per_channel_quantization: bool = True, use_to_edge_transform_and_lower: bool = True, - custom_path: str = None, + custom_path: str | None = None, tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None, atol: float = 1e-03, rtol: float = 1e-03, @@ -348,12 +352,12 @@ def __init__( ): if tosa_extensions is None: tosa_extensions = [] - tosa_profiles = { + tosa_profiles: dict[str, TosaSpecification] = { "1.0": TosaSpecification.create_from_string( "TOSA-1.0+INT" + "".join([f"+{ext}" for ext in tosa_extensions]) ), } - tosa_version = conftest.get_option("tosa_version") + tosa_version = _require_tosa_version() compile_spec = common.get_tosa_compile_spec( tosa_profiles[tosa_version], @@ -443,7 +447,7 @@ def __init__( exir_op: Optional[str | List[str]] = None, run_on_tosa_ref_model: bool = True, use_to_edge_transform_and_lower: bool = True, - custom_path: str = None, + custom_path: str | None = None, tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None, atol: float = 1e-03, rtol: float = 1e-03, @@ -456,12 +460,12 @@ def __init__( ): if tosa_extensions is None: tosa_extensions = [] - tosa_profiles = { + tosa_profiles: dict[str, TosaSpecification] = { "1.0": TosaSpecification.create_from_string( "TOSA-1.0+FP" + "".join([f"+{ext}" for ext in tosa_extensions]) ), } - tosa_version = conftest.get_option("tosa_version") + tosa_version = _require_tosa_version() compile_spec = common.get_tosa_compile_spec( tosa_profiles[tosa_version], @@ -524,7 +528,7 @@ def __init__( symmetric_io_quantization: bool = False, per_channel_quantization: bool = True, use_to_edge_transform_and_lower: bool = True, - custom_path: str = None, + custom_path: str | None = None, tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None, atol: float = 1e-03, rtol: float = 1e-03, @@ -610,12 +614,12 @@ def __init__( module: torch.nn.Module, test_data: T, aten_ops: str | List[str], - exir_ops: str | List[str] = None, + exir_ops: str | List[str] | None = None, run_on_fvp: bool = True, symmetric_io_quantization: bool = False, per_channel_quantization: bool = True, use_to_edge_transform_and_lower: bool = True, - custom_path: str = None, + custom_path: str | None = None, tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None, atol: float = 1e-03, rtol: float = 1e-03, @@ -715,20 +719,20 @@ def __init__( pass_list: Optional[List[Type[PassType]]] = None, pass_functions: Optional[List[Callable]] = None, passes_with_exported_program: Optional[List[Type[ExportPass]]] = None, - custom_path: str = None, + custom_path: str | None = None, tosa_extensions: Optional[List[str]] = None, ): if tosa_extensions is None: tosa_extensions = [] - tosa_profiles = { + tosa_profiles: dict[str, TosaSpecification] = { "1.0": TosaSpecification.create_from_string( "TOSA-1.0+" + ("INT" if quantize else "FP") + "".join([f"+{ext}" for ext in tosa_extensions]), ), } - tosa_version = conftest.get_option("tosa_version") - self.tosa_spec = tosa_profiles[tosa_version] + tosa_version = _require_tosa_version() + self.tosa_spec: TosaSpecification = tosa_profiles[tosa_version] compile_spec = common.get_tosa_compile_spec( self.tosa_spec, custom_path=custom_path @@ -758,9 +762,9 @@ def __init__( self.add_stage(self.tester.check_count, ops_before_pass, suffix="before") if ops_not_before_pass: self.add_stage(self.tester.check_not, ops_not_before_pass, suffix="before") - test_pass_stage = RunPasses( - pass_list, pass_functions, passes_with_exported_program - ) + test_pass_stage = RunPasses( # type: ignore[arg-type] + pass_list, pass_functions, passes_with_exported_program # type: ignore[arg-type] + ) # Legacy pass APIs expose callable classes rather than ExportPass subclasses self.add_stage(self.tester.run_passes, test_pass_stage) @@ -791,17 +795,17 @@ def __init__( self, module: torch.nn.Module, test_data: T, - custom_path: str = None, + custom_path: str | None = None, tosa_extensions: Optional[List[str]] = None, ): if tosa_extensions is None: tosa_extensions = [] - tosa_profiles = { + tosa_profiles: dict[str, TosaSpecification] = { "1.0": TosaSpecification.create_from_string( "TOSA-1.0+INT" + "".join([f"+{ext}" for ext in tosa_extensions]), ), } - tosa_version = conftest.get_option("tosa_version") + tosa_version = _require_tosa_version() compile_spec = common.get_tosa_compile_spec( tosa_profiles[tosa_version], custom_path=custom_path @@ -852,14 +856,14 @@ def __init__( test_data: T, non_delegated_ops: Dict[str, int], n_expected_delegates: int = 0, - custom_path: str = None, + custom_path: str | None = None, quantize: Optional[bool] = False, u55_subset: Optional[bool] = False, tosa_extensions: Optional[List[str]] = None, ): if tosa_extensions is None: tosa_extensions = [] - tosa_profiles = { + tosa_profiles: dict[str, TosaSpecification] = { "1.0": TosaSpecification.create_from_string( "TOSA-1.0+" + ("INT" if quantize else "FP") @@ -867,7 +871,7 @@ def __init__( + "".join([f"+{ext}" for ext in tosa_extensions]), ), } - tosa_version = conftest.get_option("tosa_version") + tosa_version = _require_tosa_version() tosa_spec = tosa_profiles[tosa_version] @@ -928,7 +932,7 @@ def __init__( symmetric_io_quantization: bool = False, per_channel_quantization: bool = True, use_to_edge_transform_and_lower: bool = True, - custom_path: str = None, + custom_path: str | None = None, tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None, atol: float = 1e-03, rtol: float = 1e-03,