diff --git a/examples/apple/coreml/llama/export.py b/examples/apple/coreml/llama/export.py new file mode 100644 index 00000000000..58bc0859c79 --- /dev/null +++ b/examples/apple/coreml/llama/export.py @@ -0,0 +1,285 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +import argparse +import json + +import sys + +import coremltools as ct +import torch +from executorch.backends.apple.coreml.compiler import CoreMLBackend # pyre-ignore +from executorch.backends.apple.coreml.partition import CoreMLPartitioner # pyre-ignore +from executorch.examples.models.llama.source_transformation.quantize import ( + EmbeddingQuantHandler, +) + +from executorch.exir.backend.utils import format_delegated_graph +from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig +from executorch.exir.passes import MemoryPlanningPass +from executorch.exir.passes.quant_fusion_pass import QuantFusionPass +from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass +from executorch.extension.export_util.utils import export_to_edge, save_pte_program + +sys.path.insert(0, ".") +from llama_transformer import InputManager, ModelArgs, Transformer + + +class SplitLinearModule(torch.nn.Module): + def __init__(self, in_features, out_features, target_split_size, max_splits): + super(SplitLinearModule, self).__init__() + num_splits = max(out_features // target_split_size, 1) + if num_splits > max_splits: + num_splits = max_splits + + self.split_size = out_features // num_splits + self.split_remainder = out_features % num_splits + self.splits = torch.nn.ModuleList( + [torch.nn.Linear(in_features, self.split_size) for _ in range(num_splits)] + ) + print( + f"Splitting out_features={out_features} into {num_splits} of size {self.split_size}" + ) + if self.split_remainder > 0: + print( + f"Warning: remainder {self.split_remainder} after splitting out_features={out_features} into {num_splits} of size {self.split_size}" + ) + self.splits.append(torch.nn.Linear(in_features, self.split_remainder)) + + def split_sizes(self): + return [split.out_features for split in self.splits] + + def forward(self, x): + return torch.cat([split(x) for split in self.splits], dim=-1) + + +def replace_linear_with_split_linear(model, target_split_size, max_splits): + for name, module in model.named_children(): + if isinstance(module, torch.nn.Linear): + new_module = SplitLinearModule( + module.in_features, module.out_features, target_split_size, max_splits + ) + split_sizes = new_module.split_sizes() + if module.bias is not None: + split_bias = module.bias.split(split_sizes) + split_weights = module.weight.split(split_sizes, dim=0) + for i, split in enumerate(new_module.splits): + split.weight = torch.nn.Parameter(split_weights[i]) + if module.bias is not None: + split.bias = torch.nn.Parameter(split_bias[i]) + else: + split.bias = None + setattr(model, name, new_module) + else: + replace_linear_with_split_linear(module, target_split_size, max_splits) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "-n", + "--output_name", + default="model.pte", + help="Override the output filename of the saved pte model file.", + ) + parser.add_argument( + "-p", + "--params", + help="config.json", + ) + parser.add_argument( + "-c", + "--checkpoint", + help="checkpoint path", + ) + parser.add_argument( + "--seq_length", + type=int, + default=1, + help="length sequence to evaluate", + ) + parser.add_argument( + "--max_seq_length", + type=int, + default=128, + help="maximum length sequence to evaluate", + ) + parser.add_argument( + "--cache_size", + type=int, + default=None, + help="Cache size. Old items are evicted from cache", + ) + parser.add_argument( + "-E", + "--embedding-quantize", + default=None, + type=str, + help="type of embedding quantization, ',', e.g., '8,1024'.", + ) + parser.add_argument( + "--coreml-quantize", + default=None, + choices=["b4w", "c4w"], + help="This option is only for coreml: Use coreml quantization, e.g. b4w (for blockwise 4 bit weight), c4w (for channelwise 4 bit weight)", + ) + parser.add_argument( + "--use_cache_list", + action="store_true", + help="Use cache list to speed up model computation (does not work in pybindings)", + ) + parser.add_argument( + "--target_split_size", + type=int, + default=None, + help="Split linear layers into smaller chunks of target_split_size.", + ) + parser.add_argument( + "--max_splits", + type=int, + default=8, + help="Maximum number of splits to divide linear layers", + ) + + export_args = parser.parse_args() + params_path = export_args.params + checkpoint_path = export_args.checkpoint + + # Load model args + with open(params_path, "r") as f: + params = json.loads(f.read()) + + args = ModelArgs( + max_seq_len=export_args.max_seq_length, + generate_full_logits=False, + use_cache_list=export_args.use_cache_list, + **params, + ) + + with torch.device("meta"): + model = Transformer(args) + + checkpoint = torch.load( + checkpoint_path, map_location="cpu", mmap=True, weights_only=True + ) + if "model" in checkpoint: + checkpoint = checkpoint["model"] + + missing, unexpected = model.load_state_dict( + checkpoint, + strict=False, + assign=True, + ) + print("Missing keys: ", missing) + print("Unexpected keys: ", unexpected) + + float_dtype = torch.float16 # dtype for model/inputs + model.eval() + model.to(float_dtype) + + if export_args.embedding_quantize: + bitwidth, group_size = export_args.embedding_quantize.split(",") + if group_size == "none" or group_size == "None" or group_size == "0": + group_size = None + else: + group_size = int(group_size) + bitwidth = int(bitwidth) + model = EmbeddingQuantHandler( + model, + bitwidth=bitwidth, + group_size=group_size, + packed=(bitwidth in [2, 4]), + ).quantized_model() + + if export_args.target_split_size is not None: + replace_linear_with_split_linear( + model, export_args.target_split_size, export_args.max_splits + ) + + model = model.to(float_dtype) + + op_linear_quantizer_config = None + if export_args.coreml_quantize == "b4w": + op_linear_quantizer_config = { + "mode": "linear_symmetric", + "dtype": "int4", + "granularity": "per_block", + "block_size": 32, + "weight_threshold": 512, + } + elif export_args.coreml_quantize == "c4w": + op_linear_quantizer_config = { + "mode": "linear_symmetric", + "dtype": "int4", + "granularity": "per_channel", + } + + compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16] + minimum_deployment_target=ct.target.iOS18, + compute_precision=ct.precision(ct.precision.FLOAT16.value), + compute_unit=ct.ComputeUnit.CPU_AND_NE, + model_type=CoreMLBackend.MODEL_TYPE.MODEL, # pyre-fixme[16] + op_linear_quantizer_config=op_linear_quantizer_config, + ) + partitioner = CoreMLPartitioner( # pyre-fixme[16] + compile_specs=compile_specs, + take_over_mutable_buffer=False, + skip_ops_for_coreml_delegation=[ + "quantized_decomposed.embedding_4bit.dtype", + "aten.embedding.default", + ], + ) + + input_manager = InputManager( + n_layers=args.n_layers, + max_batch_size=args.max_batch_size, + n_kv_heads=args.n_kv_heads, + max_seq_length=args.max_seq_len, + head_dim=args.head_dim, + use_cache_list=export_args.use_cache_list, + seq_length=export_args.seq_length, + dtype=float_dtype, + minus_infinity=-30000, + cache_size=export_args.cache_size, + ) + example_inputs = input_manager.get_inputs(tokens=[0]) + + edge_manager = export_to_edge( + model, + example_inputs, + edge_compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_type_promotion=(float_dtype == torch.float16), + _skip_dim_order=True, + ), + ) + print("Edge program") + print(edge_manager.exported_program()) + + for node in edge_manager.exported_program().graph_module.graph.nodes: + print(node.name, node.target, node.args, node.kwargs) + + edge_manager = edge_manager.to_backend(partitioner) + + print("Delegated program") + + print(format_delegated_graph(edge_manager.exported_program().graph_module)) + + executorch_program = edge_manager.to_executorch( + ExecutorchBackendConfig( + extract_delegate_segments=True, + passes=[ + QuantFusionPass(), + ], + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), + ) + ) + + filename = save_pte_program(executorch_program, export_args.output_name) + print(f"Saved Executorch program to local {filename}") + + +if __name__ == "__main__": + main() # pragma: no cover diff --git a/examples/apple/coreml/llama/llama_transformer.py b/examples/apple/coreml/llama/llama_transformer.py new file mode 100644 index 00000000000..5788bcd5e5a --- /dev/null +++ b/examples/apple/coreml/llama/llama_transformer.py @@ -0,0 +1,570 @@ +# @lint-ignore-every LICENSELINT +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +# Please refer to README.md in the same folder for more information. + +from dataclasses import dataclass +from functools import partial +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F + +from executorch.examples.models.llama.llama_transformer import RMSNorm + +from executorch.examples.models.llama.rope import ( + hf_apply_rotary_emb, + hf_precompute_freqs_cis, + precompute_freqs_cis, + RotaryEmbedding, +) + +from torch import nn + + +# These are just to prevent to_edge from decomposing SDPA +# A better method is to use the to_edge_transform_and_lower API for CoreML +# and not decompose SDPA +@torch.library.custom_op("coreml::sdpa", mutates_args=()) +def sdpa( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor +) -> torch.Tensor: + """Same as F.scaled_dot_product_attention, but with custom op to avoid lowering during dialect conversion.""" + return torch.ops.aten.scaled_dot_product_attention.default( + q, k, v, attn_mask=attn_mask + ) + + +@torch.library.register_fake("coreml::sdpa") +def _( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor +) -> torch.Tensor: + """Fake implementation with the right output shape, which is required for torch.compile/export/fx tracing.""" + expected_shape = list(q.shape) + expected_shape[-1] = v.shape[-1] + return q.new_empty(expected_shape) + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + +@dataclass +class ModelArgs: + dim: int = 2048 + n_layers: int = 16 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = 128256 + hidden_dim: Optional[int] = None + head_dim: Optional[int] = None # Optional customized head_dim + multiple_of: int = 256 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + max_batch_size: int = 1 + max_seq_len: int = 128 + max_context_len: int = 2048 + moe: bool = False # True to enable the MoE (Mixture of Experts) + num_experts: int = 8 # Number of experts + num_activated_experts: int = 2 # Number of experts to activate + + # Generate logits for all inputs. When it's True, it would take big memory usage + # at runtime. Enable it only necessary (e.g., use perplexity tools that requires + # logits for all input tokens.) + generate_full_logits: bool = False + # A dictionary mapping from pruned token-id to original token-id + input_prune_map: Optional[Dict[int, int]] = None + # A dictionary mapping from pruned token-id to original token-id + output_prune_map: Optional[Dict[int, int]] = None + use_hf_rope: bool = False # Use HuggingFace's RoPE implementation + rope_theta: Optional[float] = ( + None # The official name to override self.rope_freq_base. + ) + rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC. + use_scaled_rope: bool = True # Use scaled RoPE, introduced in llama3.1. + # Additional Model Metadata needed at runtime + rope_scale_factor: int = 8 + bos_idx: int = 1 + eos_idx: int = 3 + bos_count: int = -1 # i.e., a single EOS is used as BOS + eos_count: int = 2 + + quantization_args: Optional[dict] = None + lora_args: Optional[dict] = None + + use_cache_list: bool = True + + def __post_init__(self): + if self.n_kv_heads is None: + self.n_kv_heads = self.n_heads + + # rope_theta overrides rope_freq_base since it's the official name. + if self.rope_theta is not None: + self.rope_freq_base = self.rope_theta + + if self.hidden_dim is None: + # If hidden_dim is not explicitly set in the ModelArgs, + # then calculate implicitly based on dim and also multiple of `args.multiple_of` + multiple_of = self.multiple_of + hidden_dim = 4 * self.dim + hidden_dim = int(2 * hidden_dim / 3) + if self.ffn_dim_multiplier is not None: + hidden_dim = int(self.ffn_dim_multiplier * hidden_dim) + self.hidden_dim = find_multiple(hidden_dim, multiple_of) + + if self.head_dim is None: + self.head_dim = self.dim // self.n_heads + + +class Rope(torch.nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.params = params + if self.params.use_hf_rope: + self.precompute_freqs_cis = hf_precompute_freqs_cis + else: + self.precompute_freqs_cis = partial( + precompute_freqs_cis, use_scaled=self.params.use_scaled_rope + ) + freqs_cos, freqs_sin = self.precompute_freqs_cis( + self.params.head_dim, + ( + self.params.max_context_len # Normal llama2. + if self.params.ffn_dim_multiplier is None + else self.params.max_context_len * 2 # Sharded checkpoint. + ), + self.params.rope_freq_base, + scale_factor=8, + ) + self.register_buffer("freqs_cos", freqs_cos, persistent=False) + self.register_buffer("freqs_sin", freqs_sin, persistent=False) + if self.params.use_hf_rope: + self.apply_rotary_emb = hf_apply_rotary_emb + else: + self.apply_rotary_emb = RotaryEmbedding() + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + return self.apply_rotary_emb(q, k, freqs_cos, freqs_sin) + + def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): + """ + Get the precomputed frequencies for the given input position and sequence length. + + Args: + input_pos (torch.Tensor): The input position tensor. + seq_len (int): The sequence length. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The precomputed frequencies for the given input position and sequence length. + """ + assert ( + input_pos is not None + ), "input_pos must be provided when use_kv_cache is True" + input_pos_item = input_pos[-1].item() + + # CoreML partitioner is not picking up _check_is_size + # So instead use _check as workaround. Should be easy fix for partitioner + # torch._check_is_size(input_pos_item) + torch._check(input_pos_item >= 0) + torch._check(input_pos_item + seq_len <= self.params.max_seq_len) + # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor + freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seq_len) + # pyre-ignore: Incompatible parameter type [6] + freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seq_len) + + return freqs_cos, freqs_sin + + +class FeedForward(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + assert args.hidden_dim is not None + hidden_dim: int = args.hidden_dim + self.w1 = nn.Linear(args.dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, args.dim, bias=False) + self.w3 = nn.Linear(args.dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class ConditionalFeedForward(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.dim = args.dim + hidden_dim = args.hidden_dim + if hidden_dim is None: + # If hidden_dim is not explicitly set in the ModelArgs, + # then calculate implicitly based on dim and also multiple of `args.multiple_of` + multiple_of = args.multiple_of + hidden_dim = 4 * self.dim + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim)) + self.w2 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim)) + self.w3 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim)) + self.num_experts = args.num_experts + + def forward(self, x: torch.Tensor, expert_indices: torch.Tensor) -> torch.Tensor: + w1_weights = self.w1[expert_indices].transpose(-1, -2) # [T, A, D, D] + w3_weights = self.w3[expert_indices].transpose(-1, -2) # [T, A, D, D] + w2_weights = self.w2[expert_indices] # [T, A, D, D] + x1 = F.silu(torch.einsum("ti,taio -> tao", x, w1_weights)) + x3 = torch.einsum("ti, taio -> tao", x, w3_weights) + expert_outs = torch.einsum("tao, taoi -> tai", (x1 * x3), w2_weights) + return expert_outs + + +class MOEFeedForward(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.gate = nn.Linear(config.dim, config.num_experts, bias=False) + self.cond_ffn = ConditionalFeedForward(config) + self.dim = config.dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.view(-1, self.dim) + # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts + # x: [T, D] + scores = self.gate(x) # [T, E] + expert_weights, expert_indices = torch.topk(scores, 2, dim=-1) # [T, A], [T, A] + expert_weights = expert_weights.softmax(dim=-1) # [T, A] + expert_outs = self.cond_ffn(x, expert_indices) + return torch.einsum("tai,ta -> ti", expert_outs, expert_weights) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): + super().__init__() + self.n_heads = args.n_heads + self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads + + assert self.n_heads % self.n_kv_heads == 0 + model_parallel_size = 1 + self.n_local_heads = self.n_heads // model_parallel_size + self.n_local_kv_heads = self.n_kv_heads // model_parallel_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.head_dim + self.max_batch_size = args.max_batch_size + self.max_seq_len = args.max_seq_len + self.dim = args.dim + self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False) + self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) + + self.layer_id = layer_id + + self.rope = rope + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + attn_mask: torch.Tensor, + ): + bsz, seqlen, _ = x.shape + # QKV + q, k, v = self.wq(x), self.wk(x), self.wv(x) + # We need view_copy elimination + q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + # RoPE relative positional embeddings + q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) + + q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + new_k = k + new_v = v + + k = torch.concat([k_cache, k], dim=2) + v = torch.concat([v_cache, v], dim=2) + + # grouped multiquery attention: expand out keys and values + if self.n_rep > 1: + k = k.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + output = torch.ops.coreml.sdpa(q, k, v, attn_mask) + + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + + output = self.wo(output) + + return output, new_k, new_v + + +class TransformerBlock(nn.Module): + def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.head_dim + self.attention = Attention(args, layer_id, rope) + if args.moe: + self.block_sparse_moe = MOEFeedForward(args) + else: + self.feed_forward = FeedForward(args) + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + + def forward( + self, + x, + freqs_cos, + freqs_sin, + k_cache, + v_cache, + attn_mask, + ): # x: 1xN + norm_emb = self.attention_norm(x) + h, new_k, new_v = self.attention.forward( + norm_emb, freqs_cos, freqs_sin, k_cache, v_cache, attn_mask + ) + + h = x + h + out = h + self.feed_forward(self.ffn_norm(h)) + return out, new_k, new_v + + +class Transformer(nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.params = params + self.vocab_size = params.vocab_size + self.n_layers = params.n_layers + + self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) + self.rope = Rope(params) + self.layers = torch.nn.ModuleList() + for layer_id in range(params.n_layers): + self.layers.append(TransformerBlock(layer_id, params, self.rope)) + self.norm = RMSNorm(params.dim, eps=params.norm_eps) + self.output = nn.Linear(params.dim, params.vocab_size, bias=False) + self.generate_full_logits = params.generate_full_logits + self.max_seq_len = params.max_seq_len + self.input_prune_map = params.input_prune_map + self.output_prune_map = params.output_prune_map + self.use_cache_list = params.use_cache_list + + def forward( + self, + tokens: torch.LongTensor, # tokens + input_pos: torch.LongTensor, + input_length: torch.LongTensor, # input_length + k_caches: List[torch.FloatTensor], + v_caches: List[torch.FloatTensor], + attn_mask: torch.LongTensor, + h: Optional[torch.FloatTensor] = None, # embeddings + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if (tokens is None) ^ (h is not None): + raise ValueError( + "You cannot specify both tokens and h at the same time, and must specify either one" + ) + if tokens is not None and h is None: + h = self.tok_embeddings(tokens) + seqlen = h.shape[1] + freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seqlen) + + k_out = [] + v_out = [] + for i, layer in enumerate(self.layers): + h, new_k, new_v = layer( + h, + freqs_cos, + freqs_sin, + k_caches[i] if self.use_cache_list else k_caches[i, :, :, :, :], + v_caches[i] if self.use_cache_list else v_caches[i, :, :, :, :], + attn_mask, + ) + k_out.append(new_k) + v_out.append(new_v) + + if not self.generate_full_logits: + # Only the last logit is used for the new generated token + h = h[:, input_length - 1, :].squeeze(1) + + h = self.norm(h) + + logits = self.output(h) + + if not self.use_cache_list: + k_out = torch.stack(k_out, dim=0) + v_out = torch.stack(v_out, dim=0) + return logits, k_out, v_out + + +class InputManager: + def __init__( + self, + n_layers: int, + max_batch_size: int, + n_kv_heads: int, + max_seq_length: int, + head_dim: int, + use_cache_list: bool, + seq_length: int, + dtype=torch.float16, + minus_infinity=-torch.inf, + cache_size=None, + ): + if cache_size is None: + cache_size = max_seq_length - seq_length + self.cache_size = cache_size + assert self.cache_size + seq_length <= max_seq_length + + self.n_layers = n_layers + self.max_batch_size = max_batch_size + self.n_kv_heads = n_kv_heads + self.head_dim = head_dim + + self.seq_length = seq_length + self.use_cache_list = use_cache_list + + if self.use_cache_list: + self.k_caches = [ + torch.zeros(self.get_cache_shape(self.cache_size)).to(dtype) + for _ in range(self.n_layers) + ] + self.v_caches = [ + torch.zeros(self.get_cache_shape(self.cache_size)).to(dtype) + for _ in range(self.n_layers) + ] + else: + self.k_caches = torch.zeros(self.get_cache_shape(self.cache_size)).to(dtype) + self.v_caches = torch.zeros(self.get_cache_shape(self.cache_size)).to(dtype) + + attn_cache = minus_infinity * torch.ones( + seq_length, self.cache_size + ) # attn for past tokens + attn_seq = torch.triu( + minus_infinity * torch.ones(self.seq_length, self.seq_length), diagonal=1 + ) # attn for current tokens + self.attn_mask = torch.concat([attn_cache, attn_seq], dim=-1).to(dtype) + assert self.attn_mask.shape == ( + self.seq_length, + self.cache_size + self.seq_length, + ) + + self.input_pos = 0 + self.cache_pos = 0 + + def get_cache_shape(self, length): + if self.use_cache_list: + return ( + self.max_batch_size, + self.n_kv_heads, + length, + self.head_dim, + ) + return ( + self.n_layers, + self.max_batch_size, + self.n_kv_heads, + length, + self.head_dim, + ) + + def _update_cache(self, start, length, new_k_caches, new_v_caches): + """ + Copies new cache data from start to start + length to cache + """ + assert self.cache_pos + length <= self.cache_size + assert start + length <= self.seq_length + + if self.use_cache_list: + for i in range(self.n_layers): + assert new_k_caches[i].shape == self.get_cache_shape(self.seq_length) + assert new_v_caches[i].shape == self.get_cache_shape(self.seq_length) + + self.k_caches[i][ + :, :, (self.cache_pos) : (self.cache_pos + length), : + ] = new_k_caches[i][:, :, start : (start + length), :] + self.v_caches[i][ + :, :, (self.cache_pos) : (self.cache_pos + length), : + ] = new_v_caches[i][:, :, start : (start + length), :] + else: + assert new_k_caches.shape == self.get_cache_shape(self.seq_length) + assert new_v_caches.shape == self.get_cache_shape(self.seq_length) + self.k_caches[:, :, :, (self.cache_pos) : (self.cache_pos + length), :] = ( + new_k_caches[:, :, :, start : (start + length), :] + ) + self.v_caches[:, :, :, (self.cache_pos) : (self.cache_pos + length), :] = ( + new_v_caches[:, :, :, start : (start + length), :] + ) + + self.cache_pos += length + if self.cache_pos == self.cache_size: + self.cache_pos = 0 + + def update(self, input_length, new_k_caches, new_v_caches): + # Copy as much new cache data into cache as possible without wrapping + amount_to_copy = min(input_length, self.cache_size - self.cache_pos) + self._update_cache(0, amount_to_copy, new_k_caches, new_v_caches) + if self.input_pos <= self.cache_size: + self.attn_mask[:, (self.input_pos) : (self.input_pos + amount_to_copy)] = ( + 0.0 + ) + + # Copy remainder (cache is now wrapped around and has more room) + # Attention mask needs no further updates. Attention is paid to the whole cache + remaining_to_copy = min( + input_length - amount_to_copy, self.cache_size - self.cache_pos + ) + if remaining_to_copy > 0: + self._update_cache( + amount_to_copy, remaining_to_copy, new_k_caches, new_v_caches + ) + + self.input_pos += input_length + + def get_inputs(self, tokens: List[int]): + input_length = len(tokens) + assert input_length <= self.seq_length + + return ( + # tokens + torch.concat( + [ + torch.tensor(tokens, dtype=torch.int64), + torch.zeros(self.seq_length - input_length, dtype=torch.int64), + ], + axis=-1, + ).reshape(1, -1), + # input_pos + torch.tensor([self.input_pos], dtype=torch.long), + # input_length + torch.tensor([input_length], dtype=torch.long), + # k_cache + self.k_caches, + # v_cache + self.v_caches, + # attn_mask + self.attn_mask, + ) + + def get_inputs_and_remaining_tokens(self, tokens: List[int]): + processed_tokens = min(self.seq_length, len(tokens)) + return ( + self.get_inputs(tokens[0:processed_tokens]), + tokens[processed_tokens:], + ) diff --git a/examples/apple/coreml/llama/readme.md b/examples/apple/coreml/llama/readme.md new file mode 100644 index 00000000000..353f0b56307 --- /dev/null +++ b/examples/apple/coreml/llama/readme.md @@ -0,0 +1,39 @@ +# ANE-friendly Llama models + +This directory contains ANE-friendly Llama models. + +Export model with: +``` +python export.py -n /path/to/output/model.pte -p /path/to/params.json -c /path/to/model.pth --seq_length 64 --max_seq_length 1024 --coreml-quantize c4w +``` + +(Note the script should be run from the executorch/examples/apple/coreml/llama directory.) + +The runner is written in python and is only intended to serve as an example for how the model inputs should be processed; it is not performant. + + +Run model with: +``` +python run.py -m /path/to/model.pte -t /path/to/tokenizer.model --prompt "Once upon a time," +``` + +(Note the script should be run from the executorch/examples/apple/coreml/llama directory.) + + +## Export args +* seq_length: the number of tokens processed by the model. Sequences shorter than seq_length must be padded, and sequences longer than it must be chunked. +* max_seq_length: the maximum context tokens that can be processed. +* cache_size: the size of the KV cache sequences. This parameter is optional, and defaults to max_seq_length - seq_length. If a smaller cache_size is used, older tokens are evicted from the cache and no longer play a role in attention. For example, if max_seq_length=1024, but cache_size is 512, the model can generate up to 1024 tokens, but only the current tokens and the previous 512 will participate in attention. In terms of computation, cache_size plays a similar role to max_seq_length in models without cache eviction. +* use_cache_list: boolean option that controls whether KV caches are passed as a list of 4D tensors, one per layer, or if they are passed as one 5D tensor. (Note that use_cache_list does not work with ExecuTorch pybindings.) +* target_split_size: this option splits linear layers into chunks of target size. For example, if target_split_size is 1024, a linear layer with (in_features=512, out_features=8096) will be split into 8 linear layers with (in_features=512, out_features=1024) and the results concatted. If not specified, the default is no splitting. +* max_splits: this controls the maximum number of splits for linear layers. It is only relevant if target_size is passed and defaults to 8. + +## Llama1B on iPhone 15 + +We are actively experimenting with different settings. But here are ones that we've found work well for Llama1B on iPhone 15 Pro: + +* Set use_cache_list +* Split linear layers with target_split_size=1024, max_splits=8 +* Use seq_length=32 or seq_length=64, both of which offer reasonable tradeoffs for prefill and decode performance. seq_length=32 is better at decode and seq_length=64 is better at prefill. + +In our tests, we set max_seq_length=1024, but if your application allows for it, performance can improve with max_seq_length=512 or by keeping max_seq_length=1024 and setting cache_size=512-seq_length. diff --git a/examples/apple/coreml/llama/run.py b/examples/apple/coreml/llama/run.py new file mode 100644 index 00000000000..65026e1f6bc --- /dev/null +++ b/examples/apple/coreml/llama/run.py @@ -0,0 +1,134 @@ +import argparse +import sys + +import sentencepiece as spm + +import torch + +from executorch.runtime import Runtime + + +sys.path.insert(0, ".") +from executorch.examples.models.llama.runner.generation import next_token +from executorch.examples.models.llama.tokenizer import tiktoken +from llama_transformer import InputManager + + +class Tokenizer: + def __init__(self, model_path: str): + # Try sentence piece + try: + print("Trying to load sentencepiece") + sp = spm.SentencePieceProcessor() + sp.load(model_path) + self.tokenizer = sp + except: + print("Trying to load tiktoken") + self.tokenizer = tiktoken.Tokenizer(model_path) + + def encode(self, text, bos, eos): + if isinstance(self.tokenizer, spm.SentencePieceProcessor): + bos_string = "" if bos else "" + eos_string = "" if eos else "" + return self.tokenizer.encode(f"{bos_string}{text}{eos_string}") + return self.tokenizer.encode(text, bos=bos, eos=eos) + + def decode_token(self, token): + if isinstance(self.tokenizer, spm.SentencePieceProcessor): + return f"{self.tokenizer.decode(token)} " + return self.tokenizer.decode_token(token) + + def stop_tokens(self): + if isinstance(self.tokenizer, spm.SentencePieceProcessor): + return [self.tokenizer.eos_id()] + return self.tokenizer.stop_tokens + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model", + help="model.pte", + ) + parser.add_argument( + "-t", + "--tokenizer", + help="tokenizer.model path", + ) + parser.add_argument( + "--prompt", + type=str, + default="Once upon a time,", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.6, + ) + parser.add_argument( + "--top_p", + type=float, + default=0.9, + ) + + args = parser.parse_args() + + tokenizer = Tokenizer(args.tokenizer) + + runtime = Runtime.get() + program = runtime.load_program(args.model) + method = program.load_method("forward") + + metadata = method.metadata + print("Method metadata: ", metadata, "\n\n") + + assert ( + metadata.num_inputs() == 6 + ), "Do not export with --use_cache_list for use in pybindings" + # k_cache input + n_layers, max_batch_size, n_kv_heads, cache_size, head_dim = ( + metadata.input_tensor_meta(3).sizes() + ) + + # mask input + seq_length, max_seq_length = metadata.input_tensor_meta(5).sizes() + + input_manager = InputManager( + n_layers=n_layers, + max_batch_size=max_batch_size, + n_kv_heads=n_kv_heads, + max_seq_length=max_seq_length, + head_dim=head_dim, + use_cache_list=False, + seq_length=seq_length, + dtype=torch.float16, + minus_infinity=-30000.0, + cache_size=cache_size, + ) + + print(args.prompt, end="") + tokens = tokenizer.encode(args.prompt, bos=True, eos=False) + while input_manager.input_pos + seq_length < max_seq_length: + while len(tokens) > 0 and ( + input_manager.input_pos + seq_length < max_seq_length + ): + inputs, remaining_tokens = input_manager.get_inputs_and_remaining_tokens( + tokens + ) + processed_tokens = len(tokens) - len(remaining_tokens) + logits, k, v = method.execute(inputs) + input_manager.update( + input_length=processed_tokens, new_k_caches=k, new_v_caches=v + ) + tokens = remaining_tokens + + tokens = [next_token(logits, args.temperature, args.top_p)] + + if tokens[-1] in tokenizer.stop_tokens(): + break + print(tokenizer.decode_token(tokens[-1]), end="", flush=True) + + +if __name__ == "__main__": + main()