From b891fba64b2ddf353fed7fe002fd59a18c0796b8 Mon Sep 17 00:00:00 2001 From: Lucy Qiu Date: Thu, 19 Mar 2026 14:40:43 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- .../coreml/llama/export_static_llm_coreml.py | 332 ++++++++++++------ examples/apple/coreml/llama/utils.py | 14 +- examples/models/llama/static_attention.py | 62 ++-- 3 files changed, 264 insertions(+), 144 deletions(-) diff --git a/examples/apple/coreml/llama/export_static_llm_coreml.py b/examples/apple/coreml/llama/export_static_llm_coreml.py index 2aac3200dfb..da04af4c8f6 100644 --- a/examples/apple/coreml/llama/export_static_llm_coreml.py +++ b/examples/apple/coreml/llama/export_static_llm_coreml.py @@ -111,6 +111,8 @@ def load_model( params_path: str, max_context_len: int, generate_full_logits: bool = True, + adapter_checkpoint: str = None, + adapter_config: str = None, ): """Load the model from checkpoint with static_mha attention type. @@ -121,6 +123,8 @@ def load_model( generate_full_logits: If True, output logits for all tokens (needed for lookahead decoding). If False, only output logits for the last token (more efficient for standard autoregressive generation). + adapter_checkpoint: Path to LoRA adapter weights (.safetensors) + adapter_config: Path to adapter_config.json """ with open(params_path, "r") as f: params = json.loads(f.read()) @@ -133,6 +137,13 @@ def load_model( args.attention_type = "static_mha" args.attention_kwargs = {"decompose_sdpa_in_mha": True} + if adapter_config is not None: + with open(adapter_config, "r") as f: + lora_config = json.loads(f.read()) + args.r = lora_config["r"] + args.lora_alpha = lora_config["lora_alpha"] + args.target_modules = lora_config["target_modules"] + with torch.device("meta"): model = construct_transformer(args) @@ -142,20 +153,41 @@ def load_model( if "model" in checkpoint: checkpoint = checkpoint["model"] - # Rename attention weight keys for static attention + # Rename attention weight keys for static attention: + # wq.* -> wqs.0.*, wk.* -> wks.0.*, wv.* -> wvs.0.* + # LoRALinear._load_from_state_dict remaps weight -> linear.weight automatically. for i in range(len(model.layers)): - if f"layers.{i}.attention.wq.weight" in checkpoint: - checkpoint[f"layers.{i}.attention.wqs.0.weight"] = checkpoint.pop( - f"layers.{i}.attention.wq.weight" - ) - if f"layers.{i}.attention.wk.weight" in checkpoint: - checkpoint[f"layers.{i}.attention.wks.0.weight"] = checkpoint.pop( - f"layers.{i}.attention.wk.weight" - ) - if f"layers.{i}.attention.wv.weight" in checkpoint: - checkpoint[f"layers.{i}.attention.wvs.0.weight"] = checkpoint.pop( - f"layers.{i}.attention.wv.weight" - ) + prefix = f"layers.{i}.attention" + for old_proj, new_proj in [("wq", "wqs.0"), ("wk", "wks.0"), ("wv", "wvs.0")]: + for key in list(checkpoint.keys()): + old_prefix = f"{prefix}.{old_proj}." + if key.startswith(old_prefix): + suffix = key[len(old_prefix):] + new_key = f"{prefix}.{new_proj}.{suffix}" + checkpoint[new_key] = checkpoint.pop(key) + + if adapter_checkpoint is not None: + from executorch.examples.models.llama.convert_weights import ( + load_and_convert_unsloth_to_meta, + ) + + adapter_weights = load_and_convert_unsloth_to_meta(adapter_checkpoint) + # Rename adapter keys for static attention the same way + for i in range(len(model.layers)): + prefix = f"layers.{i}.attention" + for old_proj, new_proj in [ + ("wq", "wqs.0"), + ("wk", "wks.0"), + ("wv", "wvs.0"), + ]: + for key in list(adapter_weights.keys()): + old_prefix = f"{prefix}.{old_proj}." + if key.startswith(old_prefix): + suffix = key[len(old_prefix):] + new_key = f"{prefix}.{new_proj}.{suffix}" + adapter_weights[new_key] = adapter_weights.pop(key) + + checkpoint.update(adapter_weights) missing, unexpected = model.load_state_dict( checkpoint, @@ -309,8 +341,26 @@ def _get_metadata(model_args, example_inputs, input_len, cache_len, float_dtype) } -def _prepare_model(model, args, float_dtype): - """Apply splitting, quantization, and graph breaks to a model.""" +_TARGET_MODULE_TO_ATTR = { + "q_proj": "wqs", + "k_proj": "wks", + "v_proj": "wvs", + "o_proj": "wo", + "output_proj": "wo", + "gate_proj": "w1", + "up_proj": "w3", + "down_proj": "w2", +} + + +def _prepare_model(model, args, float_dtype, skip_split_names=None): + """Apply splitting, quantization, and graph breaks to a model. + + This is shared across base and adapter models so the same transformations + are applied consistently. + """ + from executorch.examples.models.llama.lora import LoRALinear + model = model.to(float_dtype).eval() if args.target_split_size is not None: @@ -321,6 +371,7 @@ def _prepare_model(model, args, float_dtype): out_max_splits=args.max_splits, in_target_split_size=1, in_max_splits=1, + skip_names=skip_split_names, ) if args.embedding_quantize: @@ -342,6 +393,20 @@ def _prepare_model(model, args, float_dtype): lambda m, fqn: isinstance(m, torch.nn.Embedding), ) + has_lora_modules = any( + isinstance(m, LoRALinear) for m in model.modules() + ) + + def _exclude_lora(m, fqn): + if isinstance(m, LoRALinear): + return False + parts = fqn.split(".") + if "lora_a" in parts or "lora_b" in parts: + return False + return isinstance(m, nn.Linear) + + linear_filter = _exclude_lora if has_lora_modules else None + if args.linear_quantize == "b4w": print("\nQuantizing linear layers: 4-bit blockwise (group_size=32)...") quantize_( @@ -350,6 +415,7 @@ def _prepare_model(model, args, float_dtype): weight_dtype=torch.int4, granularity=PerGroup(32), ), + linear_filter, ) elif args.linear_quantize == "c4w": print("\nQuantizing linear layers: 4-bit channelwise...") @@ -359,6 +425,7 @@ def _prepare_model(model, args, float_dtype): weight_dtype=torch.int4, granularity=PerAxis(0), ), + linear_filter, ) if not args.no_graph_breaks: @@ -464,9 +531,20 @@ def main(): "and generate_full_logits=True for lookahead decoding support.", ) + # LoRA adapter options + parser.add_argument( + "--adapter", + nargs=3, + action="append", + metavar=("NAME", "CHECKPOINT", "CONFIG"), + help="LoRA adapter: method name, path to adapter.safetensors, " + "path to adapter_config.json. Can be repeated for multiple adapters.", + ) + args = parser.parse_args() # Compute cache length + has_adapters = args.adapter is not None print("Export mode:") if args.multifunction: @@ -475,6 +553,8 @@ def main(): ) else: print("\tSingle method: fixed seqlen, generate_full_logits=True (lookahead)") + if has_adapters: + print(f"\tAdapters: {[a[0] for a in args.adapter]}") print("\nQuantization and datatype:") print(f"\tEmbedding quantize: {args.embedding_quantize}") @@ -491,7 +571,7 @@ def main(): print(f"\tTarget split size: {args.target_split_size}") print(f"\tMax splits: {args.max_splits}") - # Load model + # Load base model # For multifunction: generate_full_logits=False (efficient, only last token) # For single method: generate_full_logits=True (needed for lookahead decoding) generate_full_logits = not args.multifunction @@ -505,7 +585,55 @@ def main(): print(f"Model loaded: {model_args.n_layers} layers, {model_args.dim} dim") float_dtype = {"fp16": torch.float16, "fp32": torch.float32}[args.dtype] - model = _prepare_model(model, args, float_dtype) + + # Compute skip_split_names: union of all adapter target_modules mapped to + # model attribute names. Both base and adapter models skip splitting these + # so POSITIONAL weight sharing can deduplicate the base weights. + skip_split_names = None + if has_adapters: + all_targets = set() + for _, _, adapter_cfg in args.adapter: + with open(adapter_cfg, "r") as f: + cfg = json.loads(f.read()) + all_targets.update(cfg.get("target_modules", [])) + skip_split_names = { + _TARGET_MODULE_TO_ATTR[t] for t in all_targets if t in _TARGET_MODULE_TO_ATTR + } + print(f"\nSkipping split for LoRA-targeted modules: {skip_split_names}") + + model = _prepare_model(model, args, float_dtype, skip_split_names=skip_split_names) + + # Load adapter models + lora_models = {} + if has_adapters: + for name, adapter_ckpt, adapter_cfg in args.adapter: + print(f"\nLoading adapter '{name}' from {adapter_ckpt}...") + lora_model, _ = load_model( + args.checkpoint, + args.params, + args.max_context_len, + generate_full_logits=generate_full_logits, + adapter_checkpoint=adapter_ckpt, + adapter_config=adapter_cfg, + ) + lora_model = _prepare_model( + lora_model, args, float_dtype, skip_split_names=skip_split_names + ) + lora_models[name] = lora_model + + def _export_model(m, inputs, label="model"): + print(f"\nTesting eager execution ({label})...") + with torch.no_grad(): + m(*inputs) + print(f"Eager execution successful ({label})!") + + print(f"\nExporting {label}...") + ep = torch.export.export(m, inputs) + print(f"Export successful ({label})!") + print(ep) + return ep + + use_multimethod = args.multifunction or has_adapters if args.multifunction: # Multifunction mode: separate prefill and decode graphs with weight sharing @@ -537,31 +665,22 @@ def main(): cache_len=shared_cache_len, ) - # Test eager execution for both - print("\nTesting eager execution (decode, seqlen=1)...") - with torch.no_grad(): - model(*decode_inputs) - print("Decode eager execution successful!") + # Export base model + methods = { + "forward": _export_model(model, decode_inputs, "base decode"), + "prefill": _export_model(model, prefill_inputs, "base prefill"), + } - print(f"\nTesting eager execution (prefill, seqlen={prefill_input_len})...") - with torch.no_grad(): - model(*prefill_inputs) - print("Prefill eager execution successful!") - - # Export both graphs - print("\nExporting decode model (seqlen=1)...") - decode_ep = torch.export.export(model, decode_inputs) - print("Decode export successful!") - print(decode_ep) - - print(f"\nExporting prefill model (seqlen={prefill_input_len})...") - prefill_ep = torch.export.export(model, prefill_inputs) - print("Prefill export successful!") - print(prefill_ep) - - # Generate metadata for C++ runner - # constant_methods are shared across all methods, so we prefix method-specific - # metadata with the method name + # Export adapter models + for name, lora_model in lora_models.items(): + methods[f"{name}_forward"] = _export_model( + lora_model, decode_inputs, f"{name} decode" + ) + methods[f"{name}_prefill"] = _export_model( + lora_model, prefill_inputs, f"{name} prefill" + ) + + # Generate metadata print("\nGenerating metadata for C++ runner...") decode_metadata = _get_metadata( model_args, decode_inputs, decode_input_len, decode_cache_len, float_dtype @@ -574,71 +693,43 @@ def main(): float_dtype, ) - # Combine metadata - shared values go without prefix, method-specific values get prefixed constant_methods = { - # Shared metadata (same for both methods) "vocab_size": decode_metadata["vocab_size"], "head_dim": decode_metadata["head_dim"], "n_heads_per_cache": decode_metadata["n_heads_per_cache"], "freqs_cos": decode_metadata["freqs_cos"], "freqs_sin": decode_metadata["freqs_sin"], - # Decode-specific metadata (forward method) "decode_input_len": decode_metadata["forward_input_len"], "decode_freqs_cos_input_index": decode_metadata["freqs_cos_input_index"], "decode_freqs_sin_input_index": decode_metadata["freqs_sin_input_index"], "decode_mask_specs": decode_metadata["mask_specs"], "decode_kv_cache_specs": decode_metadata["kv_cache_specs"], - # Prefill-specific metadata "prefill_input_len": prefill_metadata["forward_input_len"], "prefill_freqs_cos_input_index": prefill_metadata["freqs_cos_input_index"], "prefill_freqs_sin_input_index": prefill_metadata["freqs_sin_input_index"], "prefill_mask_specs": prefill_metadata["mask_specs"], "prefill_kv_cache_specs": prefill_metadata["kv_cache_specs"], } - - # Setup CoreML partitioner with multimethod weight sharing - print("\nSetting up CoreML partitioner (multifunction with weight sharing)...") - compile_specs = CoreMLBackend.generate_compile_specs( - minimum_deployment_target=ct.target.iOS18, - compute_precision={ - torch.float16: ct.precision.FLOAT16, - torch.float32: ct.precision.FLOAT32, - }[float_dtype], - compute_unit=ct.ComputeUnit.CPU_AND_NE, - model_type=CoreMLBackend.MODEL_TYPE.MODEL, - ) - compile_specs.append( - CoreMLBackend.generate_multimethod_weight_sharing_strategy_compile_spec( - MULTIMETHOD_WEIGHT_SHARING_STRATEGY.POSITIONAL - ) - ) - partitioner = CoreMLPartitioner( - compile_specs=compile_specs, - take_over_mutable_buffer=False, - skip_ops_for_coreml_delegation=[], - ) - - # Lower to edge with both decode and prefill methods - print("\nLowering to edge (multi-method: decode + prefill)...") - edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) - - # Create multi-method edge manager with decode as "forward" and prefill as "prefill" - edge_manager = to_edge_transform_and_lower( - {"forward": decode_ep, "prefill": prefill_ep}, - partitioner=[partitioner], - constant_methods=constant_methods, - compile_config=edge_compile_config, + if has_adapters: + constant_methods["has_lora"] = True + elif has_adapters: + # Adapter-only mode (no multifunction): base + adapter methods, same seqlen + print(f"\nCreating example inputs (seqlen={args.input_len})...") + example_inputs, example_cache_len = _create_example_inputs( + model_args, args.input_len, args.max_context_len, float_dtype ) - print("\nDelegated program (decode/forward):") - print(format_delegated_graph(edge_manager.exported_program().graph_module)) + methods = { + "forward": _export_model(model, example_inputs, "base"), + } + for name, lora_model in lora_models.items(): + methods[name] = _export_model(lora_model, example_inputs, name) - print("\nDelegated program (prefill):") - print( - format_delegated_graph( - edge_manager.exported_program("prefill").graph_module - ) + print("\nGenerating metadata for C++ runner...") + constant_methods = _get_metadata( + model_args, example_inputs, args.input_len, example_cache_len, float_dtype ) + constant_methods["has_lora"] = True else: # Single method mode: fixed seqlen with generate_full_logits=True for lookahead print(f"\nCreating example inputs (seqlen={args.input_len})...") @@ -646,51 +737,60 @@ def main(): model_args, args.input_len, args.max_context_len, float_dtype ) - # Test eager execution - print("\nTesting eager execution...") - with torch.no_grad(): - model(*example_inputs) - print("Eager execution successful!") - - # Export the model - print("\nExporting model...") - ep = torch.export.export(model, example_inputs) - print("Export successful!") - print(ep) + ep = _export_model(model, example_inputs, "model") - # Generate metadata for C++ runner print("\nGenerating metadata for C++ runner...") constant_methods = _get_metadata( model_args, example_inputs, args.input_len, example_cache_len, float_dtype ) - # Setup CoreML partitioner - print("\nSetting up CoreML partitioner...") - compile_specs = CoreMLBackend.generate_compile_specs( - minimum_deployment_target=ct.target.iOS18, - compute_precision={ - torch.float16: ct.precision.FLOAT16, - torch.float32: ct.precision.FLOAT32, - }[float_dtype], - compute_unit=ct.ComputeUnit.CPU_AND_NE, - model_type=CoreMLBackend.MODEL_TYPE.MODEL, - ) - partitioner = CoreMLPartitioner( - compile_specs=compile_specs, - take_over_mutable_buffer=False, - skip_ops_for_coreml_delegation=[], + # Setup CoreML partitioner + print("\nSetting up CoreML partitioner...") + compile_specs = CoreMLBackend.generate_compile_specs( + minimum_deployment_target=ct.target.iOS18, + compute_precision={ + torch.float16: ct.precision.FLOAT16, + torch.float32: ct.precision.FLOAT32, + }[float_dtype], + compute_unit=ct.ComputeUnit.CPU_AND_NE, + model_type=CoreMLBackend.MODEL_TYPE.MODEL, + ) + if use_multimethod: + compile_specs.append( + CoreMLBackend.generate_multimethod_weight_sharing_strategy_compile_spec( + MULTIMETHOD_WEIGHT_SHARING_STRATEGY.POSITIONAL + ) ) + partitioner = CoreMLPartitioner( + compile_specs=compile_specs, + take_over_mutable_buffer=False, + skip_ops_for_coreml_delegation=[], + ) - # Lower to edge with constant methods for C++ runner - print("\nLowering to edge...") - edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) + # Lower to edge + print("\nLowering to edge...") + edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) + if use_multimethod: + edge_manager = to_edge_transform_and_lower( + methods, + partitioner=[partitioner], + constant_methods=constant_methods, + compile_config=edge_compile_config, + ) + for method_name in methods: + print(f"\nDelegated program ({method_name}):") + print( + format_delegated_graph( + edge_manager.exported_program(method_name).graph_module + ) + ) + else: edge_manager = to_edge_transform_and_lower( ep, partitioner=[partitioner], constant_methods=constant_methods, compile_config=edge_compile_config, ) - print("\nDelegated program:") print(format_delegated_graph(edge_manager.exported_program().graph_module)) diff --git a/examples/apple/coreml/llama/utils.py b/examples/apple/coreml/llama/utils.py index 1e5a842fed5..b282be8c4f4 100644 --- a/examples/apple/coreml/llama/utils.py +++ b/examples/apple/coreml/llama/utils.py @@ -91,9 +91,20 @@ def forward(self, x): def replace_linear_with_split_linear( - model, out_target_split_size, out_max_splits, in_target_split_size, in_max_splits=1 + model, + out_target_split_size, + out_max_splits, + in_target_split_size, + in_max_splits=1, + skip_names=None, ): + from executorch.examples.models.llama.lora import LoRALinear + for name, module in model.named_children(): + if skip_names and name in skip_names: + continue + if isinstance(module, LoRALinear): + continue if isinstance(module, torch.nn.Linear): assert module.bias is None, "SplitLinearModule does not support bias" new_module = SplitLinearModule( @@ -113,4 +124,5 @@ def replace_linear_with_split_linear( out_max_splits, in_target_split_size, in_max_splits, + skip_names=skip_names, ) diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 99809a063bc..c36f5a70358 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -812,38 +812,46 @@ def __init__( [StaticVCache(layer_id, i) for i in range(self.n_kv_heads)] ) else: - self.wqs = nn.ModuleList( - [ - nn.Linear( - self.dim, - self.head_dim * self.n_heads, - bias=self.attention_qkv_bias, - ) - ] - ) - self.wks = nn.ModuleList( - [ - nn.Linear( - self.dim, - self.head_dim * self.n_kv_heads, - bias=self.attention_qkv_bias, - ) - ] - ) - self.wvs = nn.ModuleList( - [ - nn.Linear( - self.dim, - self.head_dim * self.n_kv_heads, - bias=self.attention_qkv_bias, + has_lora = config.target_modules is not None + _PROJ_TARGET = { + "wqs": ("q_proj", self.dim, self.head_dim * self.n_heads), + "wks": ("k_proj", self.dim, self.head_dim * self.n_kv_heads), + "wvs": ("v_proj", self.dim, self.head_dim * self.n_kv_heads), + } + for attr, (target, in_dim, out_dim) in _PROJ_TARGET.items(): + if has_lora and target in config.target_modules: + proj = LoRALinear( + in_dim=in_dim, + out_dim=out_dim, + rank=config.r, + alpha=config.lora_alpha, + use_bias=self.attention_qkv_bias, ) - ] - ) + else: + proj = nn.Linear(in_dim, out_dim, bias=self.attention_qkv_bias) + setattr(self, attr, nn.ModuleList([proj])) self.k_caches = nn.ModuleList([StaticKCache(layer_id, 0)]) self.v_caches = nn.ModuleList([StaticVCache(layer_id, 0)]) - self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) + wo_use_lora = ( + not self.split_mha + and config.target_modules is not None + and ( + "output_proj" in config.target_modules + or "o_proj" in config.target_modules + ) + ) + if wo_use_lora: + self.wo = LoRALinear( + in_dim=self.n_heads * self.head_dim, + out_dim=self.dim, + rank=config.r, + alpha=config.lora_alpha, + use_bias=False, + ) + else: + self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) self.rope = _Rope(rope.params) self.layer_id = layer_id