From 602eb1c300dba933f9ebb83bfc85ed4f2ad4a3c7 Mon Sep 17 00:00:00 2001 From: "Jack Zhang (aider)" <32371937+jackzhxng@users.noreply.github.com> Date: Mon, 12 May 2025 14:45:37 -0700 Subject: [PATCH] Refactor _to_edge_and_lower_llama to remove args (#10520) Summary: Refactor `_to_edge_and_lower_llama_xnnpack` and `_to_edge_and_lower_llama` to remove args Reviewed By: iseeyuan Differential Revision: D73785343 Pulled By: jackzhxng --- examples/models/llama/export_llama_lib.py | 98 +++++++++++++++-------- 1 file changed, 66 insertions(+), 32 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 0e48a8520d7..3b926550b9f 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -27,7 +27,7 @@ from executorch.backends.vulkan._passes.remove_asserts import remove_asserts from executorch.devtools.backend_debug import print_delegation_info -from executorch.devtools.etrecord import generate_etrecord +from executorch.devtools.etrecord import generate_etrecord as generate_etrecord_func from executorch.examples.models.llama.hf_download import ( download_and_convert_hf_checkpoint, ) @@ -749,7 +749,9 @@ def _to_edge_and_lower_llama_xnnpack( pt2e_quant_params, quantizers, quant_dtype, - args, + xnnpack_extended_ops: bool = False, + generate_etrecord: bool = False, + verbose: bool = False, ) -> LLMEdgeManager: # noqa: C901 partitioners = [] @@ -758,7 +760,7 @@ def _to_edge_and_lower_llama_xnnpack( modelname = f"xnnpack_dq_{modelname}" - if args.xnnpack_extended_ops: + if xnnpack_extended_ops: partitioners.append( get_xnnpack_partitioner(dynamic_quant_only_partitioner=False) ) @@ -769,7 +771,7 @@ def _to_edge_and_lower_llama_xnnpack( logging.info(f"--> {partitioner.__class__.__name__}") # TODO: Enable generating ETRecord with XNNPack and to_edge_transform_and_lower(). - if args.generate_etrecord: + if generate_etrecord: raise NotImplementedError( "export_llama does not support XNNPack and generating ETRecord at the moment." ) @@ -777,7 +779,7 @@ def _to_edge_and_lower_llama_xnnpack( builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower( partitioners ) - if args.verbose: + if verbose: print_delegation_info(builder.edge_manager.exported_program().graph_module) return builder.to_executorch(passes=additional_passes) @@ -790,7 +792,23 @@ def _to_edge_and_lower_llama( # noqa: C901 pt2e_quant_params, quantizers, quant_dtype, - args, + vulkan: bool = False, + mps: bool = False, + coreml: bool = False, + qnn: bool = False, + dtype_override: str = "fp32", + enable_dynamic_shape: bool = True, + use_kv_cache: bool = False, + embedding_quantize: Optional[str] = None, + pt2e_quantize: Optional[str] = None, + coreml_ios: int = 15, + coreml_quantize: Optional[str] = None, + coreml_compute_units: str = "cpu_only", + use_qnn_sha: bool = False, + num_sharding: int = 0, + soc_model: str = "SM8650", + generate_etrecord: bool = False, + verbose: bool = False, ): builder_exported_to_edge = builder_exported.pt2e_quantize( quantizers @@ -798,11 +816,11 @@ def _to_edge_and_lower_llama( # noqa: C901 # to_backend partitioners = [] - if args.vulkan: + if vulkan: partitioners.append( get_vulkan_partitioner( - args.dtype_override, - args.enable_dynamic_shape, + dtype_override, + enable_dynamic_shape, ) ) modelname = f"vulkan_{modelname}" @@ -810,22 +828,22 @@ def _to_edge_and_lower_llama( # noqa: C901 # Need to remove asserts from the graph to prevent graph breaks remove_asserts(builder_exported_to_edge.edge_manager.exported_program()) - if args.mps: - partitioners.append(get_mps_partitioner(args.use_kv_cache)) + if mps: + partitioners.append(get_mps_partitioner(use_kv_cache)) modelname = f"mps_{modelname}" - if args.coreml: + if coreml: coreml_partitioner = get_coreml_partitioner( - args.coreml_ios, - args.embedding_quantize, - args.pt2e_quantize, - args.coreml_quantize, - args.coreml_compute_units, + coreml_ios, + embedding_quantize, + pt2e_quantize, + coreml_quantize, + coreml_compute_units, ) partitioners.append(coreml_partitioner) modelname = f"coreml_{modelname}" - if args.qnn: + if qnn: logging.warning( "The model definition in current repro is not performant, please refer to the instruction" " in https://github.com/pytorch/executorch/tree/main/examples/qualcomm/oss_scripts/llama/README.md for better performance." @@ -833,9 +851,7 @@ def _to_edge_and_lower_llama( # noqa: C901 from executorch.extension.llm.custom_ops import model_sharding partitioners.append( - get_qnn_partitioner( - args.use_kv_cache, args.pt2e_quantize, args.num_sharding, args.soc_model - ) + get_qnn_partitioner(use_kv_cache, pt2e_quantize, num_sharding, soc_model) ) # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm._passes` from executorch.backends.qualcomm._passes import ( @@ -864,7 +880,7 @@ def _to_edge_and_lower_llama( # noqa: C901 ) atten = builder_exported_to_edge.model.layers[0].attention - if args.use_qnn_sha: + if use_qnn_sha: cache_shape = torch.Size( (atten.max_batch_size, atten.max_context_len, atten.head_dim) ) @@ -887,10 +903,10 @@ def _to_edge_and_lower_llama( # noqa: C901 passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][ "get_quant_io_dtype_fn" ] = partial(get_custom_quant_ios_dtype, cache_shape) - if args.num_sharding > 0: + if num_sharding > 0: SplitGraph, setting = model_sharding.get_split_graph_pass( builder_exported_to_edge.metadata["get_n_layers"], - shares=args.num_sharding, + shares=num_sharding, ) passes_job[SplitGraph] = setting dep_table[SplitGraph] = [FoldQDQ] @@ -905,7 +921,7 @@ def _to_edge_and_lower_llama( # noqa: C901 for partitioner in partitioners: logging.info(f"--> {partitioner.__class__.__name__}") - if args.generate_etrecord: + if generate_etrecord: if not builder_exported_to_edge.edge_manager: raise ValueError("Unable to generate etrecord due to missing edge manager.") @@ -913,9 +929,9 @@ def _to_edge_and_lower_llama( # noqa: C901 # Copy the edge manager which will be serialized into etrecord. This is memory-wise expensive. edge_manager_copy = copy.deepcopy(builder_exported_to_edge.edge_manager) builder = builder_exported_to_edge.to_backend(partitioners) - if args.verbose: + if verbose: print_delegation_info(builder.edge_manager.exported_program().graph_module) - if args.num_sharding > 0 and args.qnn: + if num_sharding > 0 and qnn: # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`. from executorch.backends.qualcomm.utils.utils import canonicalize_program @@ -927,7 +943,7 @@ def _to_edge_and_lower_llama( # noqa: C901 # Generate ETRecord if edge_manager_copy: - generate_etrecord( + generate_etrecord_func( et_record="etrecord.bin", edge_dialect_program=edge_manager_copy, executorch_program=builder.export_program, @@ -935,9 +951,9 @@ def _to_edge_and_lower_llama( # noqa: C901 logging.info("Generated etrecord.bin") else: builder = builder_exported_to_edge.to_backend(partitioners) - if args.verbose: + if verbose: print_delegation_info(builder.edge_manager.exported_program().graph_module) - if args.num_sharding > 0 and args.qnn: + if num_sharding > 0 and qnn: from executorch.backends.qualcomm.utils.utils import canonicalize_program canonicalize_program(builder.edge_manager.exported_program()) @@ -976,7 +992,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 pt2e_quant_params, quantizers, quant_dtype, - args, + xnnpack_extended_ops=args.xnnpack_extended_ops, + generate_etrecord=args.generate_etrecord, + verbose=args.verbose, ) else: builder = _to_edge_and_lower_llama( @@ -986,7 +1004,23 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 pt2e_quant_params, quantizers, quant_dtype, - args, + vulkan=args.vulkan, + mps=args.mps, + coreml=args.coreml, + qnn=args.qnn, + dtype_override=args.dtype_override, + enable_dynamic_shape=args.enable_dynamic_shape, + use_kv_cache=args.use_kv_cache, + embedding_quantize=args.embedding_quantize, + pt2e_quantize=args.pt2e_quantize, + coreml_ios=args.coreml_ios, + coreml_quantize=args.coreml_quantize, + coreml_compute_units=args.coreml_compute_units, + use_qnn_sha=args.use_qnn_sha, + num_sharding=args.num_sharding, + soc_model=args.soc_model, + generate_etrecord=args.generate_etrecord, + verbose=args.verbose, ) if args.profile_memory: