diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index 7b4ebf36a56..559b4e04892 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -84,7 +84,11 @@ def execute_runner(runner_class: Type[LlamaRunner]) -> None: with torch.no_grad(): runner = runner_class(args) # pyre-ignore: Missing argument [20] generated_tokens = ( - runner.chat_completion(temperature=args.temperature) + runner.chat_completion( + max_seq_len=1000000 if args.use_attention_sink else args.max_seq_length, + temperature=args.temperature, + show_progress=args.show_tokens, + ) if args.chat else runner.text_completion( prompt=args.prompt, diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 13ac750305f..891ce20db3e 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -168,18 +168,19 @@ def text_completion( def chat_completion( self, + max_seq_len: int, temperature: float = 0.6, top_p: float = 0.9, + show_progress: bool = False, ) -> List[int]: """ Perform multi-turn chat with the language model. Args: - prompt (str): Text prompt for completion. + max_seq_len (int): Maximum number of tokens to generate for each prompt. temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. - echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. - + show_progress (bool, optional): Flag indicating whether to show number of tokens generated. Returns: Generated list of tokens. @@ -188,20 +189,26 @@ def chat_completion( """ exit_prompt = "exit" tokens = [] + pre_stop_token = [] prompt = input("Me: ") while prompt and prompt != exit_prompt: print("LLM: ", end="", flush=True) - new_tokens = self.generate( - prompt_tokens=self.tokenizer.encode( - self._format_prompt(prompt), bos=True, eos=False - ), - max_seq_len=self.max_seq_len, + prompt_tokens = self.tokenizer.encode( + self._format_prompt(prompt), bos=True, eos=False + ) + generated_tokens = self.generate( + prompt_tokens=pre_stop_token + prompt_tokens, + max_seq_len=max_seq_len, temperature=temperature, top_p=top_p, - echo=True, + echo=False, pos_base=len(tokens) - 1 if len(tokens) > 0 else 0, ) - tokens.extend(new_tokens) + pre_stop_token = generated_tokens[-1:] + tokens.extend(prompt_tokens) + tokens.extend(generated_tokens) + if show_progress: + print(f"[Generated {len(tokens)} tokens]") prompt = input("Me: ") return tokens