Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions backends/arm/operator_support/embedding_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,16 @@ class EmbeddingSupported(SupportedTOSAOperatorCheck):
def is_node_tosa_supported(
self, node: fx.Node, tosa_spec: TosaSpecification
) -> bool: # type: ignore[override, misc]
# Note aten.embedding.default requires int64 indices and TOSA does not support it.
# Int32 indices here for aten.embedding.default is ok since it will be decomposed into ops that can handle it.
assert (
len(node.all_input_nodes) == 2
), "Number of inputs to aten.embedding is not 2"
# Note aten.embedding.default requires int64 indices and TOSA does not
# support it. Int32 indices here for aten.embedding.default is ok since
# it will be decomposed into ops that can handle it.

if len(node.all_input_nodes) != 2:
self.reporter.report_reject(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is cool!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:D

node,
(f"Expected exactly two input nodes, got {len(node.all_input_nodes)}"),
)
return False
indices_val = node.all_input_nodes[1].meta["val"]
indices_dtype = indices_val.dtype

Expand Down
14 changes: 8 additions & 6 deletions backends/arm/operator_support/ethos_u55_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,18 +236,20 @@ def is_node_supported(
shape = input_node.meta["val"].shape
rank = len(shape)
if not -rank <= dim < rank:
raise IndexError(
f"Dim {dim} is outside of the range for tensor '{node.target}' of "
f"rank {rank}"
self.reporter.report_reject(
node,
(f"Dimension {dim} out of range for rank {rank}."),
)
return False
dim = dim % rank

size = shape[dim]
if not -size <= index < size:
raise IndexError(
f"Index {index} is outside of the range for dim {dim} with size "
f"{size} for tensor {node.target}"
self.reporter.report_reject(
node,
(f"Index {index} out of range for dim {dim} with size {size}."),
)
return False
index = index % size

# Shape after squeeze. This may get converted into a view which may become
Expand Down
15 changes: 15 additions & 0 deletions backends/arm/operator_support/index_tensor_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,31 @@ def is_node_tosa_supported(
for index in indices: # type: ignore[union-attr]
# Usage 2 guard
if index is None:
self.reporter.report_reject(
node,
(
"None (from slice/unsqueeze/ellipsis) before an indexing tensor"
" is not supported."
),
)
return False

# Usage 1 guard
fake_tensor = get_first_fake_tensor(index) # type: ignore[arg-type]
if len(fake_tensor.size()) > 3:
self.reporter.report_reject(
node,
("Indexing tensors of rank >= 4 is not supported."),
)
return False

# Usage 3 guard
total_vals = math.prod(get_first_fake_tensor(node.args[0]).shape) # type: ignore[arg-type]
if total_vals > torch.iinfo(torch.int32).max:
self.reporter.report_reject(
node,
("Value size exceeds int32 range; would overflow flattened indexing."),
)
return False

return True
7 changes: 7 additions & 0 deletions backends/arm/operator_support/minmax_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
)

if not (no_argmax or no_argmax_users):
self.reporter.report_reject(
node,
(
"Using the indices output is not supported; only usage of the "
"values output is supported."
),
)
return False

return True
Loading