Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 50 additions & 51 deletions backends/cadence/aot/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from executorch.backends.cadence.aot.pass_utils import (
CadencePassAttribute,
register_cadence_pass,
RemoveOrReplacePassInterface,
)
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
from executorch.exir.dialects._ops import ops as exir_ops
Expand Down Expand Up @@ -454,7 +455,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class FuseCascadedTransposeOrPermuteOps(ExportPass):
class FuseCascadedTransposeOrPermuteOps(RemoveOrReplacePassInterface):
"""
Fuse a cascaded chain of transpose and permute ops
"""
Expand All @@ -464,63 +465,61 @@ class FuseCascadedTransposeOrPermuteOps(ExportPass):
exir_ops.edge.aten.permute_copy.default,
}

# Find a chain of transpose or permute ops, and fuse them into a single permute op.
@property
def targets(self) -> list[EdgeOpOverload]:
return list(self.transpose_or_permute_target)

def fuse_cascaded_transpose_or_permute_ops(
self, graph_module: torch.fx.GraphModule
):
graph = graph_module.graph
for node in graph.nodes:
# We are only interested in permute/transpose ops
if node.target not in self.transpose_or_permute_target:
continue
# Get the cascaded chain of transpose/permute ops starting at node
cascaded_transpose_or_permute_ops = get_cascaded_ops(
[node], self.transpose_or_permute_target
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
# Get the cascaded chain of transpose/permute ops starting at node
cascaded_transpose_or_permute_ops = get_cascaded_ops(
[node], self.transpose_or_permute_target
)
# The chain must have more than 1 node
if len(cascaded_transpose_or_permute_ops) == 1:
return False

# Get shape from node metadata
val = node.meta.get("val")
if val is None:
return False
out_shape = val.shape
out_dims = len(out_shape)

# This is the trivial dimension order
dims = list(range(out_dims))
# Compute the effect of the chain on dims
for tp in cascaded_transpose_or_permute_ops:
dims = (
get_transposed_dims(tp, dims)
if tp.target == exir_ops.edge.aten.transpose_copy.int
else get_permuted_dims(tp, dims)
)
# The chain must have more than 1 node
if len(cascaded_transpose_or_permute_ops) == 1:
continue

out_shape = get_shape(graph_module, node)
assert out_shape is not None
out_dims = len(out_shape)
# This is the trivial dimension order
dims = list(range(out_dims))
# Compute the effect of the chain on dims
for tp in cascaded_transpose_or_permute_ops:
dims = (
get_transposed_dims(tp, dims)
if tp.target == exir_ops.edge.aten.transpose_copy.int
else get_permuted_dims(tp, dims)
)
graph = node.graph

# In case the permute chain cancelled each other, the final dims will
# be the same as the initial order. In that case, the chain was nop.
# Otherwise create a new permute op that encompasses the effect of the
# chain.
if dims == list(range(out_dims)):
cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(
node.args[0]
# In case the permute chain cancelled each other, the final dims will
# be the same as the initial order. In that case, the chain was nop.
# Otherwise create a new permute op that encompasses the effect of the
# chain.
if dims == list(range(out_dims)):
cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(
cast(torch.fx.Node, node.args[0])
)
else:
with graph.inserting_before(cascaded_transpose_or_permute_ops[-1]):
new_permute = graph.call_function(
exir_ops.edge.aten.permute_copy.default,
args=(node.args[0], dims),
)
else:
with graph.inserting_before(cascaded_transpose_or_permute_ops[-1]):
new_permute = graph.call_function(
exir_ops.edge.aten.permute_copy.default,
args=(node.args[0], dims),
)
cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(new_permute)
new_permute.meta = cascaded_transpose_or_permute_ops[-1].meta
cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(new_permute)

# Now erase the chain
for tp in reversed(cascaded_transpose_or_permute_ops):
graph.erase_node(tp)

graph_module.recompile()
# Now erase the chain (except the first node which will be handled by the interface)
for tp in reversed(cascaded_transpose_or_permute_ops[1:]):
graph.erase_node(tp)

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
self.fuse_cascaded_transpose_or_permute_ops(graph_module)
result = super().call(graph_module)
return result
# Return True to indicate the first node in the chain should be removed
return True


@register_cadence_pass(CadencePassAttribute(opt_level=1))
Expand Down
Loading
Loading