From 7680294220397e31846c3e632fa716d096505baa Mon Sep 17 00:00:00 2001 From: yifan_shen3 Date: Thu, 12 Sep 2024 14:46:43 -0700 Subject: [PATCH] polish UI: use --ios 18 to mean enable state + preserve sdpa --- examples/models/llama2/export_llama_lib.py | 13 ++++- extension/llm/export/partitioner_lib.py | 66 +++++++++++++--------- 2 files changed, 50 insertions(+), 29 deletions(-) diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 7117c492745..5cef72c1e6e 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -317,6 +317,13 @@ def build_args_parser() -> argparse.ArgumentParser: choices=["b4w"], help="This option is only for coreml: Use coreml quantization, e.g. b4w (for blockwise 4 bit weight)", ) + parser.add_argument( + "--coreml-ios", + type=int, + default=15, + choices=(15, 16, 17, 18), + help="This option is only for coreml: The minimum iOS version to deploy", + ) parser.add_argument( "--qnn", action="store_true", @@ -533,8 +540,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 if args.coreml: coreml_partitioner = get_coreml_partitioner( - args.use_kv_cache and args.coreml_enable_state, - args.coreml_preserve_sdpa, + args.coreml_ios, args.embedding_quantize, args.pt2e_quantize, args.coreml_quantize, @@ -810,7 +816,8 @@ def _get_source_transforms( # noqa transforms.append(replace_causal_mask) elif args.coreml: - if args.coreml_preserve_sdpa: + # iOS 18 introduced fused sdpa op + if args.coreml_ios >= 18: transforms.append(replace_sdpa_with_coreml_sdpa) else: transforms.append(replace_sdpa_with_simple_sdpa) diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index 7bbe7f5fa52..bba16dd8a4d 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -56,8 +56,7 @@ def get_mps_partitioner(use_kv_cache: bool = False): def get_coreml_partitioner( - enable_state: bool = False, - preserve_sdpa: bool = True, + ios: int = 15, embedding_quantize: Optional[str] = None, pt2e_quantize: Optional[str] = None, coreml_quantize: Optional[str] = None, @@ -75,29 +74,42 @@ def get_coreml_partitioner( "Please install the CoreML backend follwing https://pytorch.org/executorch/main/build-run-coreml.html" ) - minimum_deployment_target = ct.target.iOS15 - # In Core ML, stateful execution is introduced in iOS 18 - if enable_state: - minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18) - # In Core ML, sdpa op is introduced in iOS 18 - if preserve_sdpa: - minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18) - # In Core ML, quantization is introduced in iOS 16 - if embedding_quantize is not None or pt2e_quantize is not None: - minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS16) - # In Core ML, 8-bit activation quantization is introduced in iOS 17 - if ( - embedding_quantize is not None and int(embedding_quantize.split(",")[0]) == 8 - ) or pt2e_quantize in ("coreml_8a_c8w", "coreml_baseline_8a_c8w"): - minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS17) - # In Core ML, 4-bit weight compression is introduced in iOS 18 - if ( - (embedding_quantize is not None and int(embedding_quantize.split(",")[0]) == 4) - or pt2e_quantize in ("coreml_c4w", "coreml_8a_c4w", "coreml_baseline_8a_c4w") - or coreml_quantize == "b4w" - ): - minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18) + def _validate_ios_version() -> None: + assert ios in (15, 16, 17, 18) + if embedding_quantize is not None and ios < 18: + raise ValueError( + "In Core ML, per-block quantization is introduced in iOS 18" + ) + + use_quantization = pt2e_quantize is not None or coreml_quantize is not None + if use_quantization and ios < 16: + raise ValueError("In Core ML, quantization is introduced in iOS 16") + + use_8a = (pt2e_quantize is not None and "8a" in pt2e_quantize) or ( + coreml_quantize is not None and "8a" in coreml_quantize + ) + if use_8a and ios < 17: + raise ValueError( + "In Core ML, 8-bit activation quantization is introduced in iOS 17" + ) + + use_4w = (pt2e_quantize is not None and "4w" in pt2e_quantize) or ( + coreml_quantize is not None and "4w" in coreml_quantize + ) + if use_4w and ios < 18: + raise ValueError( + "In Core ML, 4-bit weight compression is introduced in iOS 18" + ) + + _validate_ios_version() + + minimum_deployment_target = { + 15: ct.target.iOS15, + 16: ct.target.iOS16, + 17: ct.target.iOS17, + 18: ct.target.iOS18, + }[ios] op_linear_quantizer_config = None if coreml_quantize == "b4w": op_linear_quantizer_config = { @@ -107,7 +119,6 @@ def get_coreml_partitioner( "block_size": 32, "weight_threshold": 512, } - compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16] minimum_deployment_target=minimum_deployment_target, compute_precision=ct.precision(ct.precision.FLOAT16.value), @@ -116,9 +127,12 @@ def get_coreml_partitioner( model_type=CoreMLBackend.MODEL_TYPE.MODEL, # pyre-fixme[16] op_linear_quantizer_config=op_linear_quantizer_config, ) + + take_over_mutable_buffer = minimum_deployment_target >= ct.target.iOS18 + return CoreMLPartitioner( # pyre-fixme[16] compile_specs=compile_specs, - take_over_mutable_buffer=enable_state, + take_over_mutable_buffer=take_over_mutable_buffer, )