From a011764eaa6880458c070e0278a25d2379f62f5c Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Mon, 9 Jun 2025 12:55:23 -0700 Subject: [PATCH 1/2] Add new export LLM config Pull Request resolved: https://github.com/pytorch/executorch/pull/11028 @imported-using-ghimport Differential Revision: [D75263991](https://our.internmc.facebook.com/intern/diff/D75263991/) ghstack-source-id: 289208257 --- examples/models/llama/config/TARGETS | 9 + examples/models/llama/config/llm_config.py | 495 ++++++++++++++++++ examples/models/llama/config/targets.bzl | 26 + .../models/llama/config/test_llm_config.py | 104 ++++ examples/models/llama/export_llama_lib.py | 2 +- 5 files changed, 635 insertions(+), 1 deletion(-) create mode 100644 examples/models/llama/config/TARGETS create mode 100644 examples/models/llama/config/llm_config.py create mode 100644 examples/models/llama/config/targets.bzl create mode 100644 examples/models/llama/config/test_llm_config.py diff --git a/examples/models/llama/config/TARGETS b/examples/models/llama/config/TARGETS new file mode 100644 index 00000000000..2ba1b55a3dd --- /dev/null +++ b/examples/models/llama/config/TARGETS @@ -0,0 +1,9 @@ +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/examples/models/llama/config/llm_config.py b/examples/models/llama/config/llm_config.py new file mode 100644 index 00000000000..b929e756c3e --- /dev/null +++ b/examples/models/llama/config/llm_config.py @@ -0,0 +1,495 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Configurations for exporting Llama. + +Uses dataclasses, which integrate with OmegaConf and Hydra. +""" + +import ast +import re +from dataclasses import dataclass, field +from enum import Enum +from typing import ClassVar, List, Optional + + +################################################################################ +################################## BaseConfig ################################## +################################################################################ + + +class ModelType(str, Enum): + STORIES110M = "stories110m" + LLAMA2 = "llama2" + LLAMA3 = "llama3" + LLAMA3_1 = "llama3_1" + LLAMA3_2 = "llama3_2" + LLAMA3_2_VISION = "llama3_2_vision" + STATIC_LLAMA = "static_llama" + QWEN2_5 = "qwen2_5" + QWEN3_0_6B = "qwen3-0_6b" + QWEN3_1_7B = "qwen3-1_7b" + QWEN3_4B = "qwen3-4b" + PHI_4_MINI = "phi_4_mini" + SMOLLM2 = "smollm2" + + +class PreqMode(str, Enum): + """ + If you are dealing with pre-quantized checkpoints, this used to + be the way to specify them. Now you don't need to specify these + options if you use a TorchAo-prequantized checkpoint, but they + are still around to preserve backward compatibility. + """ + + PREQ_8DA4W = "8da4w" + PREQ_8DA4W_OUT_8DA8W = "8da4w_output_8da8w" + + +@dataclass +class BaseConfig: + """ + Configurations specific to the model, e.g. whether it’s Qwen3 or Phi-4-mini, + and are the minimal set of parameters needed to load the pretrained + eager model and its weights. + + Attributes: + model_class: Which model to to export. + params: Model parameters, such as n_layers, hidden_size, etc. + If left empty will use defaults specified in model_args.py. + checkpoint: Path to the checkpoint file. + If left empty, the model will be initialized with random weights. + checkpoint_dir: Path to directory containing sharded checkpoint files. + tokenizer_path: Path to the tokenizer file. + metadata: Json string containing metadata information. + e.g. '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' + use_lora: Rank of the LoRA, if set to 0 then this means no LoRA. For use with QAT. + fairseq2: For legacy internal use cases, this is safe to ignore. + preq_mode: Legacy option to specify how prequantized weights are loaded. + Going forward, ExecuTorch supports loading weights prequantized through + TorchAo as-is, without any special handling. + preq_group_size: Legacy option to specify the group size of prequantized weights. + preq_embedding_quantize: Legacy option to specify how prequantized embeddings + are loaded. + """ + + model_class: ModelType = ModelType.LLAMA3 + params: Optional[str] = None + checkpoint: Optional[str] = None + checkpoint_dir: Optional[str] = None + tokenizer_path: Optional[str] = None + metadata: Optional[str] = None + use_lora: int = int + fairseq2: bool = False + preq_mode: Optional[PreqMode] = None + preq_group_size: int = 32 + preq_embedding_quantize: str = "8,0" + + +################################################################################ +################################# ModelConfig ################################## +################################################################################ + + +class DtypeOverride(str, Enum): + """ + DType of the model. Highly recommended to use "fp32", unless you want to + export without a backend, in which case you can also use "bf16". "fp16" + is not recommended. + """ + + FP32 = "fp32" + FP16 = "fp16" + BF16 = "bf16" + + +@dataclass +class ModelConfig: + """ + Configurations not necessarily specific to the model, but are needed to + finish off the rest of the model configuration in eager. You can think + of these like optimizations / actual configurations. The same ModelConfig + can be applied to multiple models. + + Attributes: + dtype_override: dtype to cast the model to. + enable_dynamic_shape: whether to enable dynamic shapes on the sequence + length so that the model can handle arbitrary prefill lengths and + token generation. + use_shared_embeddings: whether the embedding/output weights should be + shared. Only available with torchao kernels, e.g. when + qmode set to use a "torchao:8da(\\d+)w" pattern. + use_sdpa_with_kv_cache: Whether to use flash attention by substituting + for our custom SDPA op. Note that the naming is poor and this + doesn't actually have anything to do with the kv_cache at the moment. + expand_rope_table: Temporary workaround to expand sin/cos table in head + dim to take vectorized path in optimized kernels. + use_attention_sink: Whether to use attention sink to support multi-round + conversation. Structured as: + ',,', + e.g., '4,2044,1024'. + output_prune_map: Path to the output pruning token mapping file (token_map.json). + input_prune_map: Path to the output pruning token mapping file (token_map.json). + use_kv_cache: Whether to use KV cache. + quantize_kv_cache: Whether to perform int8 per token quantization on the KV cache. + local_global_attention: List of integers specifying local and global attention pattern. + e.g., [0, 16, 0, 16] to specify that every other layer is sliding window of 16. + [0, 16, 32] pattern specifies 2nd and 3rd layers have sliding windows of 16 and 32. + [16] pattern specifies all layers have a sliding window of 16. + """ + + dtype_override: DtypeOverride = DtypeOverride.FP32 + enable_dynamic_shape: bool = True + use_shared_embedding: bool = False + use_sdpa_with_kv_cache: bool = False + expand_rope_table: bool = False + use_attention_sink: Optional[str] = None + output_prune_map: Optional[str] = None + input_prune_map: Optional[str] = None + use_kv_cache: bool = False + quantize_kv_cache: bool = False + local_global_attention: Optional[List[int]] = None + + def __post_init__(self): + self._validate_attention_sink() + self._validate_local_global_attention() + + if self.quantize_kv_cache and not self.use_kv_cache: + raise ValueError( + "Cannot quantize the KV cache (quantize_kv_cache) without enabling the KV cache (use_kv_cache)" + ) + + if self.local_global_attention and not self.use_kv_cache: + raise ValueError( + "Cannot use local_global_attention without enabling the KV cache (use_kv_cache)" + ) + + def _validate_attention_sink(self): + if self.use_attention_sink: + attention_sink_params = self.use_attention_sink.split(",") + if len(attention_sink_params) != 3: + raise ValueError( + "The value of use_attention_sink must be structured like ',,'" + ) + + def _validate_local_global_attention(self): + if self.local_global_attention: + local_global_err = "The value of local_global_attention must be a list of integers, e.g., [0, 16, 0, 16]" + try: + parsed = ast.literal_eval(self.local_global_attention) + if not ( + isinstance(parsed, list) and all(isinstance(i, int) for i in parsed) + ): + raise ValueError(local_global_err) + except Exception: + raise ValueError(local_global_err) + + +################################################################################ +################################ ExportConfig ################################## +################################################################################ + + +@dataclass +class ExportConfig: + """ + Configures properties relevant to the export process. + + Attributes: + max_seq_length: Maximum length of sequence to evaluate. + max_context_length: Maximum of context for the model to remember. + output_dir: Output dir to save the exported .pte file to. + output_name: File name to override the exported .pte file. + so_library: Shared library to specify custom quantized operators. + export_only: Whether to stop right after torch.export() and + just save the exported .pt2 graph file. + """ + + max_seq_length: int = 128 + max_context_length: int = 128 + output_dir: Optional[str] = None + output_name: Optional[str] = None + so_library: Optional[str] = None + export_only: bool = False + + def __post_init__(self): + if self.max_context_length > self.max_seq_length: + raise ValueError( + f"max_context_length of {self.max_context_length} cannot be greater than max_seq_length of {self.max_seq_length}" + ) + + +################################################################################ +################################# DebugConfig ################################## +################################################################################ + + +@dataclass +class DebugConfig: + """ + Configures options to debug the export process. + + Attributes: + profile_memory: Whether to generate a chrome trace of activation memory + for intermediate tensors. + profile_path: Use cProfile to profile the export. Results are saved to + profile_path as an html file. + generate_etrecord: Whether to generate an ETRecord debug artifact. + generate_full_logits: Whether to keep the full logits, potentially useful + for debugging purposes. Kept off by default to save memory. + verbose: Whether to log the export process verbosely (log level >= INFO). + """ + + profile_memory: bool = False + profile_path: Optional[str] = None + generate_etrecord: bool = False + generate_full_logits: bool = False + verbose: bool = False + + +################################################################################ +############################# QuantizationConfig ############################### +################################################################################ + + +class Pt2eQuantize(str, Enum): + """ + Type of backend-specific Pt2e quantization strategy to use. + + Pt2e uses a different quantization library that is graph-based + compared to `qmode`, which is also specified in the QuantizationConfig + and is source transform-based. + """ + + XNNPACK_DYNAMIC = "xnnpack_dynamic" + XNNPACK_DYNAMIC_QC4 = "xnnpack_dynamic_qc4" + QNN_8A8W = "qnn_8a8w" + QNN_16A16W = "qnn_16a16w" + QNN_16A4W = "qnn_16a4w" + COREML_C4W = "coreml_c4w" + COREML_8A_C8W = "coreml_8a_c8w" + COREML_8A_C4W = "coreml_8a_c4w" + COREML_BASELINE_8A_C8W = "coreml_baseline_8a_c8w" + COREML_BASELINE_8A_C4W = "coreml_baseline_8a_c4w" + VULKAN_8W = "vulkan_8w" + + +class SpinQuant(str, Enum): + CUDA = "cuda" + NATIVE = "native" + + +@dataclass +class QuantizationConfig: + """ + Configures how the model should be quantized (PTQ). + + Attributes: + qmode: Quantization mode using TorchAo, expressed as a string. + See the __post_init__ validation for available qmode options. + embedding_quantize: Type of embedding quantization. + Must be of the format ',', e.g., '8,1024'. + pt2e_quantize: Quantization mode using pt2e, which is an alternative + to TorchAo that uses backend-aware graph mode quantization rather + than source transformation quantization. + group_size: Group size for quantization. + use_spin_quant: Which spin quant mode to use. If unspecified, don't use + spin quant. + use_qat: Whether the checkpoint is quantization-awarely trained. + calibration_tasks: Tasks for GPTQ calibration from lm_eval. + calibration_limit: Number of samples used for calibration from lm_eval. + calibration_seq_length: Sequence length for GPTQ calibration from lm_eval. + calibration_data: Prompts use for calibration. + """ + + # Constants. + QMODE_OPTIONS: ClassVar[List[str]] = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"] + AO_QUANT_PATTERNS: ClassVar[List[str]] = [ + r"torchao:8da(\d+)w", + r"torchao:fpa(\d+)w", + ] + + qmode: Optional[str] = None + embedding_quantize: Optional[str] = None + pt2e_quantize: Optional[Pt2eQuantize] = None + group_size: Optional[int] = None + use_spin_quant: Optional[SpinQuant] = None + use_qat: bool = False + calibration_tasks: Optional[List[str]] = None + calibration_limit: Optional[int] = None + calibration_seq_length: Optional[int] = None + calibration_data: str = "Once upon a time" + + def __post_init__(self): + if self.qmode: + self._validate_qmode() + + def _validate_qmode(self) -> None: + if not self.qmode: + return + + if self.qmode in self.QMODE_OPTIONS: + return + + # If qmode is one of these below patterns, this means that we + # are using ARM-based torchao ops. + for pattern in self.AO_QUANT_PATTERNS: + matches = re.findall(pattern, self.qmode) + if len(matches) == 1: + return + + raise ValueError( + f"Got qmode {self.qmode}, but expected one of {self.QMODE_OPTIONS}, or one of the regex patterns {self.AO_QUANT_PATTERNS}." + ) + + def _validate_embedding_quantize(self): + if len(self.embedding_quantize.split(",")) != 2: + raise ValueError( + f'embedding_quantize of {self.embedding_quantize} must follow the following format: ","' + ) + + +################################################################################ +############################### BackendConfig ################################## +################################################################################ + + +@dataclass +class XNNPackConfig: + """ + Configures the XNNPack backend. + + Attributes: + enabled: :) + extended_ops: Whether to match more types of ops to delegates to XNNPack. + """ + + enabled: bool = False + extended_ops: bool = False + + +class CoreMLQuantize(str, Enum): + B4W = "b4w" + C4W = "c4w" + + +class CoreMLComputeUnit(str, Enum): + CPU_ONLY = "cpu_only" + CPU_AND_GPU = "cpu_and_gpu" + CPU_AND_NE = "cpu_and_ne" + ALL = "all" + + +@dataclass +class CoreMLConfig: + """ + Configures the CoreML backend. + """ + + enabled: bool = False + enable_state: bool = False + preserve_sdpa: bool = False + quantize: Optional[CoreMLQuantize] = None + ios: int = 15 + compute_units: CoreMLComputeUnit = CoreMLComputeUnit.CPU_ONLY + + def __post_init__(self): + if self.ios not in (15, 16, 17, 18): + raise ValueError(f"Invalid coreml ios version: {self.ios}") + + +@dataclass +class VulkanConfig: + """ + Configures the Vulkan backend. + """ + + enabled: bool = False + + +@dataclass +class QNNConfig: + """ + Configures the QNN backend. + """ + + enabled: bool = False + use_sha: bool = False + soc_model: str = "SM8650" + use_qnn_sha: bool = False + optimized_rotation_path: Optional[str] = None + num_sharding: int = 0 + + +@dataclass +class MPSConfig: + """ + Configures the MPS backend. + """ + + enabled: bool = False + + +@dataclass +class BackendConfig: + """ + Configures which backends should be used and how the backends + should be set up. + """ + + xnnpack: XNNPackConfig = field(default_factory=XNNPackConfig) + coreml: CoreMLConfig = field(default_factory=CoreMLConfig) + vulkan: VulkanConfig = field(default_factory=VulkanConfig) + qnn: QNNConfig = field(default_factory=QNNConfig) + mps: MPSConfig = field(default_factory=MPSConfig) + + +################################################################################ +################################## LlmConfig ################################### +################################################################################ + + +@dataclass +class LlmConfig: + """ + The overall configuration for customizing the LLM export process. + """ + + base: BaseConfig = field(default_factory=BaseConfig) + model: ModelConfig = field(default_factory=ModelConfig) + export: ExportConfig = field(default_factory=ExportConfig) + debug: DebugConfig = field(default_factory=DebugConfig) + quantization: QuantizationConfig = field(default_factory=QuantizationConfig) + backend: BackendConfig = field(default_factory=BackendConfig) + + def __post_init__(self): + self._validate_low_bit() + + def _validate_low_bit(self): + if not self.quantization.qmode: + return + + using_lowbit_ops = False + for pattern in self.quantization.AO_QUANT_PATTERNS: + matches = re.findall(pattern, self.quantization.qmode) + if len(matches) == 1: + using_lowbit_ops = True + + # If we are using Ao's low bit quantization kernels for ARM, + # we do not want to also be delegating to a CPU backend (XNNPack). + if using_lowbit_ops and self.backend.xnnpack.enabled: + raise ValueError( + "Cannot use low-bit Ao ops (from qmode=torchao:...) while also delegating to XNNPack." + ) + + # Also we can only use shared embeddings if we are using low bit kernels. + if self.model.use_shared_embedding and not using_lowbit_ops: + raise ValueError( + "Can only use shared embeddings with low-bit ops (with qmode=torchao:...)." + ) diff --git a/examples/models/llama/config/targets.bzl b/examples/models/llama/config/targets.bzl new file mode 100644 index 00000000000..8b85ce6d107 --- /dev/null +++ b/examples/models/llama/config/targets.bzl @@ -0,0 +1,26 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") + +def define_common_targets(): + runtime.python_library( + name = "llm_config", + srcs = [ + "llm_config.py", + ], + _is_external_target = True, + base_module = "executorch.examples.models.llama.config", + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + python_unittest( + name = "test_llm_config", + srcs = [ + "test_llm_config.py", + ], + deps = [ + ":llm_config", + ], + ) diff --git a/examples/models/llama/config/test_llm_config.py b/examples/models/llama/config/test_llm_config.py new file mode 100644 index 00000000000..0853e9dbbd8 --- /dev/null +++ b/examples/models/llama/config/test_llm_config.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +from executorch.examples.models.llama.config.llm_config import ( + BackendConfig, + BaseConfig, + CoreMLComputeUnit, + CoreMLConfig, + DebugConfig, + ExportConfig, + LlmConfig, + ModelConfig, + QuantizationConfig, + XNNPackConfig, +) + + +class TestValidation(unittest.TestCase): + def test_invalid_attention_sink(self): + with self.assertRaises(ValueError): + ModelConfig(use_attention_sink="4,2048") + + def test_invalid_local_global_attention_format(self): + with self.assertRaises(ValueError): + ModelConfig(local_global_attention="notalist") + + def test_quantize_kv_without_kv(self): + with self.assertRaises(ValueError): + ModelConfig(quantize_kv_cache=True) + + def test_local_global_attention_without_kv(self): + with self.assertRaises(ValueError): + ModelConfig(local_global_attention="[16]", use_kv_cache=False) + + def test_invalid_export_config_context_length(self): + with self.assertRaises(ValueError): + ExportConfig(max_seq_length=128, max_context_length=256) + + def test_invalid_qmode(self): + with self.assertRaises(ValueError): + QuantizationConfig(qmode="unknown") + + def test_invalid_coreml_ios(self): + with self.assertRaises(ValueError): + CoreMLConfig(ios=14) + + def test_lowbit_conflict_with_xnnpack(self): + qcfg = QuantizationConfig(qmode="torchao:8da4w") + bcfg = BackendConfig(xnnpack=XNNPackConfig(enabled=True)) + model_cfg = ModelConfig(use_shared_embedding=True) + + with self.assertRaises(ValueError): + LlmConfig(model=model_cfg, quantization=qcfg, backend=bcfg) + + def test_shared_embedding_without_lowbit(self): + model_cfg = ModelConfig(use_shared_embedding=True) + qcfg = QuantizationConfig(qmode="int8") + + with self.assertRaises(ValueError): + LlmConfig(model=model_cfg, quantization=qcfg) + + +class TestValidConstruction(unittest.TestCase): + + def test_valid_llm_config(self): + LlmConfig( + base=BaseConfig( + model_class="llama3", + checkpoint="checkpoints/model.pt", + tokenizer_path="tokenizer.json", + use_lora=8, + ), + model=ModelConfig( + dtype_override="fp32", + use_attention_sink="4,2048,1024", + use_kv_cache=True, + local_global_attention="[16, 32]", + ), + export=ExportConfig( + max_seq_length=256, + max_context_length=128, + output_dir="/tmp/export", + output_name="model.pte", + ), + debug=DebugConfig(profile_memory=True, verbose=True), + quantization=QuantizationConfig(qmode="torchao:8da4w"), + backend=BackendConfig( + xnnpack=XNNPackConfig(enabled=False), + coreml=CoreMLConfig( + enabled=True, ios=17, compute_units=CoreMLComputeUnit.ALL + ), + ), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 96faf64475e..11fb2fa3cbb 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -486,7 +486,7 @@ def build_args_parser() -> argparse.ArgumentParser: "--use_qat", default=False, action="store_true", - help="Whether the checkpoin is pre-quantized with QAT or not.", + help="Whether the checkpoint is pre-quantized with QAT or not.", ) parser.add_argument( From 8f1c751d3e9d5f9b1c6abb38a7b0517e86a46717 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Mon, 9 Jun 2025 13:56:25 -0700 Subject: [PATCH 2/2] Introduce hydra framework with backwards compatibility Pull Request resolved: https://github.com/pytorch/executorch/pull/11029 @imported-using-ghimport Differential Revision: [D75263989](https://our.internmc.facebook.com/intern/diff/D75263989/) ghstack-source-id: 289227700 --- examples/models/llama/TARGETS | 4 ++ examples/models/llama/config/llm_config.py | 13 +++++++ examples/models/llama/export_llama.py | 38 ++++++++++++++----- examples/models/llama/export_llama_args.py | 21 ++++++++++ examples/models/llama/export_llama_hydra.py | 27 +++++++++++++ examples/models/llama/export_llama_lib.py | 21 +++++++++- examples/models/llama/install_requirements.sh | 2 +- 7 files changed, 115 insertions(+), 11 deletions(-) create mode 100644 examples/models/llama/export_llama_args.py create mode 100644 examples/models/llama/export_llama_hydra.py diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index 872eccce872..b51e164d483 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -132,6 +132,8 @@ runtime.python_library( name = "export_library", srcs = [ "export_llama.py", + "export_llama_args.py", + "export_llama_hydra.py", "export_llama_lib.py", "model.py", ], @@ -148,6 +150,8 @@ runtime.python_library( ":source_transformation", "//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform", "//caffe2:torch", + "//executorch/examples/models/llama/config:llm_config", + "//executorch/backends/vulkan/_passes:vulkan_passes", "//executorch/exir/passes:init_mutable_pass", "//executorch/examples/models:model_base", "//executorch/examples/models:models", diff --git a/examples/models/llama/config/llm_config.py b/examples/models/llama/config/llm_config.py index b929e756c3e..d9a8a8e6192 100644 --- a/examples/models/llama/config/llm_config.py +++ b/examples/models/llama/config/llm_config.py @@ -12,6 +12,7 @@ Uses dataclasses, which integrate with OmegaConf and Hydra. """ +import argparse import ast import re from dataclasses import dataclass, field @@ -468,6 +469,18 @@ class LlmConfig: quantization: QuantizationConfig = field(default_factory=QuantizationConfig) backend: BackendConfig = field(default_factory=BackendConfig) + @classmethod + def from_args(cls, args: argparse.Namespace) -> "LlmConfig": + """ + To support legacy purposes, this function converts CLI args from + argparse to an LlmConfig, which is used by the LLM export process. + """ + llm_config = LlmConfig() + + # TODO: conversion code. + + return llm_config + def __post_init__(self): self._validate_low_bit() diff --git a/examples/models/llama/export_llama.py b/examples/models/llama/export_llama.py index e25a8a007eb..63e76e28ba9 100644 --- a/examples/models/llama/export_llama.py +++ b/examples/models/llama/export_llama.py @@ -4,30 +4,50 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# Example script for exporting Llama2 to flatbuffer - -import logging - # force=True to ensure logging while in debugger. Set up logger before any # other imports. +import logging + FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT, force=True) +import argparse +import runpy import sys import torch -from .export_llama_lib import build_args_parser, export_llama - sys.setrecursionlimit(4096) +def parse_hydra_arg(): + """First parse out the arg for whether to use Hydra or the old CLI.""" + parser = argparse.ArgumentParser(add_help=True) + parser.add_argument("--hydra", action="store_true") + args, remaining = parser.parse_known_args() + return args.hydra, remaining + + def main() -> None: seed = 42 torch.manual_seed(seed) - parser = build_args_parser() - args = parser.parse_args() - export_llama(args) + + use_hydra, remaining_args = parse_hydra_arg() + if use_hydra: + # The import runs the main function of export_llama_hydra with the remaining args + # under the Hydra framework. + sys.argv = [arg for arg in sys.argv if arg != "--hydra"] + print(f"running with {sys.argv}") + runpy.run_module( + "executorch.examples.models.llama.export_llama_hydra", run_name="__main__" + ) + else: + # Use the legacy version of the export_llama script which uses argsparse. + from executorch.examples.models.llama.export_llama_args import ( + main as export_llama_args_main, + ) + + export_llama_args_main(remaining_args) if __name__ == "__main__": diff --git a/examples/models/llama/export_llama_args.py b/examples/models/llama/export_llama_args.py new file mode 100644 index 00000000000..7a176d9b7d0 --- /dev/null +++ b/examples/models/llama/export_llama_args.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Run export_llama with the legacy argparse setup. +""" + +from .export_llama_lib import build_args_parser, export_llama + + +def main(args) -> None: + parser = build_args_parser() + args = parser.parse_args(args) + export_llama(args) + + +if __name__ == "__main__": + main() diff --git a/examples/models/llama/export_llama_hydra.py b/examples/models/llama/export_llama_hydra.py new file mode 100644 index 00000000000..73eca7e2a5a --- /dev/null +++ b/examples/models/llama/export_llama_hydra.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Run export_llama using the new Hydra CLI. +""" + +import hydra + +from executorch.examples.models.llama.config.llm_config import LlmConfig +from executorch.examples.models.llama.export_llama_lib import export_llama +from hydra.core.config_store import ConfigStore + +cs = ConfigStore.instance() +cs.store(name="llm_config", node=LlmConfig) + + +@hydra.main(version_base=None, config_name="llm_config") +def main(llm_config: LlmConfig) -> None: + export_llama(llm_config) + + +if __name__ == "__main__": + main() diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 11fb2fa3cbb..12406cc762e 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -27,6 +27,8 @@ from executorch.devtools.backend_debug import print_delegation_info from executorch.devtools.etrecord import generate_etrecord as generate_etrecord_func + +from executorch.examples.models.llama.config.llm_config import LlmConfig from executorch.examples.models.llama.hf_download import ( download_and_convert_hf_checkpoint, ) @@ -50,6 +52,7 @@ get_vulkan_quantizer, ) from executorch.util.activation_memory_profiler import generate_memory_trace +from omegaconf.dictconfig import DictConfig from ..model_factory import EagerModelFactory from .source_transformation.apply_spin_quant_r1_r2 import ( @@ -567,7 +570,23 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str: return return_val -def export_llama(args) -> str: +def export_llama( + export_options: Union[argparse.Namespace, DictConfig], +) -> str: + if isinstance(export_options, argparse.Namespace): + # Legacy CLI. + args = export_options + llm_config = LlmConfig.from_args(export_options) # noqa: F841 + elif isinstance(export_options, DictConfig): + # Hydra CLI. + llm_config = export_options # noqa: F841 + else: + raise ValueError( + "Input to export_llama must be either of type argparse.Namespace or LlmConfig" + ) + + # TODO: refactor rest of export_llama to use llm_config instead of args. + # If a checkpoint isn't provided for an HF OSS model, download and convert the # weights first. if not args.checkpoint and args.model in HUGGING_FACE_REPO_IDS: diff --git a/examples/models/llama/install_requirements.sh b/examples/models/llama/install_requirements.sh index b9e0f9210c5..580a152a322 100755 --- a/examples/models/llama/install_requirements.sh +++ b/examples/models/llama/install_requirements.sh @@ -10,7 +10,7 @@ # Install tokenizers for hf .json tokenizer. # Install snakeviz for cProfile flamegraph # Install lm-eval for Model Evaluation with lm-evalution-harness. -pip install huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile +pip install hydra-core huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile # Call the install helper for further setup python examples/models/llama/install_requirement_helper.py