diff --git a/examples/models/llama2/runner/eager.py b/examples/models/llama2/runner/eager.py index 42357d6e55c..7f324b4cbc8 100644 --- a/examples/models/llama2/runner/eager.py +++ b/examples/models/llama2/runner/eager.py @@ -11,9 +11,12 @@ import torch from examples.models.llama2.llama_transformer import ModelArgs -from executorch.examples.models.model_factory import EagerModelFactory - -from .generation import LlamaRunner +from executorch.examples.models.llama2.export_llama_lib import ( + _prepare_for_llama_export, + build_args_parser as _build_args_parser, +) +from executorch.examples.models.llama2.runner.generation import LlamaRunner +from executorch.extension.llm.export import LLMEdgeManager class EagerLlamaRunner(LlamaRunner): @@ -25,21 +28,17 @@ def __init__(self, args): with open(args.params, "r") as f: params = json.loads(f.read()) model_args: ModelArgs = ModelArgs( - max_seq_len=args.max_len, + max_seq_len=args.max_seq_length, max_batch_size=1, - use_kv_cache=True, + use_kv_cache=args.use_kv_cache, **params, ) - super().__init__(tokenizer_path=args.tokenizer, model_args=model_args) - self.model, _, _, _ = EagerModelFactory.create_model( - "llama2", - "Llama2Model", - checkpoint=args.checkpoint, - params=args.params, - use_kv_cache=True, - fairseq2=False, - max_seq_len=args.max_len, - enable_dynamic_shape=True, + super().__init__(tokenizer_path=args.tokenizer_path, model_args=model_args) + manager: LLMEdgeManager = _prepare_for_llama_export("llama", args) + self.model = ( + manager.model.eval().to(device="cuda") + if torch.cuda.is_available() + else manager.model.eval().to(device="cpu") ) def forward( @@ -51,34 +50,7 @@ def forward( def build_args_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - parser.add_argument( - "--checkpoint", - type=str, - default=None, - help="path to model checkpoint file", - ) - - parser.add_argument( - "--params", - type=str, - default=None, - help="model params file", - ) - - parser.add_argument( - "--max_len", - type=int, - default=128, - help="Maximum length of the generated response sequence.", - ) - - parser.add_argument( - "--tokenizer", - type=str, - default=None, - ) + parser = _build_args_parser() parser.add_argument( "--prompt",