From 46ee9eb2c831d1d149907e6fcac09325c7772993 Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Mon, 18 Nov 2024 10:19:00 -0800 Subject: [PATCH] Fix Cuda out of memory issue for eager runner Pull Request resolved: https://github.com/pytorch/executorch/pull/6866 This PR updates the eager runner to disable grad and save memory usage. It also update the prompt format to not include bos. ghstack-source-id: 254139542 Differential Revision: [D65962743](https://our.internmc.facebook.com/intern/diff/D65962743/) --- examples/models/llama/runner/eager.py | 24 ++++++++++++---------- examples/models/llama/runner/generation.py | 5 ++--- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index 677c2ac4da6..7b4ebf36a56 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -80,18 +80,20 @@ def build_args_parser() -> argparse.ArgumentParser: def execute_runner(runner_class: Type[LlamaRunner]) -> None: parser = build_args_parser() args = parser.parse_args() - runner = runner_class(args) # pyre-ignore: Missing argument [20] - generated_tokens = ( - runner.chat_completion(temperature=args.temperature) - if args.chat - else runner.text_completion( - prompt=args.prompt, - temperature=args.temperature, - echo=True, + + with torch.no_grad(): + runner = runner_class(args) # pyre-ignore: Missing argument [20] + generated_tokens = ( + runner.chat_completion(temperature=args.temperature) + if args.chat + else runner.text_completion( + prompt=args.prompt, + temperature=args.temperature, + echo=True, + ) ) - ) - if args.show_tokens: - print(f"Generated {len(generated_tokens)} tokens: {generated_tokens}") + if args.show_tokens: + print(f"Generated {len(generated_tokens)} tokens: {generated_tokens}") def main() -> None: diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 6d0da4bc1e5..13ac750305f 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -199,15 +199,14 @@ def chat_completion( temperature=temperature, top_p=top_p, echo=True, - pos_base=len(tokens), + pos_base=len(tokens) - 1 if len(tokens) > 0 else 0, ) tokens.extend(new_tokens) prompt = input("Me: ") return tokens def _format_prompt(self, prompt: str) -> str: - return f""" -<|begin_of_text|><|start_header_id|>system<|end_header_id|> + return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>