From db93e9b431d629d7d5143c9d9a18e2252cb4c31f Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Sat, 10 Aug 2024 13:53:38 -0700 Subject: [PATCH 1/2] [llama] Fix text prefiller In sequential prefill, we should print out the prompt, and we shouldn't print the first token(``) as well as the token generated by prefill. The runner should have control on what to do with the token generated by prefill. Will add some unit tests later but now it fixes test-llama-runner-mac CI job. --- extension/llm/runner/text_prefiller.cpp | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/extension/llm/runner/text_prefiller.cpp b/extension/llm/runner/text_prefiller.cpp index 961c43d8c93..4ad95821c4a 100644 --- a/extension/llm/runner/text_prefiller.cpp +++ b/extension/llm/runner/text_prefiller.cpp @@ -78,25 +78,29 @@ Result TextPrefiller::prefill( ManagedTensor managed_start_pos(&pos_data, {1}, ScalarType::Long); + // run the first token and get back logits tensor. Assuming the first token is bos so don't callback. + exec_aten::Tensor logits_tensor = ET_UNWRAP(text_decoder_runner_->step(managed_tokens, managed_start_pos)); + pos = 1; // start from index 1 + while (pos < num_prompt_tokens) { // Run the model pos_data = start_pos + pos; - Result logits_res = - text_decoder_runner_->step(managed_tokens, managed_start_pos); - - ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error()); prev_token = cur_token; - pos++; + // NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds) + cur_token = prompt_tokens[pos]; - cur_token = pos == num_prompt_tokens - ? text_decoder_runner_->logits_to_token(logits_res.get()) - : prompt_tokens[pos]; + logits_tensor = + ET_UNWRAP(text_decoder_runner_->step(managed_tokens, managed_start_pos)); // print the token as string, decode it with the Tokenizer object token_callback(ET_UNWRAP(tokenizer_->decode(prev_token, cur_token))); + + pos++; } + + cur_token = text_decoder_runner_->logits_to_token(logits_tensor); } return cur_token; } From f6bbe4315f14b273f878918bbbe74db14da48c90 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Sat, 10 Aug 2024 15:46:30 -0700 Subject: [PATCH 2/2] Lint --- extension/llm/runner/text_prefiller.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/extension/llm/runner/text_prefiller.cpp b/extension/llm/runner/text_prefiller.cpp index 4ad95821c4a..2bd2ff7cb1d 100644 --- a/extension/llm/runner/text_prefiller.cpp +++ b/extension/llm/runner/text_prefiller.cpp @@ -78,8 +78,10 @@ Result TextPrefiller::prefill( ManagedTensor managed_start_pos(&pos_data, {1}, ScalarType::Long); - // run the first token and get back logits tensor. Assuming the first token is bos so don't callback. - exec_aten::Tensor logits_tensor = ET_UNWRAP(text_decoder_runner_->step(managed_tokens, managed_start_pos)); + // run the first token and get back logits tensor. Assuming the first token + // is bos so don't callback. + exec_aten::Tensor logits_tensor = ET_UNWRAP( + text_decoder_runner_->step(managed_tokens, managed_start_pos)); pos = 1; // start from index 1 while (pos < num_prompt_tokens) { @@ -91,8 +93,8 @@ Result TextPrefiller::prefill( // NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds) cur_token = prompt_tokens[pos]; - logits_tensor = - ET_UNWRAP(text_decoder_runner_->step(managed_tokens, managed_start_pos)); + logits_tensor = ET_UNWRAP( + text_decoder_runner_->step(managed_tokens, managed_start_pos)); // print the token as string, decode it with the Tokenizer object token_callback(ET_UNWRAP(tokenizer_->decode(prev_token, cur_token)));