From 57773ffc0d079ac1c9669cfdd9380e66b539c0bf Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Mon, 27 Oct 2025 16:44:58 +0100 Subject: [PATCH 1/2] Arm backend: Use output.name in node visitors As mentioned in #15381, TOSA tensors need unique naming, which gets tricky with submodules. It is handled in the TosaArg object, and therefore node visitors need to use output.name rather than node.name when creating new tensors. Signed-off-by: Erik Lundell Change-Id: I7a943deda0888c1de8796dd573e8befda3f074b2 --- backends/arm/operators/op_index_tensor.py | 10 +++++----- backends/arm/operators/op_mul.py | 4 ++-- backends/arm/operators/op_repeat.py | 2 +- backends/arm/operators/op_slice.py | 4 ++-- backends/arm/operators/op_tosa_matmul.py | 4 ++-- backends/arm/operators/op_tosa_resize.py | 6 +++--- backends/arm/operators/op_tosa_table.py | 19 ++++++++----------- backends/arm/operators/op_view.py | 2 +- 8 files changed, 24 insertions(+), 27 deletions(-) diff --git a/backends/arm/operators/op_index_tensor.py b/backends/arm/operators/op_index_tensor.py index 760e744923c..b2adb785ef6 100644 --- a/backends/arm/operators/op_index_tensor.py +++ b/backends/arm/operators/op_index_tensor.py @@ -165,14 +165,14 @@ def define_node( # channels and thus the stride-shift. data = np.full(index_shape, int(values_strides[i] / C)) mul_const = tosa_graph.addConst(index_shape, index_dtype, data) - tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_{i}_shift") + tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{output.name}_{i}_shift") attr = ts.TosaSerializerAttribute() attr.MulAttribute() self._serialize_operator( node, tosa_graph, ts.Op.MUL, - [index_name, mul_const.name, f"{node.name}_{i}_shift"], + [index_name, mul_const.name, f"{output.name}_{i}_shift"], [stride_shifted_indices.name], attr, ) @@ -186,7 +186,7 @@ def define_node( stride_shifted_indices.name, gather_idx_shape, reshaped_idxs.name, - shape_name_override=f"{node.name}_{i}_shape", + shape_name_override=f"{output.name}_{i}_shape", ) # Guarantees that the accumulation tensor is properly @@ -218,7 +218,7 @@ def define_node( values.name, gather_vals_shape, reshaped_input.name, - shape_name_override=f"{node.name}_index_shape", + shape_name_override=f"{output.name}_index_shape", ) gather_out_shape = (N, W, C) @@ -244,5 +244,5 @@ def define_node( gather_out.name, list(output_shape), output.name, - shape_name_override=f"{node.name}_output_shape", + shape_name_override=f"{output.name}_output_shape", ) diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index 78b0b1b6675..f1cd5de6fd6 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -48,14 +48,14 @@ def define_node( output.tosa_spec, ) - tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_shift") + tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{output.name}_shift") attr = ts.TosaSerializerAttribute() attr.MulAttribute() self._serialize_operator( node, tosa_graph, ts.Op.MUL, - [inputs[0].name, inputs[1].name, f"{node.name}_shift"], + [inputs[0].name, inputs[1].name, f"{output.name}_shift"], [output.name], attr, ) diff --git a/backends/arm/operators/op_repeat.py b/backends/arm/operators/op_repeat.py index 21a8f8e1b04..99c0ecce0b2 100644 --- a/backends/arm/operators/op_repeat.py +++ b/backends/arm/operators/op_repeat.py @@ -56,7 +56,7 @@ def define_node( (len(multiples),), ts.DType.SHAPE, list(tosa_shape(multiples, output.dim_order)), - name=node.name + "_multiples", + name=output.name + "_multiples", ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_slice.py b/backends/arm/operators/op_slice.py index c5510493eae..7366703083c 100644 --- a/backends/arm/operators/op_slice.py +++ b/backends/arm/operators/op_slice.py @@ -120,7 +120,7 @@ def define_node( (starts_len,), ts.DType.SHAPE, starts, - node.name + "_start_shape", + output.name + "_start_shape", ) sizes = [size if i == dim else shape[i] for i in input_node.dim_order] @@ -130,7 +130,7 @@ def define_node( sizes_len = 1 sizes = [0] sizes_tensor = tosa_graph.addConst( - (sizes_len,), ts.DType.SHAPE, sizes, node.name + "_sizes_shape" + (sizes_len,), ts.DType.SHAPE, sizes, output.name + "_sizes_shape" ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_tosa_matmul.py b/backends/arm/operators/op_tosa_matmul.py index 2281564a0c4..be73a60f7c7 100644 --- a/backends/arm/operators/op_tosa_matmul.py +++ b/backends/arm/operators/op_tosa_matmul.py @@ -72,8 +72,8 @@ def define_node( else: input0_zp, input1_zp = 0, 0 - input_A_ZP_name = f"{node.name}_A_ZP" - input_B_ZP_name = f"{node.name}_B_ZP" + input_A_ZP_name = f"{output.name}_A_ZP" + input_B_ZP_name = f"{output.name}_B_ZP" tosa_graph.addConst([1], inputs[0].dtype, [input0_zp], name=input_A_ZP_name) tosa_graph.addConst([1], inputs[1].dtype, [input1_zp], name=input_B_ZP_name) diff --git a/backends/arm/operators/op_tosa_resize.py b/backends/arm/operators/op_tosa_resize.py index fb8e305839f..40c6d4ce6a6 100644 --- a/backends/arm/operators/op_tosa_resize.py +++ b/backends/arm/operators/op_tosa_resize.py @@ -84,15 +84,15 @@ def in_int16_range(x): scale_d_vals[1], ] scales_tensor = tosa_graph.addConst( - [len(scales)], ts.DType.SHAPE, scales, node.name + "_scales" + [len(scales)], ts.DType.SHAPE, scales, output.name + "_scales" ) offset = [int(v) for v in offset_yx.tolist()] offset_tensor = tosa_graph.addConst( - [len(offset)], ts.DType.SHAPE, offset, node.name + "_offset" + [len(offset)], ts.DType.SHAPE, offset, output.name + "_offset" ) border = [int(v) for v in border_yx.tolist()] border_tensor = tosa_graph.addConst( - [len(border)], ts.DType.SHAPE, border, node.name + "_border" + [len(border)], ts.DType.SHAPE, border, output.name + "_border" ) attr = ts.TosaSerializerAttribute() attr.ResizeAttribute(resize_mode) diff --git a/backends/arm/operators/op_tosa_table.py b/backends/arm/operators/op_tosa_table.py index 11407517b6a..d867b5efd7b 100644 --- a/backends/arm/operators/op_tosa_table.py +++ b/backends/arm/operators/op_tosa_table.py @@ -44,27 +44,24 @@ def define_node( if inputs[0].dtype == ts.DType.INT16: validate_valid_dtype(self.target, output, ts.DType.INT32, output.tosa_spec) - if inputs[1].name not in self._exported_program.state_dict.keys(): # type: ignore[union-attr] + # The name of the table constant is a bit complex. + # The name of the pytorch buffer will be the target of last node argument. + # However, when it is serialized to TOSA, a submodule suffix might be added. The TOSA buffer name thus + # needs to be taken from the last TosaArg. + pytorch_table_buffer_name = node.args[-1].target # type: ignore[union-attr] + tosa_table_buffer_name = inputs[-1].name + if pytorch_table_buffer_name not in self._exported_program.state_dict.keys(): raise RuntimeError( f"Did not find key {node.name} in state_dict {self._exported_program.state_dict.keys()}." ) - table = self._exported_program.state_dict[inputs[1].name] # type: ignore[union-attr] - - table_tensor_name = node.name + "_table" - tosa_graph.addConst( - table.shape, - ts.DType.INT8 if inputs[0].dtype == ts.DType.INT8 else ts.DType.INT16, - table.detach().numpy(), - name=table_tensor_name, - ) attr = ts.TosaSerializerAttribute() attr.TableAttribute() self._serialize_operator( node, tosa_graph, ts.Op.TABLE, - [inputs[0].name, table_tensor_name], + [inputs[0].name, tosa_table_buffer_name], [output.name], attr, ) diff --git a/backends/arm/operators/op_view.py b/backends/arm/operators/op_view.py index f13c386a5ee..a32cb3aac06 100644 --- a/backends/arm/operators/op_view.py +++ b/backends/arm/operators/op_view.py @@ -66,7 +66,7 @@ def define_node( shape_len, ts.DType.SHAPE, shape_data, - name=node.name + "_shape", + name=output.name + "_shape", ) attr = ts.TosaSerializerAttribute() From c709e67638f6a845db590aed4b760418a449fed8 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Tue, 21 Oct 2025 16:48:14 +0200 Subject: [PATCH 2/2] Arm backend: Support conditional operator - Add partition check to make sure that the submodules with the if/else codepaths are fully delegated. - Fix some partitioning issues with submodule nodes, since they point to a submodule rather than a tensor they dont have a fake tensor. - Add node visitor. - Add tests. Signed-off-by: Erik Lundell Change-Id: I00dbfdedb04c686ce04b4fb1d682816038b7e1bf --- backends/arm/_passes/arm_pass_utils.py | 18 +- backends/arm/_passes/cast_int64_pass.py | 2 + .../arm/_passes/to_tosa_memory_format_pass.py | 2 + .../tosa_supported_operators.py | 147 ++++++++++++- backends/arm/operators/__init__.py | 1 + backends/arm/operators/op_cond_if.py | 61 ++++++ backends/arm/operators/ops_identity.py | 4 +- backends/arm/scripts/parse_test_names.py | 1 + backends/arm/test/ops/test_cond.py | 201 ++++++++++++++++++ backends/arm/tosa/backend.py | 13 ++ backends/arm/tosa/mapping.py | 15 +- 11 files changed, 446 insertions(+), 19 deletions(-) create mode 100644 backends/arm/operators/op_cond_if.py create mode 100644 backends/arm/test/ops/test_cond.py diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index de42c961d08..8590142d72c 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -31,11 +31,25 @@ from torch.export.graph_signature import InputKind +def is_submodule_node(node: torch.fx.Node): + if node.op not in ("get_attr", "placeholder"): + return False + try: + node.graph.owning_module.get_submodule(node.target) + except AttributeError: + return False + return True + + def is_get_attr_node(node: torch.fx.Node) -> bool: """ - Returns true if the given node is a get attr node for a tensor of the model + Returns true if the given node is a get attr node for a tensor of the model. """ - return isinstance(node, torch.fx.Node) and node.op == "get_attr" + return ( + isinstance(node, torch.fx.Node) + and node.op == "get_attr" + and not is_submodule_node(node) + ) def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool: diff --git a/backends/arm/_passes/cast_int64_pass.py b/backends/arm/_passes/cast_int64_pass.py index 4822c6c25c0..02a9cbeceaf 100644 --- a/backends/arm/_passes/cast_int64_pass.py +++ b/backends/arm/_passes/cast_int64_pass.py @@ -41,6 +41,8 @@ def _to_int32(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: if len(node.users) == 0: continue + if "val" not in node.meta: + continue fake_tensor = node.meta["val"] if not isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor): continue diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index 956eb77b62c..7e998e3a436 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -299,6 +299,8 @@ def remove_dim_order_kwargs( def call(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: + if "val" not in node.meta: + continue node_data = get_first_fake_tensor(node).data self.remove_dim_order_kwargs(graph_module, node) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 3a1d11eab8c..1f8405e8744 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -7,12 +7,15 @@ import itertools import operator import typing -from typing import final, Optional, Sequence, Type +from typing import cast, final, Optional, Sequence, Type import torch import torch.fx as fx -from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm._passes.arm_pass_utils import ( + get_first_fake_tensor, + is_submodule_node, +) from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( FuseQuantizedActivationPass, @@ -31,6 +34,7 @@ TOSA_PRO_INT_SupportList, ) from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.arm.tosa.specification import Tosa_1_00 from executorch.exir import ExportedProgram from executorch.exir.backend.utils import WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops @@ -110,7 +114,9 @@ def tosa_support_factory( Additional checks can be supplied to avoid partitioning additional nodes. """ # Postive checks: Add nodes to partitioning - positive_checks: list[OperatorSupportBase] = [] + positive_checks: list[OperatorSupportBase] = [ + CondSupported(exported_program, tosa_spec, reporter) + ] if tosa_spec.support_integer(): positive_checks.append(TOSAProINTSupportList()) @@ -350,7 +356,8 @@ def inside_int32_bounds(self, node: torch.fx.Node) -> bool: def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - + if is_submodule_node(node): + return True vals = node.meta["val"] tensor_list = vals if isinstance(vals, (list, tuple)) else [vals] @@ -390,7 +397,11 @@ def is_node_supported( # Ops with int64 inputs are only partitioned if input nodes are constant and will be partitioned. # If it is not partitioned, the partition will get an int64 input and fail. - for input_node in node.all_input_nodes: + for input_node in ( + input_node + for input_node in node.all_input_nodes + if input_node.op != "get_attr" + ): tensor_in = get_first_fake_tensor(input_node) if tensor_in.dtype != torch.int64: continue @@ -426,8 +437,13 @@ def __init__( def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - - for input_node in node.all_input_nodes: + if is_submodule_node(node): + return True + for input_node in ( + input_node + for input_node in node.all_input_nodes + if input_node.op != "get_attr" + ): tensor = get_first_fake_tensor(input_node) if tensor.dtype == torch.float64: self.reporter.report_reject( @@ -449,7 +465,13 @@ def __init__(self, reporter: WhyNoPartitionReporter, max_rank: int): def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - input_nodes = node.all_input_nodes + if is_submodule_node(node): + return True + input_nodes = ( + input_node + for input_node in node.all_input_nodes + if input_node.op != "get_attr" + ) # check if any input node has an unsupported rank for input_node in input_nodes: input_node_shape = get_first_fake_tensor(input_node).shape @@ -484,3 +506,112 @@ def is_node_supported( ) return False return True + + +class CondSupported(OperatorSupportBase): + """Checks whether the cond operator, and it's submodule args, should be partitioned.""" + + def __init__( + self, + exported_program: ExportedProgram, + tosa_spec: TosaSpecification, + reporter: WhyNoPartitionReporter, + ): + self.exported_program = exported_program + self.reporter = reporter + self.tosa_spec = tosa_spec + super().__init__() + + def _fully_partitioned(self, submodule: fx.GraphModule) -> bool: + partition_tag = None + for submodule_node in submodule.graph.nodes: + if submodule_node.op == "call_function": + # Input Q ops and output DQ ops will be de-tagged even if the submodule is fully supported. + if ( + submodule_node.target in Q_OPS + and list(submodule_node.all_input_nodes)[0].op == "placeholder" + ): + continue + if ( + submodule_node.target in DQ_OPS + and list(submodule_node.users)[0].op == "output" + ): + continue + if "delegation_tag" not in submodule_node.meta: + return False + if partition_tag is None: + partition_tag = submodule_node.meta["delegation_tag"] + elif submodule_node.meta["delegation_tag"] != partition_tag: + return False + return True + + def _cond_submodules_fully_partitioned(self, node: fx.Node) -> bool: + """Returns whether the submodule arguments to a cond node were fully partitioned. + Updates "val" meta of the submodules if they are. + """ + cond_submodules = ( + ( + self.exported_program.graph_module.get_submodule( + str(cast(torch.fx.Node, submodule_node).target) + ), + cast(torch.fx.Node, submodule_node), + ) + for submodule_node in node.args[1:3] + ) + for submodule, submodule_node in cond_submodules: + submodule = cast(torch.fx.GraphModule, submodule) + + if self._fully_partitioned(submodule): + submodule_node.meta["val"] = submodule.graph.output_node().meta["val"] + else: + return False + return True + + def is_node_supported( # noqa: C901 + self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node + ) -> bool: + if is_submodule_node(node): + if not isinstance(self.tosa_spec, Tosa_1_00): + self.reporter.report_reject( + node, "Control flow extension not supported for TOSA version <1.0" + ) + return False + if not self.tosa_spec.support_extension("cf"): + self.reporter.report_reject( + node, + f"TOSA spec {self.tosa_spec} does not support control flow extension.", + ) + return False + for user in node.users: + if user.target != torch.ops.higher_order.cond: + self.reporter.report_reject( + node, f"Submodule had unsupported user {user}" + ) + return False + if not self._cond_submodules_fully_partitioned(user): + self.reporter.report_reject( + node, "One submodule was not fully partitioned" + ) + return False + return True + if node.target == torch.ops.higher_order.cond: + if not isinstance(self.tosa_spec, Tosa_1_00): + self.reporter.report_reject( + node, "Control flow extension not supported for TOSA version <1.0" + ) + return False + if not self.tosa_spec.support_extension("cf"): + self.reporter.report_reject( + node, + f"TOSA spec {self.tosa_spec} does not support control flow extension.", + ) + return False + + if not self._cond_submodules_fully_partitioned(node): + self.reporter.report_reject( + node, "Submodule was not fully partitioned." + ) + return False + return True + + return False diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index a180d0a6e86..b3d8f5676d5 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -16,6 +16,7 @@ op_cat, op_ceil, op_clamp, + op_cond_if, op_constant_pad_nd, op_cos, op_eq, diff --git a/backends/arm/operators/op_cond_if.py b/backends/arm/operators/op_cond_if.py new file mode 100644 index 00000000000..5ea12f83a99 --- /dev/null +++ b/backends/arm/operators/op_cond_if.py @@ -0,0 +1,61 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +from typing import Any, cast, List + +import tosa_serializer as ts + +from executorch.backends.arm.operators.node_visitor import ( # type: ignore + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, + validate_valid_dtype, +) +from executorch.backends.arm.tosa.mapping import TosaArg # type: ignore +from executorch.backends.arm.tosa.specification import Tosa_1_00 +from torch.fx import Node + + +@register_node_visitor +class CondVisitor(NodeVisitor): + target = "cond" + + tosa_specs = NodeVisitor.tosa_specs + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + validate_num_inputs(self.target, inputs, 4) + validate_valid_dtype(self.target, [inputs[0]], ts.DType.BOOL, self.tosa_spec) + if not isinstance(self.tosa_spec, Tosa_1_00): + raise ValueError("Trying to lower cond, but TOSA version is <1.0.") + if not self.tosa_spec.support_extension("cf"): + raise ValueError( + f"Trying to lower cond, but TOSA specification {self.tosa_spec} does not support the cf extension." + ) + + attr = ts.TosaSerializerAttribute() + if_graph, else_graph = (cast(Node, arg).target for arg in node.args[1:3]) + attr.CondIfAttribute(if_graph, else_graph) + + self._serialize_operator( + node, + tosa_graph, + ts.Op.COND_IF, + [ + inputs[0].name, + *(subgraph_input.name for subgraph_input in inputs[-1].special), + ], + [output.name], + attr, + ) diff --git a/backends/arm/operators/ops_identity.py b/backends/arm/operators/ops_identity.py index d570c52ed31..6aa6a746664 100644 --- a/backends/arm/operators/ops_identity.py +++ b/backends/arm/operators/ops_identity.py @@ -40,8 +40,8 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - validate_num_inputs(self.target, inputs, 1) - validate_same_dtype(self.target, [*inputs, output], ts) + validate_num_inputs(self.target, inputs, [1, 2]) + validate_same_dtype(self.target, [inputs[0], output], ts) # Simply add an identityOp attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/scripts/parse_test_names.py b/backends/arm/scripts/parse_test_names.py index a663ba2e8b7..e388d5b30cb 100644 --- a/backends/arm/scripts/parse_test_names.py +++ b/backends/arm/scripts/parse_test_names.py @@ -7,6 +7,7 @@ # Add edge ops which we lower but which are not included in exir/dialects/edge/edge.yaml here. CUSTOM_EDGE_OPS = [ "linspace.default", + "cond.default", "eye.default", "expm1.default", "vector_norm.default", diff --git a/backends/arm/test/ops/test_cond.py b/backends/arm/test/ops/test_cond.py new file mode 100644 index 00000000000..eaf53cc3347 --- /dev/null +++ b/backends/arm/test/ops/test_cond.py @@ -0,0 +1,201 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Tuple + +import pytest + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + TosaPipelineFP, + TosaPipelineINT, +) + +aten_op = "torch.ops.higher_order.cond" +exir_op = "torch.ops.higher_order.cond" + +input_t1 = Tuple[torch.Tensor] +input_t2 = Tuple[torch.Tensor, torch.Tensor] + + +class CondZeroArgsOneOutput(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + def true_branch() -> torch.Tensor: + return torch.zeros(10) + + def false_branch() -> torch.Tensor: + return torch.ones(10) + + predicate = x.sum() > 0 + return torch.cond(predicate, true_branch, false_branch, []) + + +class CondOneArgOneOutput(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + def true_branch(arg: torch.Tensor) -> torch.Tensor: + return torch.sin(arg) + + def false_branch(arg: torch.Tensor) -> torch.Tensor: + return torch.cos(arg) + + predicate = x.sum() > 0 + return torch.cond(predicate, true_branch, false_branch, [x]) + + +class CondOneArgAndScalarOneOutput(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + def true_branch(arg: torch.Tensor) -> torch.Tensor: + return arg + 1.0 + + def false_branch(arg: torch.Tensor) -> torch.Tensor: + return arg - 1.0 + + predicate = x.sum() > 0 + return torch.cond(predicate, true_branch, false_branch, [x]) + + +class CondOneArgTwoOutputs(torch.nn.Module): + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def true_branch(arg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + return arg + torch.sin(arg), arg - torch.sin(arg) + + def false_branch(arg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + return arg - arg.mean(), arg + arg.mean() + + predicate = x.flatten().sum() > 0 + return torch.cond(predicate, true_branch, false_branch, [x]) + + +class CondNestedOneArgOneOutput(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + def inner_true(arg: torch.Tensor) -> torch.Tensor: + return arg + 1.0 + + def inner_false(arg: torch.Tensor) -> torch.Tensor: + return arg - 1.0 + + def outer_true(arg: torch.Tensor) -> torch.Tensor: + inner_predicate = arg.mean() > 0 + return torch.cond(inner_predicate, inner_true, inner_false, [arg]) + + def outer_false(arg: torch.Tensor) -> torch.Tensor: + return arg * 0.5 + + predicate = x.sum() > 0 + return torch.cond(predicate, outer_true, outer_false, [x]) + + +class CondMultipleOneArgOneOutput(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + def first_true(arg: torch.Tensor) -> torch.Tensor: + return arg + 2.0 + + def first_false(arg: torch.Tensor) -> torch.Tensor: + return arg - 2.0 + + first_predicate = x.sum() > 0 + intermediate = torch.cond(first_predicate, first_true, first_false, [x]) + + def second_true(arg: torch.Tensor) -> torch.Tensor: + return arg * 3.0 + + def second_false(arg: torch.Tensor) -> torch.Tensor: + return arg / 3.0 + + second_predicate = intermediate.mean() > 0 + return torch.cond(second_predicate, second_true, second_false, [intermediate]) + + +class CondTwoArgsOneOutput(torch.nn.Module): + def forward(self, lhs: torch.Tensor, rhs: torch.Tensor) -> torch.Tensor: + def true_branch(arg_l: torch.Tensor, arg_r: torch.Tensor) -> torch.Tensor: + return arg_l + arg_r + + def false_branch(arg_l: torch.Tensor, arg_r: torch.Tensor) -> torch.Tensor: + return arg_l - arg_r + + predicate = (lhs - rhs).sum() > 0 + return torch.cond(predicate, true_branch, false_branch, [lhs, rhs]) + + +class CondTwoArgsTwoOutputs(torch.nn.Module): + def forward( + self, lhs: torch.Tensor, rhs: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + def true_branch( + arg_l: torch.Tensor, arg_r: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + return arg_l + arg_r, arg_l * arg_r + + def false_branch( + arg_l: torch.Tensor, arg_r: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + diff = arg_l - arg_r + return diff, arg_l + diff + + predicate = (lhs * rhs).sum() > 0 + return torch.cond(predicate, true_branch, false_branch, [lhs, rhs]) + + +def _single_input_case( + module_factory: Callable[[], torch.nn.Module] +) -> Callable[[], tuple[torch.nn.Module, input_t1]]: + def _create() -> tuple[torch.nn.Module, input_t1]: + return module_factory(), (torch.randn(2, 3),) + + return _create + + +def _dual_input_case( + module_factory: Callable[[], torch.nn.Module] +) -> Callable[[], tuple[torch.nn.Module, input_t2]]: + def _create() -> tuple[torch.nn.Module, input_t2]: + return module_factory(), (torch.randn(2, 3), torch.randn(2, 3)) + + return _create + + +test_cases: dict[str, Callable[[], tuple[torch.nn.Module, tuple]]] = { + "zero_args_one_output": _single_input_case(CondZeroArgsOneOutput), + "one_arg_one_output": _single_input_case(CondOneArgOneOutput), + "one_arg_and_scalar_one_output": _single_input_case(CondOneArgAndScalarOneOutput), + "one_arg_two_outputs": _single_input_case(CondOneArgTwoOutputs), + "two_args_one_output": _dual_input_case(CondTwoArgsOneOutput), + "two_args_two_outputs": _dual_input_case(CondTwoArgsTwoOutputs), + "nested_one_arg_one_output": _single_input_case(CondNestedOneArgOneOutput), + "multiple_one_arg_one_output": _single_input_case(CondMultipleOneArgOneOutput), +} + + +@common.parametrize( + "case", + test_cases, + xfails={ + "one_arg_two_outputs": "Multiple outputs is not supported.", + "one_arg_and_scalar_one_output": "Scalars become get_attr nodes that are not supported.", + "two_args_two_outputs": "Nodes with multiple outputs are not properly supported.", + "multiple_one_arg_one_output": "Scalars become get_attr nodes that are not supported.", + }, +) +def test_cond_tosa_FP(case: Callable[[], tuple[torch.nn.Module, tuple]]): + module, example_inputs = case() + pipeline = TosaPipelineFP[tuple]( + module, example_inputs, aten_op, tosa_extensions=["cf"] + ) + pipeline.run() + + +@pytest.mark.skip("Quantization on submodules is not implemented yet.") +@common.parametrize( + "case", + test_cases, +) +def test_cond_tosa_INT(case: Callable[[], tuple[torch.nn.Module, tuple]]): + module, example_inputs = case() + pipeline = TosaPipelineINT[tuple]( + module, example_inputs, aten_op, tosa_extensions=["cf"] + ) + pipeline.run() diff --git a/backends/arm/tosa/backend.py b/backends/arm/tosa/backend.py index 42af22ff09f..fd2c7c74930 100644 --- a/backends/arm/tosa/backend.py +++ b/backends/arm/tosa/backend.py @@ -221,6 +221,19 @@ def _preprocess_module( # noqa: C901 process_placeholder(node, tosa_graph, edge_program, tosa_spec) elif node.op == "output": process_output(node, tosa_graph, tosa_spec) + elif node.op == "get_attr": + attr = getattr(graph_module, str(node.target), None) + if attr is None: + raise RuntimeError( + "get_attr node is not targeting anything in graph module." + ) + if not isinstance(attr, GraphModule): + raise RuntimeError( + "get_attr node is not targeting a GraphModule." + ) + + # If the above conditions are ok, we don't need to handle this node here. + # Only the string value of node.target is important. else: # This will only happen if an unpartitioned graph is passed without # any checking of compatibility. diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index 5162d2c6a53..cad2000a994 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -147,15 +147,16 @@ def __process_node(self, argument: torch.fx.Node): """ self.name = argument.name + argument.meta.get(TOSA_TENSOR_NAME_META, "") - output_dtype, self.shape, self.dim_order = extract_tensor_meta( - argument.meta, self.tosa_spec - ) - # Handle special case of types not representable in torch (i.e. i48_t) - if special_type := argument.meta.get(TosaSpecialDtype.meta_key(), None): - output_dtype = special_type.get_tosa_dtype() + if "val" in argument.meta: + output_dtype, self.shape, self.dim_order = extract_tensor_meta( + argument.meta, self.tosa_spec + ) + # Handle special case of types not representable in torch (i.e. i48_t) + if special_type := argument.meta.get(TosaSpecialDtype.meta_key(), None): + output_dtype = special_type.get_tosa_dtype() - self.dtype = output_dtype + self.dtype = output_dtype def __process_list(self, argument): """Capture a sequence argument as ``special``.