diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 70470890317..037b2bb8bbd 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -112,6 +112,8 @@ from executorch.exir.pass_manager import PassManager from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass from torch.fx import GraphModule +from torch.fx.passes.infra.pass_base import PassResult +from torch.nn.modules import Module class ArmPassManager(PassManager): @@ -355,3 +357,20 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeMaskedFill()) return self._transform(graph_module) + + def __call__(self, module: Module) -> PassResult: + try: + return super().__call__(module) + except Exception as e: + first_exception = e.__cause__ or e.__context__ or e + import re + + message = e.args[0] + m = re.search(r"An error occurred when running the '([^']+)' pass", message) + if m: + pass_name = m.group(1) + first_exception.args = ( + f"{pass_name}: {first_exception.args[0]}", + *first_exception.args[1:], + ) + raise first_exception