diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index d78bdfeba6e..d85a0cc9be4 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -2300,6 +2300,52 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: return result +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class ReplaceMulTensorWithMulAndFullOpsPass(ExportPass): + """ + Extracts a single value argument of mul op to a separate full op. + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for mul_node in graph_module.graph.find_nodes( + op="call_function", target=torch.ops.aten.mul.Tensor + ): + x_arg, const_arg = mul_node.args + + # Swap arguments if the order is wrong + if isinstance(const_arg, torch.fx.Node): + x_arg, const_arg = const_arg, x_arg + + # Skip if the const_arg is not a scalar + if not isinstance(const_arg, (float, int)) or not isinstance( + x_arg, torch.fx.Node + ): + continue + + # Cast the const_arg to the dtype of the x_arg + full_arg = self.resolve_full_arg(x_arg, const_arg) + + # Extract an argument to a separate full op. + with graph_module.graph.inserting_before(mul_node): + full_tensor = graph_module.graph.call_function( + exir_ops.edge.aten.full.default, args=([1], full_arg) + ) + new_mul_node = graph_module.graph.call_function( + torch.ops.aten.mul.Tensor, args=(x_arg, full_tensor) + ) + # Replace the old mul with a newly created mul. + mul_node.replace_all_uses_with(new_mul_node) + graph_module.graph.erase_node(mul_node) + return super().call(graph_module) + + def resolve_full_arg(self, x_arg, const_arg): + if x_arg.meta["val"].dtype == torch.float32 and isinstance(const_arg, int): + const_arg = float(const_arg) + if x_arg.meta["val"].dtype == torch.int32 and isinstance(const_arg, float): + const_arg = int(const_arg) + return const_arg + + # This class encapsulates all the functions that replace/switch one op in the # graph with another. class CadenceReplaceOpsInGraph: diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 4ff84a296e8..41002cda009 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -15,7 +15,7 @@ GraphBuilder, single_op_builder, ) -from executorch.backends.cadence.aot.pass_utils import count_node +from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match from executorch.backends.cadence.aot.replace_ops import ( ForceChannelLastForConvPass, MakeSliceAndCatDimOutermostPass, @@ -31,6 +31,7 @@ ReplaceLinearWithFullyConnectedOpPass, ReplaceMatmulWithTransposedMatmulPass, ReplaceMMWithAddMMPass, + ReplaceMulTensorWithMulAndFullOpsPass, ReplaceNopTransposeOrPermuteWithViewPass, ReplacePadWithCatPass, ReplacePermuteWithTransposePass, @@ -1875,3 +1876,30 @@ def test_empty_slice(self): ), 1, ) + + @parameterized.expand( + [ + ("int", int(123)), + ("float", float(456.0)), + ], + ) + @torch.no_grad() + def test_extract_mul_argument_to_full(self, _, value) -> None: + x = torch.randn(2, 1, 64) + gm = single_op_builder( + placeholders=(x,), + op=torch.ops.aten.mul.Tensor, + args=(x, value), + kwargs={}, + ) + p = ReplaceMulTensorWithMulAndFullOpsPass() + graph_after_passes = p.call(gm).graph_module + self.assertTrue( + op_counts_match( + graph_after_passes, + expected_op_counts={ + torch.ops.aten.mul.Tensor: 1, + exir_ops.edge.aten.full.default: 1, + }, + ) + )