diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 6d9ba750431..3a1f423aa27 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -676,47 +676,62 @@ def _validate_args(args): ) -def _export_llama(args) -> LLMEdgeManager: # noqa: C901 - _validate_args(args) - - pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args) - - # export_to_edge - builder_exported = _prepare_for_llama_export(args).export() - - builder_exported.run_canonical_optimizations() - - if args.export_only: - exit() - - builder_exported_to_edge = builder_exported.pt2e_quantize( - quantizers - ).export_to_edge() - - modelname = builder_exported_to_edge.modelname - - # to_backend +def _to_edge_and_lower_llama_xnnpack( + builder_exported, + modelname, + additional_passes, + pt2e_quant_params, + quantizers, + quant_dtype, + args, +) -> LLMEdgeManager: # noqa: C901 partitioners = [] # Order matters here, dynamic quantization should be applied first when both xnnpack and xnnpack_extended_ops are enabled - if ( - pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None - ) or (args.xnnpack): - partitioners.append( - get_xnnpack_partitioner(dynamic_quant_only_partitioner=True) - ) + partitioners.append(get_xnnpack_partitioner(dynamic_quant_only_partitioner=True)) - # force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False - args.xnnpack = True - modelname = f"xnnpack_dq_{modelname}" + modelname = f"xnnpack_dq_{modelname}" if args.xnnpack_extended_ops: - assert args.xnnpack, "xnnpack_extended_ops requires xnnpack to be enabled" partitioners.append( get_xnnpack_partitioner(dynamic_quant_only_partitioner=False) ) modelname = f"xnnpack_{modelname}" + logging.info("Lowering model using following partitioner(s): ") + for partitioner in partitioners: + logging.info(f"--> {partitioner.__class__.__name__}") + + # TODO: Enable generating ETRecord with XNNPack and to_edge_transform_and_lower(). + if args.generate_etrecord: + raise NotImplementedError( + "export_llama does not support XNNPack and generating ETRecord at the moment." + ) + + builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower( + partitioners + ) + if args.verbose: + print_delegation_info(builder.edge_manager.exported_program().graph_module) + + return builder.to_executorch(passes=additional_passes) + + +def _to_edge_and_lower_llama( # noqa: C901 + builder_exported, + modelname, + additional_passes, + pt2e_quant_params, + quantizers, + quant_dtype, + args, +): + builder_exported_to_edge = builder_exported.pt2e_quantize( + quantizers + ).export_to_edge() + + # to_backend + partitioners = [] if args.vulkan: partitioners.append( get_vulkan_partitioner( @@ -731,7 +746,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 modelname = f"vulkan_{modelname}" # Need to remove asserts from the graph to prevent graph breaks - # pyre-ignore: Undefined attribute [16]: `Optional` has no attribute `exported_program`. remove_asserts(builder_exported_to_edge.edge_manager.exported_program()) if args.mps: @@ -760,13 +774,11 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` from executorch.backends.qualcomm.utils.utils import _transform, tag_quant_io - # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program` _transform(builder_exported_to_edge.edge_manager.exported_program()) if args.num_sharding > 0: model_sharding.split_graph( builder_exported_to_edge.edge_manager.exported_program(), - # pyre-fixme[16]: `Optional` has no attribute `__getitem__`. builder_exported_to_edge.metadata["get_n_layers"], shares=args.num_sharding, ) @@ -792,19 +804,15 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 atten.head_dim, ) ) - # pyre-ignore tag_quant_io( builder_exported_to_edge.edge_manager.exported_program().graph_module, - partial(get_custom_quant_ios_dtype, cache_shape), # pyre-ignore + partial(get_custom_quant_ios_dtype, cache_shape), ) logging.info("Lowering model using following partitioner(s): ") for partitioner in partitioners: logging.info(f"--> {partitioner.__class__.__name__}") - additional_passes = [] - if args.model in TORCHTUNE_DEFINED_MODELS: - additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] if args.generate_etrecord: if not builder_exported_to_edge.edge_manager: raise ValueError("Unable to generate etrecord due to missing edge manager.") @@ -818,7 +826,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 if args.num_sharding > 0 and args.qnn: from executorch.backends.qualcomm.utils.utils import canonicalize_program - # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. canonicalize_program(builder.edge_manager.exported_program()) builder = builder.to_executorch( @@ -840,11 +847,55 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 if args.num_sharding > 0 and args.qnn: from executorch.backends.qualcomm.utils.utils import canonicalize_program - # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. canonicalize_program(builder.edge_manager.exported_program()) builder = builder.to_executorch(passes=additional_passes) + return builder + + +def _export_llama(args) -> LLMEdgeManager: # noqa: C901 + _validate_args(args) + + pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args) + + additional_passes = [] + if args.model in TORCHTUNE_DEFINED_MODELS: + additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] + + # export_to_edge + builder_exported = _prepare_for_llama_export(args).export() + builder_exported.run_canonical_optimizations() + modelname = builder_exported.modelname + + if args.export_only: + exit() + + if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None: + # Force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False + args.xnnpack = True + + if args.xnnpack: + builder = _to_edge_and_lower_llama_xnnpack( + builder_exported, + modelname, + additional_passes, + pt2e_quant_params, + quantizers, + quant_dtype, + args, + ) + else: + builder = _to_edge_and_lower_llama( + builder_exported, + modelname, + additional_passes, + pt2e_quant_params, + quantizers, + quant_dtype, + args, + ) + if args.profile_memory: generate_memory_trace(builder.export_program, "memory_profile.json") @@ -866,7 +917,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 output_file = f"{builder.output_dir}/{modelname}.pte" builder.save_to_pte(output_file) - return builder diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index 82c7aca09e0..a5057e5e850 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -67,7 +67,6 @@ def export(self) -> "LlavaEdgeManager": dynamic_shapes=dynamic_shape, strict=False, ) - # pyre-ignore: Incompatible attribute type [8]: Attribute `pre_autograd_graph_module` declared in class `LLMEdgeManager` has type `Optional[GraphModule]` but is used as type `Module`. self.pre_autograd_graph_module = self.export_program.module() return self diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 88d2bc0cab9..ec6cfa41ad8 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -21,7 +21,7 @@ DuplicateDynamicQuantChainPass, ) from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass -from executorch.exir import EdgeProgramManager +from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower from executorch.exir.backend.partitioner import Partitioner from executorch.exir.backend.utils import format_delegated_graph @@ -39,7 +39,7 @@ from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer import Quantizer from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer -from torch.export import export_for_training +from torch.export import export_for_training, ExportedProgram from torch.nn.attention import SDPBackend FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" @@ -89,8 +89,8 @@ def __init__( dynamic_shapes: Optional[Any] = None, ): self.model = model - # graph module returned from export() - self.pre_autograd_graph_module: Optional[torch.fx.GraphModule] = None + self.pre_autograd_exported_program: Optional[ExportedProgram] = None + self.pre_autograd_graph_module: Optional[torch.nn.Module] = None self.modelname = modelname self.max_seq_len = max_seq_len self.dtype = dtype @@ -218,8 +218,8 @@ def export(self) -> "LLMEdgeManager": kwargs=self.example_kwarg_inputs, dynamic_shapes=dynamic_shape, ) - # pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as # `Module`. + self.pre_autograd_exported_program = exported_module self.pre_autograd_graph_module = exported_module.module() if hasattr(self.args, "export_only") and self.args.export_only: torch.export.save(exported_module, self.args.output_name) @@ -330,7 +330,10 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage assert ( self.pre_autograd_graph_module is not None ), "Please run export() first" - m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer) + m = prepare_pt2e( + self.pre_autograd_graph_module, # pyre-ignore[6] + composed_quantizer, + ) logging.info( f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}" ) @@ -430,6 +433,19 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag return self + def to_edge_transform_and_lower( + self, partitioners: Optional[List[Partitioner]] + ) -> "LLMEdgeManager": + if partitioners is None: + logging.info("No partitioner provided, skipping backend lowering...") + edge_config = self._get_edge_config() + self.edge_manager = to_edge_transform_and_lower( + self.pre_autograd_exported_program, + partitioner=partitioners, + compile_config=edge_config, + ) + return self + def to_executorch( self, passes: Optional[List[ExportPass]] = None ) -> "LLMEdgeManager":