diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index de42c961d08..8590142d72c 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -31,11 +31,25 @@ from torch.export.graph_signature import InputKind +def is_submodule_node(node: torch.fx.Node): + if node.op not in ("get_attr", "placeholder"): + return False + try: + node.graph.owning_module.get_submodule(node.target) + except AttributeError: + return False + return True + + def is_get_attr_node(node: torch.fx.Node) -> bool: """ - Returns true if the given node is a get attr node for a tensor of the model + Returns true if the given node is a get attr node for a tensor of the model. """ - return isinstance(node, torch.fx.Node) and node.op == "get_attr" + return ( + isinstance(node, torch.fx.Node) + and node.op == "get_attr" + and not is_submodule_node(node) + ) def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool: diff --git a/backends/arm/_passes/cast_int64_pass.py b/backends/arm/_passes/cast_int64_pass.py index 4822c6c25c0..02a9cbeceaf 100644 --- a/backends/arm/_passes/cast_int64_pass.py +++ b/backends/arm/_passes/cast_int64_pass.py @@ -41,6 +41,8 @@ def _to_int32(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: if len(node.users) == 0: continue + if "val" not in node.meta: + continue fake_tensor = node.meta["val"] if not isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor): continue diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index 956eb77b62c..7e998e3a436 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -299,6 +299,8 @@ def remove_dim_order_kwargs( def call(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: + if "val" not in node.meta: + continue node_data = get_first_fake_tensor(node).data self.remove_dim_order_kwargs(graph_module, node) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 3a1d11eab8c..1f8405e8744 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -7,12 +7,15 @@ import itertools import operator import typing -from typing import final, Optional, Sequence, Type +from typing import cast, final, Optional, Sequence, Type import torch import torch.fx as fx -from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm._passes.arm_pass_utils import ( + get_first_fake_tensor, + is_submodule_node, +) from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( FuseQuantizedActivationPass, @@ -31,6 +34,7 @@ TOSA_PRO_INT_SupportList, ) from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.arm.tosa.specification import Tosa_1_00 from executorch.exir import ExportedProgram from executorch.exir.backend.utils import WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops @@ -110,7 +114,9 @@ def tosa_support_factory( Additional checks can be supplied to avoid partitioning additional nodes. """ # Postive checks: Add nodes to partitioning - positive_checks: list[OperatorSupportBase] = [] + positive_checks: list[OperatorSupportBase] = [ + CondSupported(exported_program, tosa_spec, reporter) + ] if tosa_spec.support_integer(): positive_checks.append(TOSAProINTSupportList()) @@ -350,7 +356,8 @@ def inside_int32_bounds(self, node: torch.fx.Node) -> bool: def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - + if is_submodule_node(node): + return True vals = node.meta["val"] tensor_list = vals if isinstance(vals, (list, tuple)) else [vals] @@ -390,7 +397,11 @@ def is_node_supported( # Ops with int64 inputs are only partitioned if input nodes are constant and will be partitioned. # If it is not partitioned, the partition will get an int64 input and fail. - for input_node in node.all_input_nodes: + for input_node in ( + input_node + for input_node in node.all_input_nodes + if input_node.op != "get_attr" + ): tensor_in = get_first_fake_tensor(input_node) if tensor_in.dtype != torch.int64: continue @@ -426,8 +437,13 @@ def __init__( def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - - for input_node in node.all_input_nodes: + if is_submodule_node(node): + return True + for input_node in ( + input_node + for input_node in node.all_input_nodes + if input_node.op != "get_attr" + ): tensor = get_first_fake_tensor(input_node) if tensor.dtype == torch.float64: self.reporter.report_reject( @@ -449,7 +465,13 @@ def __init__(self, reporter: WhyNoPartitionReporter, max_rank: int): def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - input_nodes = node.all_input_nodes + if is_submodule_node(node): + return True + input_nodes = ( + input_node + for input_node in node.all_input_nodes + if input_node.op != "get_attr" + ) # check if any input node has an unsupported rank for input_node in input_nodes: input_node_shape = get_first_fake_tensor(input_node).shape @@ -484,3 +506,112 @@ def is_node_supported( ) return False return True + + +class CondSupported(OperatorSupportBase): + """Checks whether the cond operator, and it's submodule args, should be partitioned.""" + + def __init__( + self, + exported_program: ExportedProgram, + tosa_spec: TosaSpecification, + reporter: WhyNoPartitionReporter, + ): + self.exported_program = exported_program + self.reporter = reporter + self.tosa_spec = tosa_spec + super().__init__() + + def _fully_partitioned(self, submodule: fx.GraphModule) -> bool: + partition_tag = None + for submodule_node in submodule.graph.nodes: + if submodule_node.op == "call_function": + # Input Q ops and output DQ ops will be de-tagged even if the submodule is fully supported. + if ( + submodule_node.target in Q_OPS + and list(submodule_node.all_input_nodes)[0].op == "placeholder" + ): + continue + if ( + submodule_node.target in DQ_OPS + and list(submodule_node.users)[0].op == "output" + ): + continue + if "delegation_tag" not in submodule_node.meta: + return False + if partition_tag is None: + partition_tag = submodule_node.meta["delegation_tag"] + elif submodule_node.meta["delegation_tag"] != partition_tag: + return False + return True + + def _cond_submodules_fully_partitioned(self, node: fx.Node) -> bool: + """Returns whether the submodule arguments to a cond node were fully partitioned. + Updates "val" meta of the submodules if they are. + """ + cond_submodules = ( + ( + self.exported_program.graph_module.get_submodule( + str(cast(torch.fx.Node, submodule_node).target) + ), + cast(torch.fx.Node, submodule_node), + ) + for submodule_node in node.args[1:3] + ) + for submodule, submodule_node in cond_submodules: + submodule = cast(torch.fx.GraphModule, submodule) + + if self._fully_partitioned(submodule): + submodule_node.meta["val"] = submodule.graph.output_node().meta["val"] + else: + return False + return True + + def is_node_supported( # noqa: C901 + self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node + ) -> bool: + if is_submodule_node(node): + if not isinstance(self.tosa_spec, Tosa_1_00): + self.reporter.report_reject( + node, "Control flow extension not supported for TOSA version <1.0" + ) + return False + if not self.tosa_spec.support_extension("cf"): + self.reporter.report_reject( + node, + f"TOSA spec {self.tosa_spec} does not support control flow extension.", + ) + return False + for user in node.users: + if user.target != torch.ops.higher_order.cond: + self.reporter.report_reject( + node, f"Submodule had unsupported user {user}" + ) + return False + if not self._cond_submodules_fully_partitioned(user): + self.reporter.report_reject( + node, "One submodule was not fully partitioned" + ) + return False + return True + if node.target == torch.ops.higher_order.cond: + if not isinstance(self.tosa_spec, Tosa_1_00): + self.reporter.report_reject( + node, "Control flow extension not supported for TOSA version <1.0" + ) + return False + if not self.tosa_spec.support_extension("cf"): + self.reporter.report_reject( + node, + f"TOSA spec {self.tosa_spec} does not support control flow extension.", + ) + return False + + if not self._cond_submodules_fully_partitioned(node): + self.reporter.report_reject( + node, "Submodule was not fully partitioned." + ) + return False + return True + + return False diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index a180d0a6e86..b3d8f5676d5 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -16,6 +16,7 @@ op_cat, op_ceil, op_clamp, + op_cond_if, op_constant_pad_nd, op_cos, op_eq, diff --git a/backends/arm/operators/op_cond_if.py b/backends/arm/operators/op_cond_if.py new file mode 100644 index 00000000000..5ea12f83a99 --- /dev/null +++ b/backends/arm/operators/op_cond_if.py @@ -0,0 +1,61 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +from typing import Any, cast, List + +import tosa_serializer as ts + +from executorch.backends.arm.operators.node_visitor import ( # type: ignore + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, + validate_valid_dtype, +) +from executorch.backends.arm.tosa.mapping import TosaArg # type: ignore +from executorch.backends.arm.tosa.specification import Tosa_1_00 +from torch.fx import Node + + +@register_node_visitor +class CondVisitor(NodeVisitor): + target = "cond" + + tosa_specs = NodeVisitor.tosa_specs + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + validate_num_inputs(self.target, inputs, 4) + validate_valid_dtype(self.target, [inputs[0]], ts.DType.BOOL, self.tosa_spec) + if not isinstance(self.tosa_spec, Tosa_1_00): + raise ValueError("Trying to lower cond, but TOSA version is <1.0.") + if not self.tosa_spec.support_extension("cf"): + raise ValueError( + f"Trying to lower cond, but TOSA specification {self.tosa_spec} does not support the cf extension." + ) + + attr = ts.TosaSerializerAttribute() + if_graph, else_graph = (cast(Node, arg).target for arg in node.args[1:3]) + attr.CondIfAttribute(if_graph, else_graph) + + self._serialize_operator( + node, + tosa_graph, + ts.Op.COND_IF, + [ + inputs[0].name, + *(subgraph_input.name for subgraph_input in inputs[-1].special), + ], + [output.name], + attr, + ) diff --git a/backends/arm/operators/op_index_tensor.py b/backends/arm/operators/op_index_tensor.py index 760e744923c..b2adb785ef6 100644 --- a/backends/arm/operators/op_index_tensor.py +++ b/backends/arm/operators/op_index_tensor.py @@ -165,14 +165,14 @@ def define_node( # channels and thus the stride-shift. data = np.full(index_shape, int(values_strides[i] / C)) mul_const = tosa_graph.addConst(index_shape, index_dtype, data) - tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_{i}_shift") + tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{output.name}_{i}_shift") attr = ts.TosaSerializerAttribute() attr.MulAttribute() self._serialize_operator( node, tosa_graph, ts.Op.MUL, - [index_name, mul_const.name, f"{node.name}_{i}_shift"], + [index_name, mul_const.name, f"{output.name}_{i}_shift"], [stride_shifted_indices.name], attr, ) @@ -186,7 +186,7 @@ def define_node( stride_shifted_indices.name, gather_idx_shape, reshaped_idxs.name, - shape_name_override=f"{node.name}_{i}_shape", + shape_name_override=f"{output.name}_{i}_shape", ) # Guarantees that the accumulation tensor is properly @@ -218,7 +218,7 @@ def define_node( values.name, gather_vals_shape, reshaped_input.name, - shape_name_override=f"{node.name}_index_shape", + shape_name_override=f"{output.name}_index_shape", ) gather_out_shape = (N, W, C) @@ -244,5 +244,5 @@ def define_node( gather_out.name, list(output_shape), output.name, - shape_name_override=f"{node.name}_output_shape", + shape_name_override=f"{output.name}_output_shape", ) diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index 78b0b1b6675..f1cd5de6fd6 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -48,14 +48,14 @@ def define_node( output.tosa_spec, ) - tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_shift") + tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{output.name}_shift") attr = ts.TosaSerializerAttribute() attr.MulAttribute() self._serialize_operator( node, tosa_graph, ts.Op.MUL, - [inputs[0].name, inputs[1].name, f"{node.name}_shift"], + [inputs[0].name, inputs[1].name, f"{output.name}_shift"], [output.name], attr, ) diff --git a/backends/arm/operators/op_repeat.py b/backends/arm/operators/op_repeat.py index 21a8f8e1b04..99c0ecce0b2 100644 --- a/backends/arm/operators/op_repeat.py +++ b/backends/arm/operators/op_repeat.py @@ -56,7 +56,7 @@ def define_node( (len(multiples),), ts.DType.SHAPE, list(tosa_shape(multiples, output.dim_order)), - name=node.name + "_multiples", + name=output.name + "_multiples", ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_slice.py b/backends/arm/operators/op_slice.py index c5510493eae..7366703083c 100644 --- a/backends/arm/operators/op_slice.py +++ b/backends/arm/operators/op_slice.py @@ -120,7 +120,7 @@ def define_node( (starts_len,), ts.DType.SHAPE, starts, - node.name + "_start_shape", + output.name + "_start_shape", ) sizes = [size if i == dim else shape[i] for i in input_node.dim_order] @@ -130,7 +130,7 @@ def define_node( sizes_len = 1 sizes = [0] sizes_tensor = tosa_graph.addConst( - (sizes_len,), ts.DType.SHAPE, sizes, node.name + "_sizes_shape" + (sizes_len,), ts.DType.SHAPE, sizes, output.name + "_sizes_shape" ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_tosa_matmul.py b/backends/arm/operators/op_tosa_matmul.py index 2281564a0c4..be73a60f7c7 100644 --- a/backends/arm/operators/op_tosa_matmul.py +++ b/backends/arm/operators/op_tosa_matmul.py @@ -72,8 +72,8 @@ def define_node( else: input0_zp, input1_zp = 0, 0 - input_A_ZP_name = f"{node.name}_A_ZP" - input_B_ZP_name = f"{node.name}_B_ZP" + input_A_ZP_name = f"{output.name}_A_ZP" + input_B_ZP_name = f"{output.name}_B_ZP" tosa_graph.addConst([1], inputs[0].dtype, [input0_zp], name=input_A_ZP_name) tosa_graph.addConst([1], inputs[1].dtype, [input1_zp], name=input_B_ZP_name) diff --git a/backends/arm/operators/op_tosa_resize.py b/backends/arm/operators/op_tosa_resize.py index fb8e305839f..40c6d4ce6a6 100644 --- a/backends/arm/operators/op_tosa_resize.py +++ b/backends/arm/operators/op_tosa_resize.py @@ -84,15 +84,15 @@ def in_int16_range(x): scale_d_vals[1], ] scales_tensor = tosa_graph.addConst( - [len(scales)], ts.DType.SHAPE, scales, node.name + "_scales" + [len(scales)], ts.DType.SHAPE, scales, output.name + "_scales" ) offset = [int(v) for v in offset_yx.tolist()] offset_tensor = tosa_graph.addConst( - [len(offset)], ts.DType.SHAPE, offset, node.name + "_offset" + [len(offset)], ts.DType.SHAPE, offset, output.name + "_offset" ) border = [int(v) for v in border_yx.tolist()] border_tensor = tosa_graph.addConst( - [len(border)], ts.DType.SHAPE, border, node.name + "_border" + [len(border)], ts.DType.SHAPE, border, output.name + "_border" ) attr = ts.TosaSerializerAttribute() attr.ResizeAttribute(resize_mode) diff --git a/backends/arm/operators/op_tosa_table.py b/backends/arm/operators/op_tosa_table.py index 11407517b6a..d867b5efd7b 100644 --- a/backends/arm/operators/op_tosa_table.py +++ b/backends/arm/operators/op_tosa_table.py @@ -44,27 +44,24 @@ def define_node( if inputs[0].dtype == ts.DType.INT16: validate_valid_dtype(self.target, output, ts.DType.INT32, output.tosa_spec) - if inputs[1].name not in self._exported_program.state_dict.keys(): # type: ignore[union-attr] + # The name of the table constant is a bit complex. + # The name of the pytorch buffer will be the target of last node argument. + # However, when it is serialized to TOSA, a submodule suffix might be added. The TOSA buffer name thus + # needs to be taken from the last TosaArg. + pytorch_table_buffer_name = node.args[-1].target # type: ignore[union-attr] + tosa_table_buffer_name = inputs[-1].name + if pytorch_table_buffer_name not in self._exported_program.state_dict.keys(): raise RuntimeError( f"Did not find key {node.name} in state_dict {self._exported_program.state_dict.keys()}." ) - table = self._exported_program.state_dict[inputs[1].name] # type: ignore[union-attr] - - table_tensor_name = node.name + "_table" - tosa_graph.addConst( - table.shape, - ts.DType.INT8 if inputs[0].dtype == ts.DType.INT8 else ts.DType.INT16, - table.detach().numpy(), - name=table_tensor_name, - ) attr = ts.TosaSerializerAttribute() attr.TableAttribute() self._serialize_operator( node, tosa_graph, ts.Op.TABLE, - [inputs[0].name, table_tensor_name], + [inputs[0].name, tosa_table_buffer_name], [output.name], attr, ) diff --git a/backends/arm/operators/op_view.py b/backends/arm/operators/op_view.py index f13c386a5ee..a32cb3aac06 100644 --- a/backends/arm/operators/op_view.py +++ b/backends/arm/operators/op_view.py @@ -66,7 +66,7 @@ def define_node( shape_len, ts.DType.SHAPE, shape_data, - name=node.name + "_shape", + name=output.name + "_shape", ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/ops_identity.py b/backends/arm/operators/ops_identity.py index d570c52ed31..6aa6a746664 100644 --- a/backends/arm/operators/ops_identity.py +++ b/backends/arm/operators/ops_identity.py @@ -40,8 +40,8 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - validate_num_inputs(self.target, inputs, 1) - validate_same_dtype(self.target, [*inputs, output], ts) + validate_num_inputs(self.target, inputs, [1, 2]) + validate_same_dtype(self.target, [inputs[0], output], ts) # Simply add an identityOp attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/scripts/parse_test_names.py b/backends/arm/scripts/parse_test_names.py index a663ba2e8b7..e388d5b30cb 100644 --- a/backends/arm/scripts/parse_test_names.py +++ b/backends/arm/scripts/parse_test_names.py @@ -7,6 +7,7 @@ # Add edge ops which we lower but which are not included in exir/dialects/edge/edge.yaml here. CUSTOM_EDGE_OPS = [ "linspace.default", + "cond.default", "eye.default", "expm1.default", "vector_norm.default", diff --git a/backends/arm/test/ops/test_cond.py b/backends/arm/test/ops/test_cond.py new file mode 100644 index 00000000000..eaf53cc3347 --- /dev/null +++ b/backends/arm/test/ops/test_cond.py @@ -0,0 +1,201 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Tuple + +import pytest + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + TosaPipelineFP, + TosaPipelineINT, +) + +aten_op = "torch.ops.higher_order.cond" +exir_op = "torch.ops.higher_order.cond" + +input_t1 = Tuple[torch.Tensor] +input_t2 = Tuple[torch.Tensor, torch.Tensor] + + +class CondZeroArgsOneOutput(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + def true_branch() -> torch.Tensor: + return torch.zeros(10) + + def false_branch() -> torch.Tensor: + return torch.ones(10) + + predicate = x.sum() > 0 + return torch.cond(predicate, true_branch, false_branch, []) + + +class CondOneArgOneOutput(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + def true_branch(arg: torch.Tensor) -> torch.Tensor: + return torch.sin(arg) + + def false_branch(arg: torch.Tensor) -> torch.Tensor: + return torch.cos(arg) + + predicate = x.sum() > 0 + return torch.cond(predicate, true_branch, false_branch, [x]) + + +class CondOneArgAndScalarOneOutput(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + def true_branch(arg: torch.Tensor) -> torch.Tensor: + return arg + 1.0 + + def false_branch(arg: torch.Tensor) -> torch.Tensor: + return arg - 1.0 + + predicate = x.sum() > 0 + return torch.cond(predicate, true_branch, false_branch, [x]) + + +class CondOneArgTwoOutputs(torch.nn.Module): + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def true_branch(arg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + return arg + torch.sin(arg), arg - torch.sin(arg) + + def false_branch(arg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + return arg - arg.mean(), arg + arg.mean() + + predicate = x.flatten().sum() > 0 + return torch.cond(predicate, true_branch, false_branch, [x]) + + +class CondNestedOneArgOneOutput(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + def inner_true(arg: torch.Tensor) -> torch.Tensor: + return arg + 1.0 + + def inner_false(arg: torch.Tensor) -> torch.Tensor: + return arg - 1.0 + + def outer_true(arg: torch.Tensor) -> torch.Tensor: + inner_predicate = arg.mean() > 0 + return torch.cond(inner_predicate, inner_true, inner_false, [arg]) + + def outer_false(arg: torch.Tensor) -> torch.Tensor: + return arg * 0.5 + + predicate = x.sum() > 0 + return torch.cond(predicate, outer_true, outer_false, [x]) + + +class CondMultipleOneArgOneOutput(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + def first_true(arg: torch.Tensor) -> torch.Tensor: + return arg + 2.0 + + def first_false(arg: torch.Tensor) -> torch.Tensor: + return arg - 2.0 + + first_predicate = x.sum() > 0 + intermediate = torch.cond(first_predicate, first_true, first_false, [x]) + + def second_true(arg: torch.Tensor) -> torch.Tensor: + return arg * 3.0 + + def second_false(arg: torch.Tensor) -> torch.Tensor: + return arg / 3.0 + + second_predicate = intermediate.mean() > 0 + return torch.cond(second_predicate, second_true, second_false, [intermediate]) + + +class CondTwoArgsOneOutput(torch.nn.Module): + def forward(self, lhs: torch.Tensor, rhs: torch.Tensor) -> torch.Tensor: + def true_branch(arg_l: torch.Tensor, arg_r: torch.Tensor) -> torch.Tensor: + return arg_l + arg_r + + def false_branch(arg_l: torch.Tensor, arg_r: torch.Tensor) -> torch.Tensor: + return arg_l - arg_r + + predicate = (lhs - rhs).sum() > 0 + return torch.cond(predicate, true_branch, false_branch, [lhs, rhs]) + + +class CondTwoArgsTwoOutputs(torch.nn.Module): + def forward( + self, lhs: torch.Tensor, rhs: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + def true_branch( + arg_l: torch.Tensor, arg_r: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + return arg_l + arg_r, arg_l * arg_r + + def false_branch( + arg_l: torch.Tensor, arg_r: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + diff = arg_l - arg_r + return diff, arg_l + diff + + predicate = (lhs * rhs).sum() > 0 + return torch.cond(predicate, true_branch, false_branch, [lhs, rhs]) + + +def _single_input_case( + module_factory: Callable[[], torch.nn.Module] +) -> Callable[[], tuple[torch.nn.Module, input_t1]]: + def _create() -> tuple[torch.nn.Module, input_t1]: + return module_factory(), (torch.randn(2, 3),) + + return _create + + +def _dual_input_case( + module_factory: Callable[[], torch.nn.Module] +) -> Callable[[], tuple[torch.nn.Module, input_t2]]: + def _create() -> tuple[torch.nn.Module, input_t2]: + return module_factory(), (torch.randn(2, 3), torch.randn(2, 3)) + + return _create + + +test_cases: dict[str, Callable[[], tuple[torch.nn.Module, tuple]]] = { + "zero_args_one_output": _single_input_case(CondZeroArgsOneOutput), + "one_arg_one_output": _single_input_case(CondOneArgOneOutput), + "one_arg_and_scalar_one_output": _single_input_case(CondOneArgAndScalarOneOutput), + "one_arg_two_outputs": _single_input_case(CondOneArgTwoOutputs), + "two_args_one_output": _dual_input_case(CondTwoArgsOneOutput), + "two_args_two_outputs": _dual_input_case(CondTwoArgsTwoOutputs), + "nested_one_arg_one_output": _single_input_case(CondNestedOneArgOneOutput), + "multiple_one_arg_one_output": _single_input_case(CondMultipleOneArgOneOutput), +} + + +@common.parametrize( + "case", + test_cases, + xfails={ + "one_arg_two_outputs": "Multiple outputs is not supported.", + "one_arg_and_scalar_one_output": "Scalars become get_attr nodes that are not supported.", + "two_args_two_outputs": "Nodes with multiple outputs are not properly supported.", + "multiple_one_arg_one_output": "Scalars become get_attr nodes that are not supported.", + }, +) +def test_cond_tosa_FP(case: Callable[[], tuple[torch.nn.Module, tuple]]): + module, example_inputs = case() + pipeline = TosaPipelineFP[tuple]( + module, example_inputs, aten_op, tosa_extensions=["cf"] + ) + pipeline.run() + + +@pytest.mark.skip("Quantization on submodules is not implemented yet.") +@common.parametrize( + "case", + test_cases, +) +def test_cond_tosa_INT(case: Callable[[], tuple[torch.nn.Module, tuple]]): + module, example_inputs = case() + pipeline = TosaPipelineINT[tuple]( + module, example_inputs, aten_op, tosa_extensions=["cf"] + ) + pipeline.run() diff --git a/backends/arm/tosa/backend.py b/backends/arm/tosa/backend.py index 69c49f654f0..2200cc7f86d 100644 --- a/backends/arm/tosa/backend.py +++ b/backends/arm/tosa/backend.py @@ -269,6 +269,19 @@ def _preprocess_module( # noqa: C901 process_placeholder(node, tosa_graph, edge_program, tosa_spec) elif node.op == "output": process_output(node, tosa_graph, tosa_spec) + elif node.op == "get_attr": + attr = getattr(graph_module, str(node.target), None) + if attr is None: + raise RuntimeError( + "get_attr node is not targeting anything in graph module." + ) + if not isinstance(attr, GraphModule): + raise RuntimeError( + "get_attr node is not targeting a GraphModule." + ) + + # If the above conditions are ok, we don't need to handle this node here. + # Only the string value of node.target is important. else: # This will only happen if an unpartitioned graph is passed without # any checking of compatibility. diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index 5162d2c6a53..cad2000a994 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -147,15 +147,16 @@ def __process_node(self, argument: torch.fx.Node): """ self.name = argument.name + argument.meta.get(TOSA_TENSOR_NAME_META, "") - output_dtype, self.shape, self.dim_order = extract_tensor_meta( - argument.meta, self.tosa_spec - ) - # Handle special case of types not representable in torch (i.e. i48_t) - if special_type := argument.meta.get(TosaSpecialDtype.meta_key(), None): - output_dtype = special_type.get_tosa_dtype() + if "val" in argument.meta: + output_dtype, self.shape, self.dim_order = extract_tensor_meta( + argument.meta, self.tosa_spec + ) + # Handle special case of types not representable in torch (i.e. i48_t) + if special_type := argument.meta.get(TosaSpecialDtype.meta_key(), None): + output_dtype = special_type.get_tosa_dtype() - self.dtype = output_dtype + self.dtype = output_dtype def __process_list(self, argument): """Capture a sequence argument as ``special``.