From f876fd4bdd45a57368e46dc4624e14ab0a22ea11 Mon Sep 17 00:00:00 2001 From: Ekaterina Ignasheva Date: Fri, 1 Aug 2025 09:55:37 -0700 Subject: [PATCH] Call ExportPass() inside ReplaceNopTransposeOrPermuteWithViewPass::call(). (#13005) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/13005 Call ExportPass() inside ReplaceNopTransposeOrPermuteWithViewPass::call(). Reviewed By: hsharma35 Differential Revision: D79212506 --- backends/cadence/aot/pass_utils.py | 7 ++++++- backends/cadence/aot/replace_ops.py | 5 +++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/backends/cadence/aot/pass_utils.py b/backends/cadence/aot/pass_utils.py index b004f714f2b..9aedef2ce2f 100644 --- a/backends/cadence/aot/pass_utils.py +++ b/backends/cadence/aot/pass_utils.py @@ -13,7 +13,7 @@ from executorch.backends.cadence.aot.utils import get_edge_overload_packet from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket -from executorch.exir.pass_base import PassBase +from executorch.exir.pass_base import PassBase, PassResult from torch._ops import OpOverloadPacket @@ -224,3 +224,8 @@ def set_arg( node.update_arg(idx, value) else: node.update_kwarg(kwarg_name, value) + + +def none_throws(x: Optional[PassResult]) -> PassResult: + assert x is not None + return x diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 8e6516cadba..61ab7b4c40f 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -39,6 +39,7 @@ ) from executorch.backends.cadence.aot.pass_utils import ( CadencePassAttribute, + none_throws, register_cadence_pass, ) from executorch.backends.cadence.aot.remove_ops import RemoveNopSelectOpPass @@ -1661,8 +1662,8 @@ def call_operator(self, op, args, kwargs, meta): def call(self, graph_module: torch.fx.GraphModule) -> PassResult: result = super().call(graph_module) - result = FuseCascadedViewOps()(result.graph_module) - assert result is not None + fuse_cascaded_result = none_throws(FuseCascadedViewOps()(result.graph_module)) + result = none_throws(ExportPass()(fuse_cascaded_result.graph_module)) return result