From 7ec2374e3365e029ecd79a505c714714d6c9ebb1 Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Wed, 6 Nov 2024 18:09:10 -0800 Subject: [PATCH] update llama runner to decode single token Right now, we don't print the generated response in the eager runner until all tokens are generated. This is not good experience as we need to wait until all tokens are generated to see the response. This PR updates it to decode each new token immediately after it is generated. Differential Revision: [D65578306](https://our.internmc.facebook.com/intern/diff/D65578306/) [ghstack-poisoned] --- examples/models/llama/runner/eager.py | 16 ++++++++------ examples/models/llama/runner/generation.py | 23 +++++++-------------- examples/models/llama/runner/native.py | 8 ++----- examples/models/llama/tokenizer/tiktoken.py | 12 +++++++++++ extension/llm/tokenizer/tokenizer.py | 4 ++++ 5 files changed, 36 insertions(+), 27 deletions(-) diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index b8792151a09..abac920c6b2 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -63,6 +63,13 @@ def build_args_parser() -> argparse.ArgumentParser: default=0, ) + parser.add_argument( + "--show_tokens", + action="store_true", + default=False, + help="Show the tokens that were generated", + ) + return parser @@ -71,15 +78,12 @@ def main() -> None: args = parser.parse_args() runner = EagerLlamaRunner(args) - result = runner.text_completion( + generated_tokens = runner.text_completion( prompt=args.prompt, temperature=args.temperature, ) - print( - "Response: \n{response}\n Tokens:\n {tokens}".format( - response=result["generation"], tokens=result["tokens"] - ) - ) + if args.show_tokens: + print(f"Tokens: {generated_tokens}") if __name__ == "__main__": diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 867c41aabea..549aea360d9 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod -from typing import List, Optional, TypedDict +from typing import List, Optional import torch @@ -13,11 +13,6 @@ from executorch.extension.llm.tokenizer.utils import get_tokenizer -class CompletionPrediction(TypedDict, total=False): - generation: str - tokens: List[int] # not required - - def sample_top_p(probs, p): """ Perform top-p (nucleus) sampling on a probability distribution. @@ -84,6 +79,7 @@ def generate( # noqa: C901 ) current_token = next_token(logits, temperature, top_p) + print(f"{self.tokenizer.decode(current_token)}", end="", flush=True) tokens = prompt_tokens + [current_token] while len(tokens) < self.params.max_seq_len: @@ -101,12 +97,14 @@ def generate( # noqa: C901 tokens=torch.tensor([tokens], dtype=torch.long, device=self.device), ) current_token = next_token(logits, temperature, top_p) + tokens.append(current_token) if current_token == self.tokenizer.eos_id or ( hasattr(self.tokenizer, "stop_tokens") and current_token in self.tokenizer.stop_tokens ): break - tokens.append(current_token) + print(f"{self.tokenizer.decode(current_token)}", end="", flush=True) + print("\n") return tokens if echo else tokens[len(prompt_tokens) :] @@ -116,7 +114,7 @@ def text_completion( temperature: float = 0.6, top_p: float = 0.9, echo: bool = False, - ) -> CompletionPrediction: + ) -> List[int]: """ Perform text completion for a prompt using the language model. @@ -132,14 +130,9 @@ def text_completion( Note: This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness. """ - prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False) - generation_tokens = self.generate( - prompt_tokens=prompt_tokens, + return self.generate( + prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False), temperature=temperature, top_p=top_p, echo=echo, ) - return { - "generation": self.tokenizer.decode(generation_tokens), - "tokens": generation_tokens, - } diff --git a/examples/models/llama/runner/native.py b/examples/models/llama/runner/native.py index 73005d93330..19e57915982 100644 --- a/examples/models/llama/runner/native.py +++ b/examples/models/llama/runner/native.py @@ -107,15 +107,11 @@ def main() -> None: parser = build_args_parser() args = parser.parse_args() runner = NativeLlamaRunner(args) - result = runner.text_completion( + generated_tokens = runner.text_completion( prompt=args.prompt, temperature=args.temperature, ) - print( - "Response: \n{response}\n Tokens:\n {tokens}".format( - response=result["generation"], tokens=result["tokens"] - ) - ) + print(f"Response: {generated_tokens}") if __name__ == "__main__": diff --git a/examples/models/llama/tokenizer/tiktoken.py b/examples/models/llama/tokenizer/tiktoken.py index 1d74e5e3aa5..051911238b5 100644 --- a/examples/models/llama/tokenizer/tiktoken.py +++ b/examples/models/llama/tokenizer/tiktoken.py @@ -185,6 +185,18 @@ def decode(self, t: Sequence[int]) -> str: # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. return self.model.decode(cast(List[int], t)) + def decode(self, t: int) -> str: + """ + Decodes a single token ID into a string. + + Args: + t (int): The token ID to be decoded. + + Returns: + str: The decoded string. + """ + return self.model.decode_single_token_bytes(t).decode("utf-8") + @staticmethod def _split_whitespaces_or_nonwhitespaces( s: str, max_consecutive_slice_len: int diff --git a/extension/llm/tokenizer/tokenizer.py b/extension/llm/tokenizer/tokenizer.py index ecd0231fb6d..573a1104eb1 100644 --- a/extension/llm/tokenizer/tokenizer.py +++ b/extension/llm/tokenizer/tokenizer.py @@ -50,6 +50,10 @@ def decode(self, t: List[int]) -> str: # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. return self.sp_model.decode(t) + def decode(self, t: int) -> str: + # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. + return self.sp_model.decode(t) + def export(self, output_path: str, *, prepend_padding: bool = False) -> None: """ Export tokenizer.model to another serialization format. Here we did some lightweight