diff --git a/exir/program/_program.py b/exir/program/_program.py index 7b0ecccca9d..e0484f4f4ff 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -212,7 +212,30 @@ def _get_updated_graph_signature( return new_signature -def _transform(self, *passes: PassType) -> "ExportedProgram": +def _transform( + self, + *passes: PassType, + override_verifiers: None | list[Type[Verifier]] = None, +) -> "ExportedProgram": + """ + Transforms the program according to the provided passes. + + Args: + self: The ExportedProgram instance to transform + *passes: A sequence of passes to apply to the program + override_verifiers: Optional list of verifier classes to use instead of the default verifiers. + This is needed if the transforms yields illegal graph that the default verifier cannot handle. + + Returns: + ExportedProgram: A new ExportedProgram with the transformations applied, or self if no changes were made + """ + # A user friendly check to avoid vararg surprises, PEP 3102 + assert not any( + isinstance(p, (list, Verifier)) for p in passes + ), f"Expected all passes to be of PassType, not list or Verifier. Use override_verifiers kwarg instead. Got: {list(passes)}" + + for p in list(passes): + print(type(p)) pm = PassManager(list(passes)) res = pm(self.graph_module) transformed_gm = res.graph_module if res is not None else self.graph_module @@ -221,7 +244,9 @@ def _transform(self, *passes: PassType) -> "ExportedProgram": if transformed_gm is self.graph_module and not res.modified: return self - return _update_exported_program_graph_module(self, transformed_gm) + return _update_exported_program_graph_module( + self, transformed_gm, override_verifiers + ) def _update_exported_program_graph_module( diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index fd7c51fdccb..bf8683ebb69 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -22,6 +22,7 @@ from executorch.exir.pass_base import ExportPass from executorch.exir.passes import MemoryPlanningPass from executorch.exir.program._program import ( + _transform, EdgeProgramManager, ExecutorchProgramManager, to_edge, @@ -34,6 +35,7 @@ from executorch.extension.pybindings.portable_lib import ( _load_for_executorch_from_buffer, ) +from torch._export.verifier import Verifier from torch.export import Dim, export, ExportedProgram from torch.export._trace import _export @@ -273,7 +275,6 @@ def get_executorch_memory_planning_passes() -> Dict[str, MemoryPlanningPass]: for output_val in method.outputs: evalue = method.values[output_val] self.assertNotEqual(evalue.val.allocation_info, None) - else: for input_val in method.inputs: evalue = method.values[input_val] self.assertEqual(evalue.val.allocation_info, None) @@ -847,3 +848,19 @@ def test_save_fails(self): et = edge.to_executorch() with self.assertRaises(ValueError): _ = et.save("/tmp/test_save.pt") + + def test__transform_override_verifiers(self): + """Test that _transform can override verifiers in the exported program.""" + class MyVerifier(Verifier): + dialect: str = "MY_DIALECT" + def __init__(self): + super().__init__() + + model = TestLinear() + program = torch.export.export(model, model._get_random_inputs(), strict=True) + self.assertFalse(issubclass(program.verifiers[0], MyVerifier)) + + # Apply transformation with custom verifier + transformed = _transform(program, AddToMulPassEdge(), override_verifiers=[MyVerifier]) + self.assertTrue(issubclass(transformed.verifiers[0], MyVerifier)) + self.assertFalse(issubclass(program.verifiers[0], MyVerifier))