From 23330e850035b0ea3076b9ca2a5423be08d953a4 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 10 Dec 2025 16:54:37 -0800 Subject: [PATCH] Expose static llama in OSS (#16184) Summary: This exposes a static llama model for CoreML. We want to unify development behind one static model so that we do not need to apply fixes in multiple places (e.g., iOS 26 fixes). Reviewed By: billmguo Differential Revision: D88875745 --- .ci/scripts/test_ane_static_llama.sh | 7 + .../coreml/llama/export_static_llm_coreml.py | 504 ++++++++++++++++++ examples/apple/coreml/llama/readme.md | 36 ++ examples/apple/coreml/llama/run_static_llm.py | 323 +++++++++++ 4 files changed, 870 insertions(+) create mode 100644 examples/apple/coreml/llama/export_static_llm_coreml.py create mode 100644 examples/apple/coreml/llama/run_static_llm.py diff --git a/.ci/scripts/test_ane_static_llama.sh b/.ci/scripts/test_ane_static_llama.sh index 3081c7ffe52..73a9c4ca54b 100644 --- a/.ci/scripts/test_ane_static_llama.sh +++ b/.ci/scripts/test_ane_static_llama.sh @@ -28,6 +28,13 @@ pushd $EXECUTORCH_ROOT/examples/apple/coreml/llama # Download stories llama110m artifacts download_stories_model_artifacts +# Test static ANE llama model +python export_static_llm_coreml.py --checkpoint stories110M.pt --params params.json --output model.pte + +# The ANE cannot run in github CI +# python run_static_llm.py --model model.pte --params params.json --tokenizer tokenizer.model --prompt "Once upon a time," --lookahead + +# Test export of deprecated model python export.py -n model.pte -p params.json -c stories110M.pt --seq_length 32 --max_seq_length 64 --dtype fp16 --coreml-quantize c4w --embedding-quantize 4,32 popd diff --git a/examples/apple/coreml/llama/export_static_llm_coreml.py b/examples/apple/coreml/llama/export_static_llm_coreml.py new file mode 100644 index 00000000000..a3fd8201414 --- /dev/null +++ b/examples/apple/coreml/llama/export_static_llm_coreml.py @@ -0,0 +1,504 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Export script for static attention LLM models to CoreML via ExecuTorch. + +Usage: + python export_static_llm_coreml.py \ + --checkpoint /path/to/model.pth \ + --params /path/to/params.json \ + --output static_llm_coreml_model.pte \ + --max_context_len 1024 \ + --input_len 32 \ + --embedding_quantize 4,32 \ + --coreml_quantize c4w \ + --target_split_size 1048 +""" + +import argparse +import json + +import coremltools as ct +import torch +import torch.nn as nn +import torch.utils._pytree as pytree + +from executorch.backends.apple.coreml.compiler import CoreMLBackend +from executorch.backends.apple.coreml.partition import CoreMLPartitioner +from executorch.examples.apple.coreml.llama.utils import ( + replace_linear_with_split_linear, +) +from executorch.examples.models.llama.llama_transformer import construct_transformer +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.rope import Rope +from executorch.examples.models.llama.static_attention import StaticAttentionIOManager +from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower +from executorch.exir.backend.utils import format_delegated_graph +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.passes import MemoryPlanningPass +from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass +from executorch.extension.export_util.utils import save_pte_program +from torch.library import impl, Library +from torchao.quantization.granularity import PerAxis, PerGroup +from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ + +# Define custom graph break op +lib = Library("executorch_utils", "DEF") +lib.define("graph_break.Tensor(Tensor x) -> Tensor") + + +@impl(lib, "graph_break.Tensor", "CompositeExplicitAutograd") +def graph_break_impl(x): + return x + + +class ExecutorchGraphBreakModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, *args, **kwargs): + return tuple( + ( + torch.ops.executorch_utils.graph_break.Tensor(a) + if isinstance(a, torch.Tensor) + else a + ) + for a in args + ) + + +class BlockWithGraphBreak(nn.Module): + def __init__(self, block: nn.Module, break_before: bool = True): + super().__init__() + self.graph_break = ExecutorchGraphBreakModule() + self.block = block + self.break_before = break_before + + def forward(self, *args, **kwargs): + if self.break_before: + new_args = self.graph_break(*args) + out = self.block(*new_args, **kwargs) + return out + else: + out = self.block(*args, **kwargs) + out = self.graph_break(*out) + return out + + +def remove_graph_break_(edge_manager): + from executorch.exir.dialects._ops import ops as exir_ops + + for n in edge_manager.exported_program().graph_module.graph.nodes: + if n.target == exir_ops.edge.executorch_utils.graph_break.Tensor: + n.replace_all_uses_with(n.args[0]) + edge_manager.exported_program().graph_module.graph.eliminate_dead_code() + + +def load_model(checkpoint_path: str, params_path: str, max_context_len: int): + """Load the model from checkpoint with static_mha attention type.""" + with open(params_path, "r") as f: + params = json.loads(f.read()) + + # TODO: to support lookahead decoding, the static model outputs + # full logits, but if we are not using lookahead decoding, we can have a + # more efficient model by setting generate_full_logits=False and supplying the last + # valid token + args = ModelArgs( + max_context_len=max_context_len, + generate_full_logits=True, + **params, + ) + args.attention_type = "static_mha" + args.attention_kwargs = {"decompose_sdpa_in_mha": True} + + with torch.device("meta"): + model = construct_transformer(args) + + checkpoint = torch.load( + checkpoint_path, map_location="cpu", mmap=True, weights_only=True + ) + if "model" in checkpoint: + checkpoint = checkpoint["model"] + + # Rename attention weight keys for static attention + for i in range(len(model.layers)): + if f"layers.{i}.attention.wq.weight" in checkpoint: + checkpoint[f"layers.{i}.attention.wqs.0.weight"] = checkpoint.pop( + f"layers.{i}.attention.wq.weight" + ) + if f"layers.{i}.attention.wk.weight" in checkpoint: + checkpoint[f"layers.{i}.attention.wks.0.weight"] = checkpoint.pop( + f"layers.{i}.attention.wk.weight" + ) + if f"layers.{i}.attention.wv.weight" in checkpoint: + checkpoint[f"layers.{i}.attention.wvs.0.weight"] = checkpoint.pop( + f"layers.{i}.attention.wv.weight" + ) + + missing, unexpected = model.load_state_dict( + checkpoint, + strict=False, + assign=True, + ) + if missing: + print(f"Missing keys: {missing}") + if unexpected: + print(f"Unexpected keys: {unexpected}") + + return model, args + + +def _get_metadata(model_args, example_inputs, input_len, cache_len, float_dtype): + """ + Generate metadata methods for the C++ runner. + + The C++ runner needs these constant methods to understand the model structure: + - vocab_size: Vocabulary size + - head_dim: Head dimension + - n_heads_per_cache: Number of KV heads + - freqs_cos, freqs_sin: Pre-computed RoPE frequencies + - freqs_cos_input_index, freqs_sin_input_index: Input indices for RoPE + - kv_cache_specs: Tensor describing cache input/output indices and lengths + - mask_specs: Tensor describing mask input indices + - forward_input_len: Input length for forward method + """ + # Pre-compute RoPE frequencies for the full context + rope = Rope(model_args) + freqs_cos, freqs_sin = rope.get_freqs(None, model_args.max_context_len) + print(f"Pre-computed RoPE frequencies shape: {freqs_cos.shape}, {freqs_sin.shape}") + + # Flatten example inputs to get the pytree spec + flat_inputs, in_spec = pytree.tree_flatten(example_inputs) + + # Reconstruct input indices from the pytree spec + input_indices = pytree.tree_unflatten( + list(range(in_spec.num_leaves)), + in_spec, + ) + + # input_indices structure: + # (token_idx, { + # "masks": {cache_len: mask_idx}, + # "freqs_cos_override": freqs_cos_idx, + # "freqs_sin_override": freqs_sin_idx, + # "in_cache_state": ({k_cache_ids: k_cache_idx}, {v_cache_ids: v_cache_idx}) + # }) + + # Get the options dict indices + opts_indices = input_indices[1] + + # Build KV cache specs: [k_in_idx, k_out_idx, v_in_idx, v_out_idx, cache_len] + # For static_mha, output cache indices follow the same order as inputs + # Output structure: (logits, {"out_cache_state": ({k_ids: k_out}, {v_ids: v_out})}) + k_cache_in_indices = opts_indices["in_cache_state"][0] + v_cache_in_indices = opts_indices["in_cache_state"][1] + + # Sort by layer to ensure consistent ordering + sorted_k_cache_ids = sorted(k_cache_in_indices.keys()) + + # Output indices are in the same order (after logits) + # Logits is output 0, then k_caches, then v_caches + kv_cache_specs = [] + for i, cache_id in enumerate(sorted_k_cache_ids): + k_in_idx = k_cache_in_indices[cache_id] + v_in_idx = v_cache_in_indices[cache_id] + # Output indices: k_caches come after logits (idx 1 to n_layers), + # v_caches come after k_caches (idx n_layers+1 to 2*n_layers) + k_out_idx = 1 + i + v_out_idx = 1 + len(sorted_k_cache_ids) + i + kv_cache_specs.append([k_in_idx, k_out_idx, v_in_idx, v_out_idx, cache_len]) + + print(f"KV cache specs (k_in, k_out, v_in, v_out, cache_len): {kv_cache_specs}") + + # Build mask specs: [mask_idx, cache_len] + mask_specs = [ + [mask_idx, c_len] for c_len, mask_idx in opts_indices["masks"].items() + ] + print(f"Mask specs (mask_idx, cache_len): {mask_specs}") + + return { + "vocab_size": model_args.vocab_size, + "head_dim": model_args.head_dim, + "n_heads_per_cache": model_args.n_kv_heads, + "freqs_cos": freqs_cos.to(float_dtype), + "freqs_sin": freqs_sin.to(float_dtype), + "freqs_cos_input_index": torch.tensor( + [opts_indices["freqs_cos_override"]], dtype=torch.int64 + ), + "freqs_sin_input_index": torch.tensor( + [opts_indices["freqs_sin_override"]], dtype=torch.int64 + ), + "mask_specs": torch.tensor(mask_specs, dtype=torch.int64), + "kv_cache_specs": torch.tensor(kv_cache_specs, dtype=torch.int64), + "forward_input_len": input_len, + } + + +def main(): + parser = argparse.ArgumentParser( + description="Export static attention Llama model to CoreML" + ) + + # Model paths + parser.add_argument( + "-c", + "--checkpoint", + required=True, + help="Path to model checkpoint (.pth)", + ) + parser.add_argument( + "-p", + "--params", + required=True, + help="Path to params.json", + ) + parser.add_argument( + "-o", + "--output", + default="model.pte", + help="Output filename for the .pte model", + ) + + # Model configuration + parser.add_argument( + "--max_context_len", + type=int, + default=1024, + help="Maximum context length", + ) + parser.add_argument( + "--input_len", + type=int, + default=32, + help="Input sequence length per forward pass", + ) + parser.add_argument( + "--dtype", + type=str, + choices=["fp16", "fp32"], + default="fp16", + help="Model dtype. The ANE requires fp16.", + ) + + # Quantization options + parser.add_argument( + "-E", + "--embedding_quantize", + default="8,0", + type=str, + help="Embedding quantization: ',', e.g., '4,32' or '8,0' for per-channel", + ) + parser.add_argument( + "--linear_quantize", + default="c4w", + choices=["b4w", "c4w"], + help="CoreML linear quantization: b4w (blockwise 4-bit) or c4w (channelwise 4-bit). The ANE requires channelwise.", + ) + + # Linear splitting options + parser.add_argument( + "--target_split_size", + type=int, + default=1024, + help="Split linear layers into chunks of this size (helps with ANE performance)", + ) + parser.add_argument( + "--max_splits", + type=int, + default=8, + help="Maximum number of splits for linear layers", + ) + + # Graph break options + parser.add_argument( + "--no_graph_breaks", + action="store_true", + help="Disable graph breaks between transformer blocks", + ) + + args = parser.parse_args() + + # Compute cache length + + print("Quantization and datatype:") + print(f"\tEmbedding quantize: {args.embedding_quantize}") + print(f"\tLinear quantize: {args.linear_quantize}") + print(f"\tDtype: {args.dtype}") + + cache_len = args.max_context_len - args.input_len + print("\nGeneration configuration:") + print(f"\tMax context length: {args.max_context_len}") + print(f"\tInput length: {args.input_len}") + print(f"\tCache length: {cache_len}") + + print("\nLinear splitting:") + print(f"\tTarget split size: {args.target_split_size}") + print(f"\tMax splits: {args.max_splits}") + + # Load model + print(f"\nLoading model from {args.checkpoint}...") + model, model_args = load_model( + args.checkpoint, + args.params, + args.max_context_len, + ) + print(f"Model loaded: {model_args.n_layers} layers, {model_args.dim} dim") + + # Set dtype + float_dtype = {"fp16": torch.float16, "fp32": torch.float32}[args.dtype] + model = model.to(float_dtype).eval() + + # Apply linear splitting (before quantization) + if args.target_split_size is not None: + print(f"\nSplitting linear layers with target size {args.target_split_size}...") + replace_linear_with_split_linear( + model, + out_target_split_size=args.target_split_size, + out_max_splits=args.max_splits, + in_target_split_size=1, + in_max_splits=1, + ) + + # Apply embedding quantization + if args.embedding_quantize: + bitwidth, group_size = args.embedding_quantize.split(",") + bitwidth = int(bitwidth) + group_size = int(group_size) + assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization" + + print(f"\nQuantizing embeddings: {bitwidth}-bit, group_size={group_size}...") + if group_size == 0: + granularity = PerAxis(0) + else: + granularity = PerGroup(group_size) + weight_dtype = getattr(torch, f"int{bitwidth}") + + quantize_( + model, + IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity), + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) + + # Apply linear quantization + if args.linear_quantize == "b4w": + print("\nQuantizing linear layers: 4-bit blockwise (group_size=32)...") + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerGroup(32), + ), + ) + elif args.linear_quantize == "c4w": + print("\nQuantizing linear layers: 4-bit channelwise...") + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerAxis(0), + ), + ) + + # Add graph breaks between transformer blocks + # Keeping model pieces smaller helps with ANE performance + if not args.no_graph_breaks: + print("\nAdding graph breaks between before/after the transformer blocks...") + n_layers = len(model.layers) + model.layers[0] = BlockWithGraphBreak(model.layers[0], break_before=True) + model.layers[n_layers - 1] = BlockWithGraphBreak( + model.layers[n_layers - 1], break_before=False + ) + + # Create IO manager and example inputs + mgr = StaticAttentionIOManager( + model_args, + input_len=args.input_len, + cache_lens=cache_len, + batch_size=1, + dtype=float_dtype, + style="smart_mask", # Use smart_mask to match C++ StaticTransformerRunner + mask_val=float("-inf"), + ) + example_inputs = ( + torch.zeros(1, args.input_len, dtype=torch.int32), + { + "masks": mgr.masks, + "freqs_cos_override": mgr.freqs_cos[: args.input_len], + "freqs_sin_override": mgr.freqs_sin[: args.input_len], + "in_cache_state": (mgr.k_caches, mgr.v_caches), + }, + ) + + # Test eager execution + print("\nTesting eager execution...") + with torch.no_grad(): + model(*example_inputs) + print("Eager execution successful!") + + # Export the model + print("\nExporting model...") + ep = torch.export.export(model, example_inputs) + print("Export successful!") + print(ep) + + # Generate metadata for C++ runner + print("\nGenerating metadata for C++ runner...") + constant_methods = _get_metadata( + model_args, example_inputs, args.input_len, cache_len, float_dtype + ) + + # Setup CoreML partitioner + print("\nSetting up CoreML partitioner...") + compile_specs = CoreMLBackend.generate_compile_specs( + minimum_deployment_target=ct.target.iOS18, + compute_precision={ + torch.float16: ct.precision.FLOAT16, + torch.float32: ct.precision.FLOAT32, + }[float_dtype], + compute_unit=ct.ComputeUnit.CPU_AND_NE, + model_type=CoreMLBackend.MODEL_TYPE.MODEL, + ) + partitioner = CoreMLPartitioner( + compile_specs=compile_specs, + take_over_mutable_buffer=False, + skip_ops_for_coreml_delegation=[], + ) + + # Lower to edge with constant methods for C++ runner + print("\nLowering to edge...") + edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) + edge_manager = to_edge_transform_and_lower( + ep, + partitioner=[partitioner], + constant_methods=constant_methods, + compile_config=edge_compile_config, + ) + + print("\nDelegated program:") + print(format_delegated_graph(edge_manager.exported_program().graph_module)) + + # Convert to ExecuTorch + print("\nConverting to ExecuTorch...") + remove_graph_break_(edge_manager) + executorch_program = edge_manager.to_executorch( + ExecutorchBackendConfig( + extract_delegate_segments=True, + do_quant_fusion_and_const_prop=True, + memory_planning_pass=MemoryPlanningPass( + alloc_graph_input=False, alloc_graph_output=False + ), + sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), + ) + ) + + # Save the program + filename = save_pte_program(executorch_program, args.output) + print(f"\nSaved ExecuTorch program to {filename}") + + +if __name__ == "__main__": + main() diff --git a/examples/apple/coreml/llama/readme.md b/examples/apple/coreml/llama/readme.md index 14dff0c8580..46e9043a5fc 100644 --- a/examples/apple/coreml/llama/readme.md +++ b/examples/apple/coreml/llama/readme.md @@ -1,5 +1,41 @@ # ANE-friendly Llama models +To export a static, ANE-friendly model use: + +``` +python export_static_llm_coreml.py \ + --checkpoint /path/to/model.pth \ + --params /path/to/params.json \ + --output static_llm_coreml_model.pte +``` + +To test in python, use: + +``` +python run_static_llm.py \ + --model static_llm_coreml_model.pte \ + --params /path/to/params.json \ + --tokenizer /path/to/tokenizer.model \ + --prompt "Once upon a time" \ + --max_new_tokens 100 \ + --lookahead +``` + +(Enabling lookahead decoding is optional, but does improve performance.) + +The static model has several ANE optimizations, including: +* Splitting linear layers for improved performance (controlled by target_split_size and max_splits args) +* Splitting the pte into multiple Core ML pieces for improved performance (can be disabled with no_graph_breaks) +* Re-writing SDPA to avoid 5-D tensors to imporve performance. This also fixes an accuracy bug that was introduced in iOS 26 (addresses this: https://github.com/pytorch/executorch/issues/15833) + + +We are working on adding a C++ runner as well. + + +# Deprecated (export.py, run.py, and run_lookahead.py) + +Below we describe export.py, run.py, and run_lookahead.py. But these are deprecated and will evenutally be removed because we are unifying around the static model formulation. + This directory contains ANE-friendly Llama models. Export model with: diff --git a/examples/apple/coreml/llama/run_static_llm.py b/examples/apple/coreml/llama/run_static_llm.py new file mode 100644 index 00000000000..2cd526aec42 --- /dev/null +++ b/examples/apple/coreml/llama/run_static_llm.py @@ -0,0 +1,323 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Run script for static attention Llama models exported with coreml_static_llama.py. + +Usage: + python run_static_llm.py \ + --model llama1b_static.pte \ + --params $HOME/models/llama1b/params.json \ + --tokenizer $HOME/models/llama1b/tokenizer.model \ + --prompt "Once upon a time" \ + --max_new_tokens 100 +""" + +import argparse +import json +import time +from typing import Any, Dict, List, Tuple + +import sentencepiece as spm +import torch +import torch.utils._pytree as pytree + +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.runner.generation import next_token +from executorch.examples.models.llama.static_attention import StaticAttentionIOManager +from executorch.runtime import Runtime + + +class Tokenizer: + """Wrapper to support both SentencePiece and Tiktoken tokenizers.""" + + def __init__(self, model_path: str): + try: + print("Trying to load sentencepiece") + sp = spm.SentencePieceProcessor() + sp.load(model_path) + self.tokenizer = sp + self._is_sentencepiece = True + except Exception: + print("Trying to load tiktoken") + from executorch.examples.models.llama.tokenizer import tiktoken + + self.tokenizer = tiktoken.Tokenizer(model_path) + self._is_sentencepiece = False + + def encode(self, text: str, bos: bool = True, eos: bool = False) -> List[int]: + if self._is_sentencepiece: + bos_string = "" if bos else "" + eos_string = "" if eos else "" + return self.tokenizer.encode(f"{bos_string}{text}{eos_string}") + return self.tokenizer.encode(text, bos=bos, eos=eos) + + def decode(self, tokens: List[int]) -> str: + if self._is_sentencepiece: + return self.tokenizer.decode(tokens) + return self.tokenizer.decode(tokens) + + def decode_token(self, token: int) -> str: + if self._is_sentencepiece: + return self.tokenizer.decode([token]) + try: + return self.tokenizer.decode_token(token) + except UnicodeDecodeError: + return f"<{token}>" + + @property + def stop_tokens(self) -> List[int]: + if self._is_sentencepiece: + return [self.tokenizer.eos_id()] + return self.tokenizer.stop_tokens + + +def create_pte_wrapper( + method, + k_cache_keys: List[str], + v_cache_keys: List[str], +): + """ + Create a wrapper function that adapts PTE execution to the interface + expected by StaticAttentionIOManager. + + The wrapper: + - Takes (tokens, options_dict) like the eager model + - Flattens inputs using pytree + - Executes the PTE method + - Reconstructs outputs to match eager model format: (logits, {"out_cache_state": (k_dict, v_dict)}) + """ + + def wrapper( + tokens: torch.Tensor, options: Dict[str, Any] + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + # Build the same input structure as during export + inputs = (tokens, options) + + # Flatten using pytree (same order as torch.export) + flat_inputs, _ = pytree.tree_flatten(inputs) + + # Execute PTE model + outputs = method.execute(flat_inputs) + + # First output is logits + logits = outputs[0] + + # Remaining outputs are k_cache updates then v_cache updates + num_layers = len(k_cache_keys) + k_updates = outputs[1 : 1 + num_layers] + v_updates = outputs[1 + num_layers : 1 + 2 * num_layers] + + # Reconstruct the output cache state dicts + k_cache_dict = dict(zip(k_cache_keys, k_updates)) + v_cache_dict = dict(zip(v_cache_keys, v_updates)) + + attn_updates = {"out_cache_state": (k_cache_dict, v_cache_dict)} + + return logits, attn_updates + + return wrapper + + +def main(): + parser = argparse.ArgumentParser(description="Run static attention Llama model") + + parser.add_argument( + "-m", + "--model", + required=True, + help="Path to exported .pte model", + ) + parser.add_argument( + "-p", + "--params", + required=True, + help="Path to params.json", + ) + parser.add_argument( + "-t", + "--tokenizer", + required=True, + help="Path to tokenizer model", + ) + parser.add_argument( + "--prompt", + type=str, + default="Once upon a time,", + help="Input prompt", + ) + parser.add_argument( + "--max_new_tokens", + type=int, + default=100, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.6, + help="Sampling temperature", + ) + parser.add_argument( + "--top_p", + type=float, + default=0.9, + help="Top-p (nucleus) sampling threshold", + ) + parser.add_argument( + "--input_len", + type=int, + default=32, + help="Input sequence length (must match export)", + ) + parser.add_argument( + "--cache_len", + type=int, + default=992, + help="Cache length (must match export: max_context_len - input_len)", + ) + parser.add_argument( + "--lookahead", + action="store_true", + help="Enable lookahead (speculative) decoding", + ) + parser.add_argument( + "--ngram_size", + type=int, + default=5, + help="N-gram size for lookahead decoding", + ) + parser.add_argument( + "--window_size", + type=int, + default=4, + help="Window size for lookahead decoding", + ) + parser.add_argument( + "--n_verifications", + type=int, + default=4, + help="Number of verification branches for lookahead decoding", + ) + + args = parser.parse_args() + + # Load tokenizer + tokenizer = Tokenizer(args.tokenizer) + + # Load model params + with open(args.params, "r") as f: + params = json.loads(f.read()) + + # Create model args + model_args = ModelArgs( + max_context_len=args.cache_len + args.input_len, + generate_full_logits=True, + **params, + ) + model_args.attention_type = "static_mha" + + print(f"Model config: {model_args.n_layers} layers, dim={model_args.dim}") + print(f"Input length: {args.input_len}, Cache length: {args.cache_len}") + + # Create StaticAttentionIOManager + mgr = StaticAttentionIOManager( + model_args, + input_len=args.input_len, + cache_lens=args.cache_len, + batch_size=1, + dtype=torch.float16, + style="smart_mask", # Use smart_mask to match C++ StaticTransformerRunner + mask_val=float("-inf"), + ) + + # Load PTE model + print(f"Loading model from {args.model}...") + runtime = Runtime.get() + program = runtime.load_program(args.model) + method = program.load_method("forward") + + metadata = method.metadata + print( + f"Method metadata: num_inputs={metadata.num_inputs()}, num_outputs={metadata.num_outputs()}" + ) + + # Get cache keys in insertion order (NOT sorted alphabetically!) + # Pytree preserves dict insertion order in Python 3.7+ + # The caches are created in layer order (0, 1, 2, ..., n_layers-1) + k_cache_keys = list(mgr.k_caches.keys()) + v_cache_keys = list(mgr.v_caches.keys()) + + # Create wrapper function that adapts PTE to eager interface + model_fn = create_pte_wrapper(method, k_cache_keys, v_cache_keys) + + # Encode prompt + prompt_tokens = tokenizer.encode(args.prompt, bos=True, eos=False) + print(f"\nPrompt: {args.prompt}") + print(f"Prompt tokens: {len(prompt_tokens)}") + print("-" * 50) + + # Reset manager + mgr.reset() + + # Prefill using StaticAttentionIOManager.prefill + print("Prefilling...", end=" ", flush=True) + start_time = time.time() + logits = mgr.prefill(model_fn, prompt_tokens) + prefill_time = time.time() - start_time + print(f"done in {prefill_time:.2f}s") + + # Get first token from prefill logits + first_token = next_token(logits[:, -1, :], args.temperature, args.top_p) + + # Decode using StaticAttentionIOManager.decode or lookahead_decode + print(f"\n{args.prompt}", end="", flush=True) + print(tokenizer.decode_token(first_token), end="", flush=True) + + decode_start = time.time() + + if args.lookahead: + # Use lookahead (speculative) decoding + print( + f"\n[Using lookahead decoding: ngram={args.ngram_size}, window={args.window_size}, verifications={args.n_verifications}]" + ) + generated_tokens = mgr.lookahead_decode( + model_fn, + first_token, + n=args.max_new_tokens - 1, # -1 because first_token counts + ngram_size=args.ngram_size, + window_size=args.window_size, + n_verifications=args.n_verifications, + stop_tokens=tokenizer.stop_tokens, + ) + else: + # Use standard autoregressive decoding + generated_tokens = mgr.decode( + model_fn, + first_token, + n=args.max_new_tokens - 1, # -1 because first_token counts + stop_tokens=tokenizer.stop_tokens, + ) + + # Print generated tokens (skip first as it's the init_token we already printed) + for token in generated_tokens[1:]: + if token in tokenizer.stop_tokens: + break + print(tokenizer.decode_token(token), end="", flush=True) + + decode_time = time.time() - decode_start + total_generated = len(generated_tokens) + tokens_per_sec = total_generated / decode_time if decode_time > 0 else 0 + + print("\n" + "-" * 50) + print(f"Prefill: {len(prompt_tokens)} tokens in {prefill_time:.2f}s") + print( + f"Decode: {total_generated} tokens in {decode_time:.2f}s ({tokens_per_sec:.2f} tok/s)" + ) + + +if __name__ == "__main__": + main()