From 57973d8a8864f5450ad2dd70646cb6d3706fca22 Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Mon, 11 Nov 2024 11:16:00 -0800 Subject: [PATCH 1/2] update llama runner to decode single token Pull Request resolved: https://github.com/pytorch/executorch/pull/6703 Right now, we don't print the generated response in the eager runner until all tokens are generated. This is not good experience as we need to wait until all tokens are generated to see the response. This PR updates it to decode each new token immediately after it is generated. ghstack-source-id: 252924039 Differential Revision: [D65578306](https://our.internmc.facebook.com/intern/diff/D65578306/) --- .ci/scripts/test_llama_runner_eager.sh | 3 ++- examples/models/llama/runner/eager.py | 16 ++++++++----- examples/models/llama/runner/generation.py | 25 ++++++++------------- examples/models/llama/runner/native.py | 8 ++----- examples/models/llama/tokenizer/tiktoken.py | 12 ++++++++++ extension/llm/tokenizer/tokenizer.py | 4 ++++ 6 files changed, 39 insertions(+), 29 deletions(-) diff --git a/.ci/scripts/test_llama_runner_eager.sh b/.ci/scripts/test_llama_runner_eager.sh index 537d835ba1c..0f2cb7b3769 100644 --- a/.ci/scripts/test_llama_runner_eager.sh +++ b/.ci/scripts/test_llama_runner_eager.sh @@ -42,11 +42,12 @@ run_and_verify() { -d fp32 \ --max_seq_length 32 \ --temperature 0 \ + --show_tokens \ --prompt "Once upon a time," > result.txt # Verify result.txt RESULT=$(cat result.txt) - EXPECTED_RESULT="there was a little girl" + EXPECTED_RESULT="727, 471, 263, 2217, 7826, 4257, 365, 2354, 29889, 2296, 18012, 304, 1708, 5377, 297, 278, 6575, 845, 457, 29889, 3118, 2462, 29892, 1183, 4446, 263" if [[ "${RESULT}" == *"${EXPECTED_RESULT}"* ]]; then echo "Actual result: ${RESULT}" echo "Success" diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index b8792151a09..abac920c6b2 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -63,6 +63,13 @@ def build_args_parser() -> argparse.ArgumentParser: default=0, ) + parser.add_argument( + "--show_tokens", + action="store_true", + default=False, + help="Show the tokens that were generated", + ) + return parser @@ -71,15 +78,12 @@ def main() -> None: args = parser.parse_args() runner = EagerLlamaRunner(args) - result = runner.text_completion( + generated_tokens = runner.text_completion( prompt=args.prompt, temperature=args.temperature, ) - print( - "Response: \n{response}\n Tokens:\n {tokens}".format( - response=result["generation"], tokens=result["tokens"] - ) - ) + if args.show_tokens: + print(f"Tokens: {generated_tokens}") if __name__ == "__main__": diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 867c41aabea..159bc5f5017 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod -from typing import List, Optional, TypedDict +from typing import List, Optional import torch @@ -13,11 +13,6 @@ from executorch.extension.llm.tokenizer.utils import get_tokenizer -class CompletionPrediction(TypedDict, total=False): - generation: str - tokens: List[int] # not required - - def sample_top_p(probs, p): """ Perform top-p (nucleus) sampling on a probability distribution. @@ -84,6 +79,7 @@ def generate( # noqa: C901 ) current_token = next_token(logits, temperature, top_p) + print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) tokens = prompt_tokens + [current_token] while len(tokens) < self.params.max_seq_len: @@ -101,12 +97,14 @@ def generate( # noqa: C901 tokens=torch.tensor([tokens], dtype=torch.long, device=self.device), ) current_token = next_token(logits, temperature, top_p) + tokens.append(current_token) if current_token == self.tokenizer.eos_id or ( hasattr(self.tokenizer, "stop_tokens") and current_token in self.tokenizer.stop_tokens ): break - tokens.append(current_token) + print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) + print("\n") return tokens if echo else tokens[len(prompt_tokens) :] @@ -116,7 +114,7 @@ def text_completion( temperature: float = 0.6, top_p: float = 0.9, echo: bool = False, - ) -> CompletionPrediction: + ) -> List[int]: """ Perform text completion for a prompt using the language model. @@ -127,19 +125,14 @@ def text_completion( echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. Returns: - CompletionPrediction: Completion prediction, which contains the generated text completion. + Generated list of tokens. Note: This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness. """ - prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False) - generation_tokens = self.generate( - prompt_tokens=prompt_tokens, + return self.generate( + prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False), temperature=temperature, top_p=top_p, echo=echo, ) - return { - "generation": self.tokenizer.decode(generation_tokens), - "tokens": generation_tokens, - } diff --git a/examples/models/llama/runner/native.py b/examples/models/llama/runner/native.py index 73005d93330..19e57915982 100644 --- a/examples/models/llama/runner/native.py +++ b/examples/models/llama/runner/native.py @@ -107,15 +107,11 @@ def main() -> None: parser = build_args_parser() args = parser.parse_args() runner = NativeLlamaRunner(args) - result = runner.text_completion( + generated_tokens = runner.text_completion( prompt=args.prompt, temperature=args.temperature, ) - print( - "Response: \n{response}\n Tokens:\n {tokens}".format( - response=result["generation"], tokens=result["tokens"] - ) - ) + print(f"Response: {generated_tokens}") if __name__ == "__main__": diff --git a/examples/models/llama/tokenizer/tiktoken.py b/examples/models/llama/tokenizer/tiktoken.py index 1d74e5e3aa5..b48cb4dc890 100644 --- a/examples/models/llama/tokenizer/tiktoken.py +++ b/examples/models/llama/tokenizer/tiktoken.py @@ -185,6 +185,18 @@ def decode(self, t: Sequence[int]) -> str: # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. return self.model.decode(cast(List[int], t)) + def decode_token(self, t: int) -> str: + """ + Decodes a single token ID into a string. + + Args: + t (int): The token ID to be decoded. + + Returns: + str: The decoded string. + """ + return self.model.decode_single_token_bytes(t).decode("utf-8") + @staticmethod def _split_whitespaces_or_nonwhitespaces( s: str, max_consecutive_slice_len: int diff --git a/extension/llm/tokenizer/tokenizer.py b/extension/llm/tokenizer/tokenizer.py index ecd0231fb6d..78377230b9f 100644 --- a/extension/llm/tokenizer/tokenizer.py +++ b/extension/llm/tokenizer/tokenizer.py @@ -50,6 +50,10 @@ def decode(self, t: List[int]) -> str: # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. return self.sp_model.decode(t) + def decode_token(self, t: int) -> str: + # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. + return self.sp_model.decode(t) + def export(self, output_path: str, *, prepend_padding: bool = False) -> None: """ Export tokenizer.model to another serialization format. Here we did some lightweight From 8e2d359f0cee5dbd24f77868f7a10b86ecba09bd Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Mon, 11 Nov 2024 11:54:17 -0800 Subject: [PATCH 2/2] add the ability to have multi-round conversation with llama Ad the ability to have multi-round conversations with LLM. This will be helpful for testing long context length. Differential Revision: [D65771122](https://our.internmc.facebook.com/intern/diff/D65771122/) ghstack-source-id: 252934165 Pull Request resolved: https://github.com/pytorch/executorch/pull/6758 --- examples/models/llama/runner/eager.py | 19 ++++++-- examples/models/llama/runner/generation.py | 53 +++++++++++++++++++++- 2 files changed, 66 insertions(+), 6 deletions(-) diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index abac920c6b2..9745fdd5428 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -54,7 +54,7 @@ def build_args_parser() -> argparse.ArgumentParser: parser.add_argument( "--prompt", type=str, - default="Hello", + default=None, ) parser.add_argument( @@ -70,6 +70,13 @@ def build_args_parser() -> argparse.ArgumentParser: help="Show the tokens that were generated", ) + parser.add_argument( + "--chat", + action="store_true", + default=False, + help="Have multi-turn chat with the model", + ) + return parser @@ -78,9 +85,13 @@ def main() -> None: args = parser.parse_args() runner = EagerLlamaRunner(args) - generated_tokens = runner.text_completion( - prompt=args.prompt, - temperature=args.temperature, + generated_tokens = ( + runner.chat_completion(temperature=args.temperature) + if args.chat + else runner.text_completion( + prompt=args.prompt, + temperature=args.temperature, + ) ) if args.show_tokens: print(f"Tokens: {generated_tokens}") diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 159bc5f5017..ed25d44b6fb 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -67,12 +67,13 @@ def generate( # noqa: C901 temperature: float = 0.8, top_p: float = 0.9, echo: bool = False, + pos_base: int = 0, ) -> List[int]: # prefill logits = self.forward( tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device), input_pos=( - torch.tensor([0], dtype=torch.long, device=self.device) + torch.tensor([pos_base], dtype=torch.long, device=self.device) if self.params.use_kv_cache else None ), @@ -89,7 +90,9 @@ def generate( # noqa: C901 [[current_token]], dtype=torch.long, device=self.device ), input_pos=torch.tensor( - [len(tokens) - 1], dtype=torch.long, device=self.device + [pos_base + len(tokens) - 1], + dtype=torch.long, + device=self.device, ), ) else: @@ -136,3 +139,49 @@ def text_completion( top_p=top_p, echo=echo, ) + + def chat_completion( + self, + temperature: float = 0.6, + top_p: float = 0.9, + ) -> List[int]: + """ + Perform multi-turn chat with the language model. + + Args: + prompt (str): Text prompt for completion. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + + Returns: + Generated list of tokens. + + Note: + This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness. + """ + exit_prompt = "exit" + tokens = [] + prompt = input("Me: ") + while prompt and prompt != exit_prompt: + print("LLM: ", end="", flush=True) + new_tokens = self.generate( + prompt_tokens=self.tokenizer.encode( + self._format_prompt(prompt), bos=True, eos=False + ), + temperature=temperature, + top_p=top_p, + echo=True, + pos_base=len(tokens), + ) + tokens.extend(new_tokens) + prompt = input("Me: ") + return tokens + + def _format_prompt(self, prompt: str) -> str: + return f""" +<|begin_of_text|><|start_header_id|>system<|end_header_id|> + +You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|> + +{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""