Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self) -> None:
}
)

def permute_subgraph(self, subgraph):
def permute_subgraph(self, subgraph) -> bool:
# Original function will always permute constant nodes which is wrong for table ops
# Remove constant tosa.TABLE edges before running full function
new_constant_edges_in = set()
Expand All @@ -32,4 +32,4 @@ def permute_subgraph(self, subgraph):
new_constant_edges_in.add((const_node, user_node))

subgraph.constant_edges_in = new_constant_edges_in
super().permute_subgraph(subgraph)
return super().permute_subgraph(subgraph)
21 changes: 18 additions & 3 deletions backends/transforms/remove_permutes_around_elementwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901

modified = False
for subgraph in subgraphs_found:
self.permute_subgraph(subgraph)
modified = True
if self.permute_subgraph(subgraph):
modified = True

if modified:
graph_module.graph.eliminate_dead_code()
Expand Down Expand Up @@ -399,7 +399,20 @@ def is_node_permutable(self, node: torch.fx.Node) -> bool:
return True
return self._is_pointwise(node.target)

def permute_subgraph(self, subgraph: Subgraph) -> None: # noqa: C901
def permute_subgraph(self, subgraph: Subgraph) -> bool: # noqa: C901
# Validate: every view_copy node's permutation rank must match its
# input tensor rank. A mismatch can occur when a squeeze/unsqueeze
# view is reached via upstream traversal with a permutation that was
# already adapted to a different rank. Applying the optimisation in
# this case would produce an invalid graph, so skip the subgraph.
for node in subgraph.nodes:
if node.target in self._VIEW_OPS:
perm = subgraph.node_start_permute.get(node, subgraph.start_permute)
inp = node.args[0]
if isinstance(inp, torch.fx.Node) and inp.meta.get("val") is not None:
if len(perm) != len(inp.meta["val"].shape):
return False

# Handle dimension related node arguments FIRST, before
# bypassing permutes (which changes node inputs/metadata).
for node in subgraph.nodes:
Expand Down Expand Up @@ -480,6 +493,8 @@ def permute_subgraph(self, subgraph: Subgraph) -> None: # noqa: C901
assert out.target == exir_ops.edge.aten.permute_copy.default
out.replace_all_uses_with(inp)

return True

def update_cat(self, node: torch.fx.Node, start_permute: list[int]) -> None:
dim = get_arg(node, "dim", int)
set_arg(node, "dim", start_permute[dim])
Expand Down
64 changes: 64 additions & 0 deletions backends/transforms/test/test_permute_optimization_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,3 +998,67 @@ def test_permute_unsqueeze_copy_neg_dim_mul_squeeze_copy_permute(self) -> None:
[x_data],
"permute_unsqueeze_copy_neg_dim_mul_squeeze_copy_permute",
)

def test_upstream_view_rank_mismatch_no_crash(self) -> None:
"""Regression test for IndexError when a squeeze/unsqueeze view_copy
is reached via upstream traversal with a permutation whose rank does
not match the view's input tensor rank.

Graph:
full([16, 128], 1.0) x [1, 128, 16]
| |
view_copy (unsqueeze 2D→3D) permute [0, 2, 1]
[1, 16, 128] [1, 16, 128]
\\ /
---- add (3D) -----------
|
permute [0, 2, 1]
|
output

The view_copy (unsqueeze) is reached as an upstream input to `add`.
Its node_start_permute gets the 3D permutation [0, 2, 1], but its
input (the full op) is 2D. Before the fix, update_view_copy would
crash with IndexError: tuple index out of range."""
builder = GraphBuilder()
x_data = torch.randn(1, 128, 16)
x = builder.placeholder("x", x_data)
# 2D constant — treated as compile-time constant by _is_constant
const_2d = builder.call_operator(
op=exir_ops.edge.aten.full.default, args=([16, 128], 1.0)
)
# Unsqueeze via view_copy: [16, 128] → [1, 16, 128]
view_unsq = builder.call_operator(
op=exir_ops.edge.aten.view_copy.default, args=(const_2d, [1, 16, 128])
)
# Start permute: [1, 128, 16] → [1, 16, 128]
p1 = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 1])
)
# Add the permuted input with the unsqueezed constant
add = builder.call_operator(
op=exir_ops.edge.aten.add.Tensor, args=(p1, view_unsq)
)
# End permute: [1, 16, 128] → [1, 128, 16]
p2 = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default, args=(add, [0, 2, 1])
)
builder.output([p2])
original = builder.get_graph_module()
gm_before = copy.deepcopy(original)

# Should not crash, and should skip the subgraph due to rank mismatch
p = RemovePermutesAroundElementwiseOps()
result = cast(PassResult, p(original))
# The subgraph is skipped, so the graph should be unmodified
self.assertFalse(result.modified)
# Both permutes are preserved
self.assertEqual(
count_node(result.graph_module, exir_ops.edge.aten.permute_copy.default), 2
)
validate_numerics(
gm_before,
result.graph_module,
[x_data],
"upstream_view_rank_mismatch_no_crash",
)
Loading