diff --git a/extension/llm/runner/text_prefiller.cpp b/extension/llm/runner/text_prefiller.cpp index 961c43d8c93..2bd2ff7cb1d 100644 --- a/extension/llm/runner/text_prefiller.cpp +++ b/extension/llm/runner/text_prefiller.cpp @@ -78,25 +78,31 @@ Result TextPrefiller::prefill( ManagedTensor managed_start_pos(&pos_data, {1}, ScalarType::Long); + // run the first token and get back logits tensor. Assuming the first token + // is bos so don't callback. + exec_aten::Tensor logits_tensor = ET_UNWRAP( + text_decoder_runner_->step(managed_tokens, managed_start_pos)); + pos = 1; // start from index 1 + while (pos < num_prompt_tokens) { // Run the model pos_data = start_pos + pos; - Result logits_res = - text_decoder_runner_->step(managed_tokens, managed_start_pos); - - ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error()); prev_token = cur_token; - pos++; + // NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds) + cur_token = prompt_tokens[pos]; - cur_token = pos == num_prompt_tokens - ? text_decoder_runner_->logits_to_token(logits_res.get()) - : prompt_tokens[pos]; + logits_tensor = ET_UNWRAP( + text_decoder_runner_->step(managed_tokens, managed_start_pos)); // print the token as string, decode it with the Tokenizer object token_callback(ET_UNWRAP(tokenizer_->decode(prev_token, cur_token))); + + pos++; } + + cur_token = text_decoder_runner_->logits_to_token(logits_tensor); } return cur_token; }