diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 3950f1852df..0dac813675c 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -2329,7 +2329,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # Extract an argument to a separate full op. with graph_module.graph.inserting_before(mul_node): full_tensor = graph_module.graph.call_function( - exir_ops.edge.aten.full.default, args=([1], full_arg) + torch.ops.aten.full.default, args=([1], full_arg) ) new_mul_node = graph_module.graph.call_function( torch.ops.aten.mul.Tensor, args=(x_arg, full_tensor) @@ -2449,4 +2449,5 @@ class CadenceReplaceOpsInGraph: ReplaceAtenApproxGeluWithApproxGeluPass, ReplaceSplitWithSlicePass, ReplacePowWithMulPass, + ReplaceMulTensorWithMulAndFullOpsPass, ] diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 0537889d2c2..d778cd5b898 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -1933,7 +1933,7 @@ def test_extract_mul_argument_to_full( graph_after_passes, expected_op_counts={ torch.ops.aten.mul.Tensor: 1, - exir_ops.edge.aten.full.default: 1, + torch.ops.aten.full.default: 1, }, ) )