diff --git a/backends/arm/operator_support/embedding_support.py b/backends/arm/operator_support/embedding_support.py index bf95014e575..24395d56cbf 100644 --- a/backends/arm/operator_support/embedding_support.py +++ b/backends/arm/operator_support/embedding_support.py @@ -27,11 +27,16 @@ class EmbeddingSupported(SupportedTOSAOperatorCheck): def is_node_tosa_supported( self, node: fx.Node, tosa_spec: TosaSpecification ) -> bool: # type: ignore[override, misc] - # Note aten.embedding.default requires int64 indices and TOSA does not support it. - # Int32 indices here for aten.embedding.default is ok since it will be decomposed into ops that can handle it. - assert ( - len(node.all_input_nodes) == 2 - ), "Number of inputs to aten.embedding is not 2" + # Note aten.embedding.default requires int64 indices and TOSA does not + # support it. Int32 indices here for aten.embedding.default is ok since + # it will be decomposed into ops that can handle it. + + if len(node.all_input_nodes) != 2: + self.reporter.report_reject( + node, + (f"Expected exactly two input nodes, got {len(node.all_input_nodes)}"), + ) + return False indices_val = node.all_input_nodes[1].meta["val"] indices_dtype = indices_val.dtype diff --git a/backends/arm/operator_support/ethos_u55_support.py b/backends/arm/operator_support/ethos_u55_support.py index bf9e29d5cb7..2e9bd846045 100644 --- a/backends/arm/operator_support/ethos_u55_support.py +++ b/backends/arm/operator_support/ethos_u55_support.py @@ -236,18 +236,20 @@ def is_node_supported( shape = input_node.meta["val"].shape rank = len(shape) if not -rank <= dim < rank: - raise IndexError( - f"Dim {dim} is outside of the range for tensor '{node.target}' of " - f"rank {rank}" + self.reporter.report_reject( + node, + (f"Dimension {dim} out of range for rank {rank}."), ) + return False dim = dim % rank size = shape[dim] if not -size <= index < size: - raise IndexError( - f"Index {index} is outside of the range for dim {dim} with size " - f"{size} for tensor {node.target}" + self.reporter.report_reject( + node, + (f"Index {index} out of range for dim {dim} with size {size}."), ) + return False index = index % size # Shape after squeeze. This may get converted into a view which may become diff --git a/backends/arm/operator_support/index_tensor_support.py b/backends/arm/operator_support/index_tensor_support.py index 4b226a9c407..25bc79ea938 100644 --- a/backends/arm/operator_support/index_tensor_support.py +++ b/backends/arm/operator_support/index_tensor_support.py @@ -111,16 +111,31 @@ def is_node_tosa_supported( for index in indices: # type: ignore[union-attr] # Usage 2 guard if index is None: + self.reporter.report_reject( + node, + ( + "None (from slice/unsqueeze/ellipsis) before an indexing tensor" + " is not supported." + ), + ) return False # Usage 1 guard fake_tensor = get_first_fake_tensor(index) # type: ignore[arg-type] if len(fake_tensor.size()) > 3: + self.reporter.report_reject( + node, + ("Indexing tensors of rank >= 4 is not supported."), + ) return False # Usage 3 guard total_vals = math.prod(get_first_fake_tensor(node.args[0]).shape) # type: ignore[arg-type] if total_vals > torch.iinfo(torch.int32).max: + self.reporter.report_reject( + node, + ("Value size exceeds int32 range; would overflow flattened indexing."), + ) return False return True diff --git a/backends/arm/operator_support/minmax_support.py b/backends/arm/operator_support/minmax_support.py index edbf7f61818..68433819f4b 100644 --- a/backends/arm/operator_support/minmax_support.py +++ b/backends/arm/operator_support/minmax_support.py @@ -32,6 +32,13 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): ) if not (no_argmax or no_argmax_users): + self.reporter.report_reject( + node, + ( + "Using the indices output is not supported; only usage of the " + "values output is supported." + ), + ) return False return True