diff --git a/examples/models/llama2/runner/runner.cpp b/examples/models/llama2/runner/runner.cpp index cd6d9c9e7cc..7c6176a464c 100644 --- a/examples/models/llama2/runner/runner.cpp +++ b/examples/models/llama2/runner/runner.cpp @@ -130,133 +130,101 @@ int32_t Runner::logitsToToken(const exec_aten::Tensor& logits_tensor) { } } -Result Runner::prefill( - const std::vector& tokens, - ManagedTensor& managed_tokens, - ManagedTensor& managed_start_pos, +Result Runner::prefill( + std::vector& prompt_tokens, + int64_t start_pos, std::function token_callback) { // enable_parallel_prefill_ maybe set even when not using kv cache // When kv cache is not used, start pos is ignored - int32_t num_tokens = tokens.size(); - if (enable_parallel_prefill_) { - managed_tokens.resize({1, num_tokens}); - int64_t* tokens_ptr = - managed_tokens.get_aliasing_tensor().mutable_data_ptr(); - for (int i = 0; i < num_tokens; i++) { - // The following assumes batch size = 1 - tokens_ptr[i] = tokens[i]; - } - std::vector inputs; - auto tokens_tensor = managed_tokens.get_aliasing_tensor(); - auto start_pos = managed_start_pos.get_aliasing_tensor(); + int32_t num_prompt_tokens = prompt_tokens.size(); + + ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token"); + ET_CHECK_MSG( + num_prompt_tokens < max_seq_len_, + "Max seq length exceeded - please increase max seq len value"); + + // store the token + uint64_t cur_token; + if (enable_parallel_prefill_ || !use_kv_cache_) { + // initialize tensor wrappers + ManagedTensor managed_tokens( + prompt_tokens.data(), {1, num_prompt_tokens}, ScalarType::Long); - // inputs:[tokens, start_pos] - inputs.push_back(tokens_tensor); - inputs.push_back(start_pos); + ManagedTensor managed_start_pos(&start_pos, {1}, ScalarType::Long); + + Result outputs_res = + run_model_step(managed_tokens, managed_start_pos); - Result> outputs_res = module_->forward(inputs); ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); + ET_LOG( + Info, "Prefill token result numel(): %zu", outputs_res.get().numel()); ET_CHECK_MSG( - outputs_res.get()[0].isTensor(), - "Non Tensor Output returned from executing LLM"); - ET_CHECK_MSG( - outputs_res.get()[0].toTensor().size(1) == num_tokens, + outputs_res.get().size(1) == num_prompt_tokens, "Expected number of output tokens %d does not match returned value %zu.", - num_tokens, - outputs_res.get()[0].toTensor().size(1)); - - start_pos.mutable_data_ptr()[0] = num_tokens; - - uint64_t prev = tokens[0]; + num_prompt_tokens, + outputs_res.get().size(1)); + // insert new token into prompt_tokens + // NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds) + uint64_t prev = prompt_tokens[0]; uint64_t cur; - for (int i = 1; i < num_tokens; i++) { - cur = tokens[i]; - auto piece_res = tokenizer_->decode(prev, cur); - ET_CHECK_OK_OR_RETURN_ERROR(piece_res.error()); - util::safe_printf(piece_res.get().c_str()); - fflush(stdout); + for (int i = 1; i < prompt_tokens.size(); i++) { + cur = prompt_tokens[i]; + token_callback(ET_UNWRAP(tokenizer_->decode(prev, cur))); prev = cur; - if (token_callback) { - token_callback(piece_res.get().c_str()); - } - } - cur = logitsToToken(outputs_res.get()[0].toTensor()); - auto piece_res = tokenizer_->decode(prev, cur); - ET_CHECK(piece_res.ok()); - const char* piece = piece_res.get().c_str(); - util::safe_printf(piece); - fflush(stdout); - if (token_callback) { - token_callback(piece_res.get().c_str()); } - - // Return the logits tensor - stats_.first_token_ms = util::time_in_ms(); - stats_.prompt_eval_end_ms = util::time_in_ms(); - return outputs_res.get()[0].toTensor(); + cur_token = logitsToToken(outputs_res.get()); } else { // sequential prefill int64_t pos = 0; // position in the sequence - int64_t cur_token = tokens[0]; - int64_t prev_token; - // This is a hack to enable returning a logits tensor from prefill - auto logits_tensor = managed_tokens.get_aliasing_tensor(); - while (pos < num_tokens) { + uint64_t prev_token; + // token & pos + int64_t pos_data = 0; + // NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds) + cur_token = prompt_tokens[0]; + + // initialize tensor wrappers + ManagedTensor managed_tokens(&cur_token, {1, 1}, ScalarType::Long); + + ManagedTensor managed_start_pos(&pos_data, {1}, ScalarType::Long); + + while (pos < num_prompt_tokens) { // Run the model - Result logits_res = run_model_step( - cur_token, managed_tokens, managed_start_pos, num_tokens); + pos_data = start_pos + pos; - ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error()); - logits_tensor = logits_res.get(); - // Hack to enable returning a logits tensor from prefill + Result logits_res = + run_model_step(managed_tokens, managed_start_pos); + ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error()); prev_token = cur_token; + pos++; + long sample_start_time_ms = util::time_in_ms(); - cur_token = logitsToToken(logits_tensor); + cur_token = pos == num_prompt_tokens ? logitsToToken(logits_res.get()) + : prompt_tokens[pos]; + stats_.aggregate_sampling_time_ms += util::time_in_ms() - sample_start_time_ms; - // advance the state machine - if (pos < num_tokens - 1) { - // prefill, force the next token to be the next prompt token - cur_token = tokens[pos + 1]; - } - pos++; - // print the token as string, decode it with the Tokenizer object - auto piece_res = tokenizer_->decode(prev_token, cur_token); - ET_CHECK(piece_res.ok()); - const char* piece = piece_res.get().c_str(); - util::safe_printf(piece); - fflush(stdout); - if (token_callback) { - token_callback(piece_res.get().c_str()); - } + token_callback(ET_UNWRAP(tokenizer_->decode(prev_token, cur_token))); } - auto start_pos = managed_start_pos.get_aliasing_tensor(); - start_pos.mutable_data_ptr()[0] = num_tokens; - stats_.first_token_ms = util::time_in_ms(); - stats_.prompt_eval_end_ms = util::time_in_ms(); - return logits_tensor; } + // Return the next token + stats_.first_token_ms = util::time_in_ms(); + stats_.prompt_eval_end_ms = util::time_in_ms(); + return cur_token; } // Given an input token. Set up the inputs for the model and execute a single // step. Returning the logits tensor. Result Runner::run_model_step( - int64_t input_token, ManagedTensor& managed_tokens, - ManagedTensor& managed_start_pos, - size_t max_seq_len) { + ManagedTensor& managed_start_pos) { // ET_LOG(Info, "Input token %" PRIu64, input_token); + auto tokens = managed_tokens.get_aliasing_tensor(); if (use_kv_cache_) { - auto tokens = managed_tokens.get_aliasing_tensor(); auto start_pos = managed_start_pos.get_aliasing_tensor(); - // When using kv-cache our input is always 1 token, so just update to the - // latest. - tokens.mutable_data_ptr()[0] = input_token; - Result> outputs_res = module_->forward({tokens, start_pos}); ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); @@ -267,25 +235,12 @@ Result Runner::run_model_step( outputs_res.get()[0].isTensor(), "Non Tensor Output returned from executing LLM"); - // Bump start_pos by 1 - start_pos.mutable_data_ptr()[0]++; - // Return the logits tensor return outputs_res.get()[0].toTensor(); } else { // no kv cache - std::vector inputs; - auto tokens = managed_tokens.get_aliasing_tensor(); (void)managed_start_pos; // unused - // When not using kv-cache our input is the entire history of tokens we have - // seen, so resize input to be 1 larger and append the new token to the end. - // TODO does this work in ATen mode? - tokens.mutable_data_ptr()[tokens.size(1) - 1] = input_token; - - // inputs:[tokens] - inputs.push_back(tokens); - - Result> outputs_res = module_->forward(inputs); + Result> outputs_res = module_->forward({tokens}); ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); ET_CHECK_MSG( outputs_res.get().size() == 1, @@ -294,14 +249,6 @@ Result Runner::run_model_step( outputs_res.get()[0].isTensor(), "Non Tensor Output returned from executing LLM"); - if (tokens.size(1) < max_seq_len) { - // Resize the tokens tensor to be 1 larger for next step. - // Note that this relies on the fact that underlying memory is the same - // such that previous tokens stored there will still exist. - // Not a good thing to rely upon. - managed_tokens.resize({1, static_cast(tokens.size(1) + 1)}); - } - // Return the logits tensor return outputs_res.get()[0].toTensor(); } @@ -321,6 +268,15 @@ Error Runner::generate( stats_.model_load_end_ms = util::time_in_ms(); } + // Wrap the token_callback with print function + std::function wrapped_callback = + [token_callback](const std::string& piece) { + util::safe_printf(piece.c_str()); + fflush(stdout); + if (token_callback) { + token_callback(piece); + } + }; // First token time only measures the time it takes to encode the prompt and // return a response token. @@ -349,70 +305,48 @@ Error Runner::generate( num_prompt_tokens < seq_len, "Sequence length exceeded - please increase the seq_len value passed to generate()"); + // Prefill first + // Here feed all tokens to the model and get the next predicted token + // after the prompt. After that we will enter generate loop. + auto prefill_res = prefill(prompt_tokens, 0, wrapped_callback); + ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); + int64_t cur_token = prefill_res.get(); + + // print the first token from prefill. No prev_token so use cur_token for it. + wrapped_callback(ET_UNWRAP(tokenizer_->decode(cur_token, cur_token))); + // start the main loop - int64_t pos = 0; // position in the sequence + int64_t pos = num_prompt_tokens; // position in the sequence + // Generate the rest of the sequence std::vector token_data; // allocate space for the tokens - std::vector token_shape = {1, seq_len}; - - std::vector start_pos_data; // allocate space for the tokens - std::vector start_pos_shape = {1}; + std::vector token_shape; - token_data.resize(seq_len); if (use_kv_cache_) { // hard code these to size 1 as kv cache is locked to static size right now. - start_pos_data.resize(1); - start_pos_data.push_back(0); + token_data = {cur_token}; + token_shape = {1, 1}; + } else { + for (auto tok : prompt_tokens) { + token_data.push_back(tok); + } + token_data.push_back(cur_token); + token_shape = {1, num_prompt_tokens + 1}; } // initialize tensor wrappers ManagedTensor tokens_managed( token_data.data(), token_shape, ScalarType::Long); - // Create with the max shape to approapriately set the capacity of this - // tensor, then resize back to 1 for first input. - tokens_managed.resize({1, 1}); - ManagedTensor start_pos_managed( - start_pos_data.data(), start_pos_shape, ScalarType::Long); + ManagedTensor start_pos_managed(&pos, {1}, ScalarType::Long); int64_t prev_token; - int64_t cur_token = prompt_tokens[0]; - - // Prefill first - // Here feed all tokens to the model and get the next predicted token - // after the prompt. After that we will enter generate loop. - auto prefill_res = - prefill(prompt_tokens, tokens_managed, start_pos_managed, token_callback); - ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); - exec_aten::Tensor& prefill_res_tensor = prefill_res.get(); - cur_token = logitsToToken(prefill_res_tensor); - if (use_kv_cache_) { - // Prefill could be parallel or sequential. - // Parallel: - // kv cache: - // - tokens_managed should resized to 1 as inference expects one token at - // a time. - // no kv cache: - // - tokens_managed should be resized to prompt length + 1, as inference - // expects all tokens at once. - // Sequential prefill: - // kv cache: - // - tokens_managed should be resized to 1, as inference expects one - // token at a time. - // no kv cache: - // - tokens_managed should be resized to prompt length + 1, as inference - // expects all tokens at once. - tokens_managed.resize({1, 1}); - } else { - tokens_managed.resize({1, num_prompt_tokens + 1}); - } - pos = num_prompt_tokens; // Generate our tokens while (pos < seq_len - 1) { // Run the model Result logits_res = - run_model_step(cur_token, tokens_managed, start_pos_managed, seq_len); + run_model_step(tokens_managed, start_pos_managed); ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error()); exec_aten::Tensor& logits_tensor = logits_res.get(); @@ -426,19 +360,19 @@ Error Runner::generate( pos++; - // print the token as string, decode it with the Tokenizer object - auto piece_res = tokenizer_->decode(prev_token, cur_token); - ET_CHECK(piece_res.ok()); - const char* piece = piece_res.get().c_str(); - - // same as printf("%s", piece), but skips "unsafe" bytes - util::safe_printf(piece); - fflush(stdout); - - if (token_callback) { - token_callback(piece); + if (use_kv_cache_) { + // update the token tensor. token_data will not be empty. + // NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds) + token_data[0] = cur_token; + } else { + // push it to the back + token_data.push_back(cur_token); + tokens_managed.resize({1, static_cast(token_data.size())}); } + // print the token as string, decode it with the Tokenizer object + wrapped_callback(ET_UNWRAP(tokenizer_->decode(prev_token, cur_token))); + if (shouldStop_) { break; } diff --git a/examples/models/llama2/runner/runner.h b/examples/models/llama2/runner/runner.h index 2c313fd6fe7..fe38d4e0404 100644 --- a/examples/models/llama2/runner/runner.h +++ b/examples/models/llama2/runner/runner.h @@ -45,16 +45,13 @@ class Runner { private: int32_t logitsToToken(const exec_aten::Tensor& logits_tensor); - Result prefill( - const std::vector& tokens, - ManagedTensor& managed_tokens, - ManagedTensor& managed_start_pos, + Result prefill( + std::vector& prompt_tokens, + int64_t start_pos, std::function token_callback); Result run_model_step( - int64_t input_token, - ManagedTensor& tokens, - ManagedTensor& start_pos, - size_t max_seq_len); + ManagedTensor& managed_tokens, + ManagedTensor& managed_start_pos); // metadata int32_t vocab_size_; int32_t bos_id_;