Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion examples/models/llama/runner/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@ def execute_runner(runner_class: Type[LlamaRunner]) -> None:
with torch.no_grad():
runner = runner_class(args) # pyre-ignore: Missing argument [20]
generated_tokens = (
runner.chat_completion(temperature=args.temperature)
runner.chat_completion(
max_seq_len=1000000 if args.use_attention_sink else args.max_seq_length,
temperature=args.temperature,
show_progress=args.show_tokens,
)
if args.chat
else runner.text_completion(
prompt=args.prompt,
Expand Down
27 changes: 17 additions & 10 deletions examples/models/llama/runner/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,18 +168,19 @@ def text_completion(

def chat_completion(
self,
max_seq_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
show_progress: bool = False,
) -> List[int]:
"""
Perform multi-turn chat with the language model.

Args:
prompt (str): Text prompt for completion.
max_seq_len (int): Maximum number of tokens to generate for each prompt.
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.

show_progress (bool, optional): Flag indicating whether to show number of tokens generated.
Returns:
Generated list of tokens.

Expand All @@ -188,20 +189,26 @@ def chat_completion(
"""
exit_prompt = "exit"
tokens = []
pre_stop_token = []
prompt = input("Me: ")
while prompt and prompt != exit_prompt:
print("LLM: ", end="", flush=True)
new_tokens = self.generate(
prompt_tokens=self.tokenizer.encode(
self._format_prompt(prompt), bos=True, eos=False
),
max_seq_len=self.max_seq_len,
prompt_tokens = self.tokenizer.encode(
self._format_prompt(prompt), bos=True, eos=False
)
generated_tokens = self.generate(
prompt_tokens=pre_stop_token + prompt_tokens,
max_seq_len=max_seq_len,
temperature=temperature,
top_p=top_p,
echo=True,
echo=False,
pos_base=len(tokens) - 1 if len(tokens) > 0 else 0,
)
tokens.extend(new_tokens)
pre_stop_token = generated_tokens[-1:]
tokens.extend(prompt_tokens)
tokens.extend(generated_tokens)
if show_progress:
print(f"[Generated {len(tokens)} tokens]")
prompt = input("Me: ")
return tokens

Expand Down