From ffe5229a76170b9debe66a1cd6375ece7859b510 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Thu, 17 Apr 2025 12:53:36 -0700 Subject: [PATCH] [exir] Allow verifiers in _transform This is to allow users to transform an ExportedProgram using passes in places where it may result in a dialect that is not compliant with the original creation context. For example, if an ExportedProgram was created in an edge dialect and now needs to be run and transformed in a way that is not compliant with the EdgeDialectVerifier, such as in a delegate preprocess() function, then the user may want to override the verifier with their own or simply disable it. Differential Revision: [D73205727](https://our.internmc.facebook.com/intern/diff/D73205727/) [ghstack-poisoned] --- exir/program/_program.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/exir/program/_program.py b/exir/program/_program.py index 687852234dc..1df3d71a8d3 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -212,7 +212,23 @@ def _get_updated_graph_signature( return new_signature -def _transform(self, *passes: PassType) -> "ExportedProgram": +def _transform( + self, + *passes: PassType, + override_verifiers: None | list[Type[Verifier]] = None, +) -> "ExportedProgram": + """ + Transforms the program according to the provided passes. + + Args: + self: The ExportedProgram instance to transform + *passes: A sequence of passes to apply to the program + override_verifiers: Optional list of verifier classes to use instead of the default verifiers. + This is needed if the transforms yields illegal graph that the default verifier cannot handle. + + Returns: + ExportedProgram: A new ExportedProgram with the transformations applied, or self if no changes were made + """ pm = PassManager(list(passes)) res = pm(self.graph_module) transformed_gm = res.graph_module if res is not None else self.graph_module @@ -221,7 +237,9 @@ def _transform(self, *passes: PassType) -> "ExportedProgram": if transformed_gm is self.graph_module and not res.modified: return self - return _update_exported_program_graph_module(self, transformed_gm) + return _update_exported_program_graph_module( + self, transformed_gm, override_verifiers + ) def _update_exported_program_graph_module(