Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2300,6 +2300,52 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
return result


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class ReplaceMulTensorWithMulAndFullOpsPass(ExportPass):
"""
Extracts a single value argument of mul op to a separate full op.
"""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
for mul_node in graph_module.graph.find_nodes(
op="call_function", target=torch.ops.aten.mul.Tensor
):
x_arg, const_arg = mul_node.args

# Swap arguments if the order is wrong
if isinstance(const_arg, torch.fx.Node):
x_arg, const_arg = const_arg, x_arg

# Skip if the const_arg is not a scalar
if not isinstance(const_arg, (float, int)) or not isinstance(
x_arg, torch.fx.Node
):
continue

# Cast the const_arg to the dtype of the x_arg
full_arg = self.resolve_full_arg(x_arg, const_arg)

# Extract an argument to a separate full op.
with graph_module.graph.inserting_before(mul_node):
full_tensor = graph_module.graph.call_function(
exir_ops.edge.aten.full.default, args=([1], full_arg)
)
new_mul_node = graph_module.graph.call_function(
torch.ops.aten.mul.Tensor, args=(x_arg, full_tensor)
)
# Replace the old mul with a newly created mul.
mul_node.replace_all_uses_with(new_mul_node)
graph_module.graph.erase_node(mul_node)
return super().call(graph_module)

def resolve_full_arg(self, x_arg, const_arg):
if x_arg.meta["val"].dtype == torch.float32 and isinstance(const_arg, int):
const_arg = float(const_arg)
if x_arg.meta["val"].dtype == torch.int32 and isinstance(const_arg, float):
const_arg = int(const_arg)
return const_arg


# This class encapsulates all the functions that replace/switch one op in the
# graph with another.
class CadenceReplaceOpsInGraph:
Expand Down
30 changes: 29 additions & 1 deletion backends/cadence/aot/tests/test_replace_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
GraphBuilder,
single_op_builder,
)
from executorch.backends.cadence.aot.pass_utils import count_node
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
from executorch.backends.cadence.aot.replace_ops import (
ForceChannelLastForConvPass,
MakeSliceAndCatDimOutermostPass,
Expand All @@ -31,6 +31,7 @@
ReplaceLinearWithFullyConnectedOpPass,
ReplaceMatmulWithTransposedMatmulPass,
ReplaceMMWithAddMMPass,
ReplaceMulTensorWithMulAndFullOpsPass,
ReplaceNopTransposeOrPermuteWithViewPass,
ReplacePadWithCatPass,
ReplacePermuteWithTransposePass,
Expand Down Expand Up @@ -1875,3 +1876,30 @@ def test_empty_slice(self):
),
1,
)

@parameterized.expand(
[
("int", int(123)),
("float", float(456.0)),
],
)
@torch.no_grad()
def test_extract_mul_argument_to_full(self, _, value) -> None:
x = torch.randn(2, 1, 64)
gm = single_op_builder(
placeholders=(x,),
op=torch.ops.aten.mul.Tensor,
args=(x, value),
kwargs={},
)
p = ReplaceMulTensorWithMulAndFullOpsPass()
graph_after_passes = p.call(gm).graph_module
self.assertTrue(
op_counts_match(
graph_after_passes,
expected_op_counts={
torch.ops.aten.mul.Tensor: 1,
exir_ops.edge.aten.full.default: 1,
},
)
)
Loading