diff --git a/export/TARGETS b/export/TARGETS index ae8be8a5e98..bf1002a701e 100644 --- a/export/TARGETS +++ b/export/TARGETS @@ -12,6 +12,7 @@ python_library( "//executorch/exir/backend:backend_api", "//executorch/exir:pass_manager", "//executorch/devtools/backend_debug:delegation_info", + "//executorch/extension/export_util:export_util", ] ) diff --git a/export/export.py b/export/export.py index 593f9b91157..7dd6b239d0a 100644 --- a/export/export.py +++ b/export/export.py @@ -4,16 +4,19 @@ import torch from executorch.devtools.backend_debug import get_delegation_info from executorch.exir._warnings import experimental +from executorch.exir.backend.backend_api import validation_disabled from executorch.exir.program import ( EdgeProgramManager, ExecutorchProgramManager, to_edge_transform_and_lower, ) from executorch.exir.schema import Program +from executorch.extension.export_util.utils import save_pte_program from executorch.runtime import Runtime, Verification from tabulate import tabulate from torch import nn from torch.ao.quantization import allow_exported_model_train_eval +from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer from torch.export import ExportedProgram from torchao.quantization import quantize_ from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -145,15 +148,15 @@ def run( model, self._example_inputs_dict[method_name][0], dynamic_shapes=dynamic_shapes, + strict=True, ) # Apply pre-edge transform passes if available if self._pre_edge_transform_passes is not None: - self._exported_program[method_name] = ( - self._pre_edge_transform_passes( + for pre_edge_transform_pass in self._pre_edge_transform_passes: + self._exported_program[method_name] = pre_edge_transform_pass( self._exported_program[method_name] ) - ) def get_artifacts(self) -> Dict[str, ExportedProgram]: """ @@ -210,13 +213,14 @@ def run( self._constant_methods = transform_config.get("constant_methods", None) # Process inputs - self._edge_program_manager = to_edge_transform_and_lower( - self._exported_program, - partitioner=self._partitioners, - transform_passes=self._transform_passes, - constant_methods=self._constant_methods, - compile_config=self._compile_config, - ) + with validation_disabled(): + self._edge_program_manager = to_edge_transform_and_lower( + self._exported_program, + partitioner=self._partitioners, + transform_passes=self._transform_passes, + constant_methods=self._constant_methods, + compile_config=self._compile_config, + ) self._delegation_info = get_delegation_info( self._edge_program_manager.exported_program().graph_module ) @@ -345,8 +349,8 @@ class QuantizeStage(Stage): Optional stage: Perform post-training quantization on the model. """ - def __init__(self, quantizer: Any) -> None: - self._quantizer = quantizer + def __init__(self, quantizers: Any) -> None: + self._quantizers = quantizers self._quantized_models: Dict[str, nn.Module] = {} self._model_dict: Dict[str, nn.Module] = {} self._exported_program_dict: Dict[str, ExportedProgram] = {} @@ -394,7 +398,8 @@ def run( model = exported_program.module() # Prepare the model for quantization - prepared_model = prepare_pt2e(model, self._quantizer) # type: ignore + composed_quantizer = ComposableQuantizer(self._quantizers) + prepared_model = prepare_pt2e(model, composed_quantizer) # type: ignore # Allow the model to switch between train and eval modes allow_exported_model_train_eval(prepared_model) @@ -546,9 +551,9 @@ def __init__( # Create the quantize stage if a quantizer is provided if self._export_recipe.quantization_recipe is not None: - quantizer = self._export_recipe.quantization_recipe.get_quantizer() - if quantizer is not None: - quantize_stage = QuantizeStage(quantizer=quantizer) + quantizers = self._export_recipe.quantization_recipe.get_quantizers() + if quantizers is not None: + quantize_stage = QuantizeStage(quantizers=quantizers) self._pipeline.append(quantize_stage) # Create the edge transform and lower stage @@ -661,6 +666,22 @@ def get_executorch_program(self) -> Program: ) return self._executorch_program_manager.executorch_program + def get_executorch_program_manager(self) -> ExecutorchProgramManager: + """ + Get the ExecutorchProgramManager. + + Returns: + The ExecutorchProgramManager + + Raises: + RuntimeError: If the executorch program manager is not initialized + """ + if self._executorch_program_manager is None: + raise RuntimeError( + "Executorch program manager is not initialized. Run export() first." + ) + return self._executorch_program_manager + def get_pte_buffer(self) -> bytes: """ Get the PTE buffer as bytes. @@ -677,6 +698,20 @@ def get_pte_buffer(self) -> bytes: ) return self._executorch_program_manager.buffer + def save_to_pte(self, output_name: str) -> None: + """ + Save the model to a .pte file. + + Args: + output_name (Optional[str]): The name of the .pte file. + """ + assert output_name, "Need a valid output name" + if self._executorch_program_manager is None: + raise RuntimeError( + "Executorch program manager is not initialized. Run export() first." + ) + save_pte_program(self._executorch_program_manager, output_name) + def get_example_input( self, method_name: str = "forward" ) -> Tuple[torch.Tensor, ...]: diff --git a/export/recipe.py b/export/recipe.py index 7b743c0aa4c..b993fce26e3 100644 --- a/export/recipe.py +++ b/export/recipe.py @@ -49,17 +49,17 @@ class QuantizationRecipe: quantizer: Optional quantizer for model quantization """ - quantizer: Optional[Quantizer] = None + quantizers: Optional[List[Quantizer]] = None ao_base_config: Optional[List[AOBaseConfig]] = None - def get_quantizer(self) -> Optional[Quantizer]: + def get_quantizers(self) -> Optional[Quantizer]: """ Get the quantizer associated with this recipe. Returns: The quantizer if one is set, otherwise None """ - return self.quantizer + return self.quantizers @experimental( @@ -94,10 +94,11 @@ class ExportRecipe: ) pre_edge_transform_passes: Optional[ Callable[[ExportedProgram], ExportedProgram] + | List[Callable[[ExportedProgram], ExportedProgram]] ] = None edge_transform_passes: Optional[Sequence[PassType]] = None transform_check_ir_validity: bool = True - partitioners: Optional[list[Partitioner]] = None + partitioners: Optional[List[Partitioner]] = None executorch_backend_config: Optional[ExecutorchBackendConfig] = ( None # pyre-ignore[11]: Type not defined )