diff --git a/examples/models/llama2/runner/generation.py b/examples/models/llama2/runner/generation.py index 56a15005ef1..404ff4717ea 100644 --- a/examples/models/llama2/runner/generation.py +++ b/examples/models/llama2/runner/generation.py @@ -14,11 +14,7 @@ import torch.nn.functional as F from executorch.examples.models.llama2.llama_transformer import ModelArgs -from executorch.examples.models.llama2.tokenizer.tiktoken import ( - Dialog, - Message, - Tokenizer, -) +from executorch.examples.models.llama2.tokenizer.tiktoken import Tokenizer from executorch.extension.pybindings.portable_lib import _load_for_executorch @@ -28,12 +24,6 @@ class CompletionPrediction(TypedDict, total=False): logprobs: List[float] # not required -class ChatPrediction(TypedDict, total=False): - generation: Message - tokens: List[str] # not required - logprobs: List[float] # not required - - def sample_top_p(probs, p): """ Perform top-p (nucleus) sampling on a probability distribution. @@ -225,72 +215,6 @@ def text_completion( ] return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens] - def chat_completion( - self, - dialogs: List[Dialog], - temperature: float = 0.6, - top_p: float = 0.9, - max_gen_len: Optional[int] = None, - logprobs: bool = False, - ) -> List[ChatPrediction]: - """ - Generate assistant responses for a list of conversational dialogs using the language generation model. - - Args: - dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages. - temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. - top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. - max_gen_len (Optional[int], optional): Maximum length of the generated response sequence. - If not provided, it's set to the model's maximum sequence length minus 1. - logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. - - Returns: - List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response. - - Raises: - AssertionError: If the last message in a dialog is not from the user. - AssertionError: If the dialog roles are not in the required 'user', 'assistant', and optional 'system' order. - - Note: - This method generates assistant responses for the provided conversational dialogs. - It employs nucleus sampling to introduce controlled randomness in text generation. - If logprobs is True, token log probabilities are computed for each generated token. - """ - if max_gen_len is None: - max_gen_len = self.model.params.max_seq_len - 1 - - prompt_tokens = [ - self.formatter.encode_dialog_prompt(dialog) for dialog in dialogs - ] - generation_tokens, generation_logprobs = self.generate( - prompt_tokens=prompt_tokens, - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - logprobs=logprobs, - ) - if logprobs: - return [ - { - "generation": { - "role": "assistant", - "content": self.tokenizer.decode(t), - }, - "tokens": [self.tokenizer.decode([x]) for x in t], - "logprobs": logprobs_i, - } - for t, logprobs_i in zip(generation_tokens, generation_logprobs) - ] - return [ - { - "generation": { - "role": "assistant", - "content": self.tokenizer.decode(t), - }, - } - for t in generation_tokens - ] - def build_args_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser()