diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index 677c2ac4da6..7b4ebf36a56 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -80,18 +80,20 @@ def build_args_parser() -> argparse.ArgumentParser: def execute_runner(runner_class: Type[LlamaRunner]) -> None: parser = build_args_parser() args = parser.parse_args() - runner = runner_class(args) # pyre-ignore: Missing argument [20] - generated_tokens = ( - runner.chat_completion(temperature=args.temperature) - if args.chat - else runner.text_completion( - prompt=args.prompt, - temperature=args.temperature, - echo=True, + + with torch.no_grad(): + runner = runner_class(args) # pyre-ignore: Missing argument [20] + generated_tokens = ( + runner.chat_completion(temperature=args.temperature) + if args.chat + else runner.text_completion( + prompt=args.prompt, + temperature=args.temperature, + echo=True, + ) ) - ) - if args.show_tokens: - print(f"Generated {len(generated_tokens)} tokens: {generated_tokens}") + if args.show_tokens: + print(f"Generated {len(generated_tokens)} tokens: {generated_tokens}") def main() -> None: diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 6d0da4bc1e5..13ac750305f 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -199,15 +199,14 @@ def chat_completion( temperature=temperature, top_p=top_p, echo=True, - pos_base=len(tokens), + pos_base=len(tokens) - 1 if len(tokens) > 0 else 0, ) tokens.extend(new_tokens) prompt = input("Me: ") return tokens def _format_prompt(self, prompt: str) -> str: - return f""" -<|begin_of_text|><|start_header_id|>system<|end_header_id|> + return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>