Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 135 additions & 112 deletions backends/cadence/aot/remove_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@

# pyre-strict

import logging
from dataclasses import dataclass, field
from typing import cast, List, Optional, Sequence, Set, Type
from typing import cast, List, Optional, Set, Type

# Import these for the cadence function signatures.
import executorch.backends.cadence.aot.ops_registrations # noqa: F401
Expand Down Expand Up @@ -69,45 +68,57 @@ class RemoveRedundantOps:


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class RemoveZeroSizedCatArgsPass(ExportPass):
def call_operator(
self,
op, # pyre-ignore
args: tuple[Argument, ...],
kwargs: dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op != exir_ops.edge.aten.cat.default:
return super().call_operator(op, args, kwargs, meta)

# Remove any zero-sized tensor arg to form a new args list.
cat_inputs: list[ProxyValue] = []
for arg in cast(Sequence[ProxyValue], args[0]):
if arg.to_tensor().numel() > 0:
cat_inputs.append(arg)
class RemoveZeroSizedCatArgsPass(RemoveOrReplacePassInterface):
@property
def targets(self) -> list[EdgeOpOverload]:
return [exir_ops.edge.aten.cat.default]

# If all the tensors were empty, we just return an empty tensor with
# the right shape.
def maybe_remove_or_replace(self, node: Node) -> bool:
# Get the cat inputs (first argument is a list of tensors)
cat_inputs_arg = node.args[0]

# Assert that cat_inputs_arg is iterable
assert isinstance(
cat_inputs_arg, (list, tuple)
), "cat_inputs_arg must be a sequence type"

# Filter out zero-sized tensors
cat_inputs: list[Node] = []
for arg in cat_inputs_arg:
if isinstance(arg, Node) and arg.meta.get("val") is not None:
if arg.meta["val"].numel() > 0:
cat_inputs.append(arg)

# If all tensors were empty, create a full op with the right shape
if not cat_inputs:
empty_shape = meta["val"].shape
dtype = meta["val"].dtype
return super().call_operator(
exir_ops.edge.aten.full.default,
(tuple(empty_shape), 0),
{"dtype": dtype},
meta,
)
empty_shape = node.meta["val"].shape
dtype = node.meta["val"].dtype
# Create a new full node
with node.graph.inserting_before(node):
full_node = node.graph.call_function(
exir_ops.edge.aten.full.default,
args=(tuple(empty_shape), 0),
kwargs={"dtype": dtype},
)
full_node.meta = node.meta.copy()
node.replace_all_uses_with(full_node)
return True

# If there was only one tensor in the cat_inputs list,
# we can safely erase this cat op.
# If only one tensor remains, replace with it
if len(cat_inputs) == 1:
return cat_inputs[0]
node.replace_all_uses_with(cat_inputs[0])
return True

# If the number of inputs changed, update the cat args
if len(cat_inputs) < len(cat_inputs_arg):
# Update the first argument with filtered inputs
new_args = list(node.args)
new_args[0] = cat_inputs
node.args = tuple(new_args)
return True

# Otherwise, we replace args[0] with cat_inputs.
new_args = list(args)
# pyre error introduced after D66937105
new_args[0] = cat_inputs # pyre-ignore[6]
return super().call_operator(op, tuple(new_args), kwargs, meta)
# No changes needed
return False


@register_cadence_pass(CadencePassAttribute(opt_level=0))
Expand Down Expand Up @@ -151,25 +162,29 @@ def maybe_remove_or_replace(self, node: Node) -> bool:


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class RemoveZeroSizedConstantPadNd(ExportPass):
def call_operator(
self,
op, # pyre-ignore
args: tuple[ProxyValue, tuple[int, ...], Argument],
kwargs: dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op != exir_ops.edge.aten.constant_pad_nd.default:
return super().call_operator(op, args, kwargs, meta)
class RemoveZeroSizedConstantPadNd(RemoveOrReplacePassInterface):
@property
def targets(self) -> list[EdgeOpOverload]:
return [exir_ops.edge.aten.constant_pad_nd.default]

input_tensor = args[0]
padding = args[1]
def maybe_remove_or_replace(self, node: Node) -> bool:
# Get padding argument (second argument)
if len(node.args) < 2:
return False

padding = node.args[1]
if not isinstance(padding, (list, tuple)):
return False

# If any padding value is non-zero, keep the node
if any(x != 0 for x in padding):
return super().call_operator(op, args, kwargs, meta)
return False

logging.debug(f"Erasing 0 sized constant pad nd node with {input_tensor}")
return input_tensor
# All padding is zero, replace with input
input_node = node.args[0]
assert isinstance(input_node, Node)
node.replace_all_uses_with(input_node)
return True


@register_cadence_pass(CadencePassAttribute(opt_level=1))
Expand Down Expand Up @@ -721,27 +736,27 @@ def get_squeeze_indices(self, view_node: Node) -> List[int]:

return squeeze_indices

def handle_squeeze(self, view_node: Node, visited_view_nodes: Set[Node]) -> None:
def handle_squeeze(self, view_node: Node, visited_view_nodes: Set[Node]) -> bool:
if view_node in visited_view_nodes:
return
return False

squeeze_indices = self.get_squeeze_indices(view_node)
if not squeeze_indices:
return
return False

# Only handle simple chains for now.
if len(view_node.users) != 1:
return
return False
node = next(iter(view_node.users))

# Traverse down from the node until finding another view op.
intermediate_slices = []
while node.target != exir_ops.edge.aten.view_copy.default:
# Only handle simple chains for now
if len(node.users) != 1:
return
return False
if node.target not in self.intermediate_ops:
return
return False
if node.target == exir_ops.edge.aten.slice_copy.Tensor:
intermediate_slices.append(node)
node = next(iter(node.users))
Expand All @@ -764,18 +779,22 @@ def handle_squeeze(self, view_node: Node, visited_view_nodes: Set[Node]) -> None
# Skip the initial view node.
input_node = cast(Node, get_arg(view_node, "input"))
view_node.replace_all_uses_with(input_node)
return True

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
visited_view_nodes = set()
modified = False
for view_node in graph_module.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.view_copy.default, sort=True
):
self.handle_squeeze(view_node, visited_view_nodes)
modified |= self.handle_squeeze(view_node, visited_view_nodes)

graph_module.graph.eliminate_dead_code()
graph_module.recompile()
if modified:
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
return super().call(graph_module)

return super().call(graph_module)
return PassResult(graph_module, False)


@register_cadence_pass(CadencePassAttribute(opt_level=1))
Expand All @@ -798,23 +817,27 @@ class RemoveBranchedQuantDequant(ExportPass):
}

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
self.remove_branched(
modified = self.remove_branched(
graph_module, self.quantize_op_packets, self.dequantize_op_packets
)
self.remove_branched(
modified |= self.remove_branched(
graph_module, self.dequantize_op_packets, self.quantize_op_packets
)

graph_module.graph.eliminate_dead_code()
result = super().call(graph_module)
return result
if modified:
graph_module.graph.eliminate_dead_code()
result = super().call(graph_module)
return result

return PassResult(graph_module, False)

def remove_branched(
self,
graph_module: torch.fx.GraphModule,
producer_pkts: set[EdgeOpOverloadPacket],
consumer_pkts: set[EdgeOpOverloadPacket],
) -> None:
) -> bool:
modified = False
for node in graph_module.graph.nodes:
if (
node.op != "call_function"
Expand All @@ -838,61 +861,62 @@ def remove_branched(
continue

user.replace_all_uses_with(node.args[0])
modified = True

return modified

class RemoveCatFromSliceCopyPass(ExportPass):

@register_cadence_pass(CadencePassAttribute(opt_level=1))
class RemoveCatFromSliceCopyPass(RemoveOrReplacePassInterface):
"""
Simplifies cat->slice_copy chains where one of the cat inputs can be directly passed
to the slice_copy.
"""

def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
for slice_copy_node in graph_module.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
):
cat_node = cast(Node, get_arg(slice_copy_node, "input"))
slice_dim = cast(int, get_arg(slice_copy_node, "dim"))
start_idx = cast(int, get_arg(slice_copy_node, "start"))
end_idx = cast(int, get_arg(slice_copy_node, "end"))
step = cast(int, get_arg(slice_copy_node, "step"))
@property
def targets(self) -> list[EdgeOpOverload]:
return [exir_ops.edge.aten.slice_copy.Tensor]

if cat_node.target != exir_ops.edge.aten.cat.default or step != 1:
continue
def maybe_remove_or_replace(self, node: Node) -> bool:
cat_node = cast(Node, get_arg(node, "input"))
slice_dim = cast(int, get_arg(node, "dim"))
start_idx = cast(int, get_arg(node, "start"))
end_idx = cast(int, get_arg(node, "end"))
step = cast(int, get_arg(node, "step"))

# Make sure cat and slice happens on the same dimension.
cat_dim = cast(Node, get_arg(cat_node, "dim"))
if cat_dim != slice_dim:
continue
if cat_node.target != exir_ops.edge.aten.cat.default or step != 1:
return False

# Make sure cat and slice happens on the same dimension.
cat_dim = cast(int, get_arg(cat_node, "dim"))
if cat_dim != slice_dim:
return False

# Canonicalize slice indices.
cat_output_shape = cat_node.meta["val"].shape
if start_idx is None:
start_idx = 0
elif start_idx < 0:
start_idx += cat_output_shape[cat_dim]
if end_idx is None or end_idx > cat_output_shape[cat_dim]:
end_idx = cat_output_shape[cat_dim]
elif end_idx < 0:
end_idx += cat_output_shape[cat_dim]

offset = 0
for cat_input_node in cast(List[Node], get_arg(cat_node, "tensors")):
cat_input_shape = cat_input_node.meta["val"].shape

# Check if the slice range overlaps with the cat input range.
if offset <= start_idx and end_idx <= offset + cat_input_shape[cat_dim]:
slice_copy_node.replace_input_with(cat_node, cat_input_node)
set_arg(slice_copy_node, "start", start_idx - offset)
set_arg(slice_copy_node, "end", end_idx - offset)
break

offset += cat_input_shape[cat_dim]
# Canonicalize slice indices.
cat_output_shape = cat_node.meta["val"].shape
if start_idx is None:
start_idx = 0
elif start_idx < 0:
start_idx += cat_output_shape[cat_dim]
if end_idx is None or end_idx > cat_output_shape[cat_dim]:
end_idx = cat_output_shape[cat_dim]
elif end_idx < 0:
end_idx += cat_output_shape[cat_dim]

offset = 0
for cat_input_node in cast(List[Node], get_arg(cat_node, "tensors")):
cat_input_shape = cat_input_node.meta["val"].shape

# Check if the slice range overlaps with the cat input range.
if offset <= start_idx and end_idx <= offset + cat_input_shape[cat_dim]:
node.replace_input_with(cat_node, cat_input_node)
set_arg(node, "start", start_idx - offset)
set_arg(node, "end", end_idx - offset)
return True

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
self._remove_unused_cat(graph_module)
graph_module.recompile()
graph_module.graph.eliminate_dead_code()
return super().call(graph_module)
offset += cat_input_shape[cat_dim]

return False


class CommonRemovePasses:
Expand All @@ -901,7 +925,6 @@ class CommonRemovePasses:
RemoveAliasCopyOpPass,
RemoveNopExpandOpPass,
RemoveNopSliceOrViewOpPass,
RemoveNopSelectOpPass,
RemoveToOpsPass,
RemoveZeroSizedCatArgsPass,
RemovePermutesAroundElementwiseOps,
Expand Down
Loading
Loading