diff --git a/.ci/scripts/test_llama_runner_eager.sh b/.ci/scripts/test_llama_runner_eager.sh index 537d835ba1c..0f2cb7b3769 100644 --- a/.ci/scripts/test_llama_runner_eager.sh +++ b/.ci/scripts/test_llama_runner_eager.sh @@ -42,11 +42,12 @@ run_and_verify() { -d fp32 \ --max_seq_length 32 \ --temperature 0 \ + --show_tokens \ --prompt "Once upon a time," > result.txt # Verify result.txt RESULT=$(cat result.txt) - EXPECTED_RESULT="there was a little girl" + EXPECTED_RESULT="727, 471, 263, 2217, 7826, 4257, 365, 2354, 29889, 2296, 18012, 304, 1708, 5377, 297, 278, 6575, 845, 457, 29889, 3118, 2462, 29892, 1183, 4446, 263" if [[ "${RESULT}" == *"${EXPECTED_RESULT}"* ]]; then echo "Actual result: ${RESULT}" echo "Success" diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index b8792151a09..abac920c6b2 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -63,6 +63,13 @@ def build_args_parser() -> argparse.ArgumentParser: default=0, ) + parser.add_argument( + "--show_tokens", + action="store_true", + default=False, + help="Show the tokens that were generated", + ) + return parser @@ -71,15 +78,12 @@ def main() -> None: args = parser.parse_args() runner = EagerLlamaRunner(args) - result = runner.text_completion( + generated_tokens = runner.text_completion( prompt=args.prompt, temperature=args.temperature, ) - print( - "Response: \n{response}\n Tokens:\n {tokens}".format( - response=result["generation"], tokens=result["tokens"] - ) - ) + if args.show_tokens: + print(f"Tokens: {generated_tokens}") if __name__ == "__main__": diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 867c41aabea..159bc5f5017 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod -from typing import List, Optional, TypedDict +from typing import List, Optional import torch @@ -13,11 +13,6 @@ from executorch.extension.llm.tokenizer.utils import get_tokenizer -class CompletionPrediction(TypedDict, total=False): - generation: str - tokens: List[int] # not required - - def sample_top_p(probs, p): """ Perform top-p (nucleus) sampling on a probability distribution. @@ -84,6 +79,7 @@ def generate( # noqa: C901 ) current_token = next_token(logits, temperature, top_p) + print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) tokens = prompt_tokens + [current_token] while len(tokens) < self.params.max_seq_len: @@ -101,12 +97,14 @@ def generate( # noqa: C901 tokens=torch.tensor([tokens], dtype=torch.long, device=self.device), ) current_token = next_token(logits, temperature, top_p) + tokens.append(current_token) if current_token == self.tokenizer.eos_id or ( hasattr(self.tokenizer, "stop_tokens") and current_token in self.tokenizer.stop_tokens ): break - tokens.append(current_token) + print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) + print("\n") return tokens if echo else tokens[len(prompt_tokens) :] @@ -116,7 +114,7 @@ def text_completion( temperature: float = 0.6, top_p: float = 0.9, echo: bool = False, - ) -> CompletionPrediction: + ) -> List[int]: """ Perform text completion for a prompt using the language model. @@ -127,19 +125,14 @@ def text_completion( echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. Returns: - CompletionPrediction: Completion prediction, which contains the generated text completion. + Generated list of tokens. Note: This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness. """ - prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False) - generation_tokens = self.generate( - prompt_tokens=prompt_tokens, + return self.generate( + prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False), temperature=temperature, top_p=top_p, echo=echo, ) - return { - "generation": self.tokenizer.decode(generation_tokens), - "tokens": generation_tokens, - } diff --git a/examples/models/llama/runner/native.py b/examples/models/llama/runner/native.py index 73005d93330..19e57915982 100644 --- a/examples/models/llama/runner/native.py +++ b/examples/models/llama/runner/native.py @@ -107,15 +107,11 @@ def main() -> None: parser = build_args_parser() args = parser.parse_args() runner = NativeLlamaRunner(args) - result = runner.text_completion( + generated_tokens = runner.text_completion( prompt=args.prompt, temperature=args.temperature, ) - print( - "Response: \n{response}\n Tokens:\n {tokens}".format( - response=result["generation"], tokens=result["tokens"] - ) - ) + print(f"Response: {generated_tokens}") if __name__ == "__main__": diff --git a/examples/models/llama/tokenizer/tiktoken.py b/examples/models/llama/tokenizer/tiktoken.py index 1d74e5e3aa5..b48cb4dc890 100644 --- a/examples/models/llama/tokenizer/tiktoken.py +++ b/examples/models/llama/tokenizer/tiktoken.py @@ -185,6 +185,18 @@ def decode(self, t: Sequence[int]) -> str: # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. return self.model.decode(cast(List[int], t)) + def decode_token(self, t: int) -> str: + """ + Decodes a single token ID into a string. + + Args: + t (int): The token ID to be decoded. + + Returns: + str: The decoded string. + """ + return self.model.decode_single_token_bytes(t).decode("utf-8") + @staticmethod def _split_whitespaces_or_nonwhitespaces( s: str, max_consecutive_slice_len: int diff --git a/extension/llm/tokenizer/tokenizer.py b/extension/llm/tokenizer/tokenizer.py index ecd0231fb6d..78377230b9f 100644 --- a/extension/llm/tokenizer/tokenizer.py +++ b/extension/llm/tokenizer/tokenizer.py @@ -50,6 +50,10 @@ def decode(self, t: List[int]) -> str: # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. return self.sp_model.decode(t) + def decode_token(self, t: int) -> str: + # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. + return self.sp_model.decode(t) + def export(self, output_path: str, *, prepend_padding: bool = False) -> None: """ Export tokenizer.model to another serialization format. Here we did some lightweight