diff --git a/examples/models/llama/main.cpp b/examples/models/llama/main.cpp index 078d938ffde..55104d7f3dc 100644 --- a/examples/models/llama/main.cpp +++ b/examples/models/llama/main.cpp @@ -35,7 +35,12 @@ DEFINE_double( DEFINE_int32( seq_len, 128, - "Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens."); + "DEPRECATED: Please use max_seq_len instead. Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens."); + +DEFINE_int32( + max_new_tokens, + -1, + "Total number of tokens to generate, excluding the prompt, will be capped by max_seq_len - # prompt tokens."); DEFINE_int32( cpu_threads, @@ -100,20 +105,33 @@ int32_t main(int32_t argc, char** argv) { } if (warmup) { - auto error = runner->warmup(prompt, /*max_new_tokens=*/seq_len); + int32_t warmup_max_new_tokens = + FLAGS_max_new_tokens != -1 ? FLAGS_max_new_tokens : seq_len; + auto error = + runner->warmup(prompt, /*max_new_tokens=*/warmup_max_new_tokens); if (error != executorch::runtime::Error::Ok) { ET_LOG(Error, "Failed to warmup llama runner"); return 1; } - // reset kv cache pos to 0 - runner->reset(); } // generate executorch::extension::llm::GenerationConfig config{ - .seq_len = seq_len, .temperature = temperature}; + .temperature = temperature}; + + if (FLAGS_max_new_tokens != -1) { + config.max_new_tokens = FLAGS_max_new_tokens; + } else { + ET_LOG( + Info, + "max_new_tokens not provided, falling back to seq_len=%d. " + "Consider using --max_new_tokens instead of --seq_len for specifying generation length.", + seq_len); + config.seq_len = seq_len; + } + auto error = runner->generate(prompt, config); if (error != executorch::runtime::Error::Ok) { - ET_LOG(Error, "Failed to warmup llama runner"); + ET_LOG(Error, "Failed to run llama runner"); return 1; } diff --git a/extension/llm/runner/test/test_text_prefiller.cpp b/extension/llm/runner/test/test_text_prefiller.cpp index 78edc96ca94..5ed5031dace 100644 --- a/extension/llm/runner/test/test_text_prefiller.cpp +++ b/extension/llm/runner/test/test_text_prefiller.cpp @@ -138,113 +138,68 @@ TEST_F(TextPrefillerTest, PrefillCallsPrefillChunkOnceWhenPromptFits) { TEST_F( TextPrefillerTest, PrefillCallsPrefillChunkMultipleTimesWhenPromptExceedsMaxLen) { - // Create a spy TextPrefiller with max_seq_len = 3 + // Create a real TextPrefiller with max_seq_len = 3 and parallel prefill const int64_t max_seq_len = 3; - auto prefiller = createMockTextPrefiller(max_seq_len); + auto prefiller = createTextPrefiller(max_seq_len, true, true); // Create prompt tokens with size > max_seq_len std::vector prompt_tokens = {1, 2, 3, 4, 5, 6, 7, 8}; int64_t start_pos = 0; - // Set up expectations for prefill_chunk calls - { - InSequence seq; // Ensure calls happen in the expected order - - // First chunk: tokens [1, 2, 3] - EXPECT_CALL(*prefiller, prefill_chunk(_, _)) - .WillOnce([&](std::vector& tokens, int64_t& pos) { - EXPECT_EQ(tokens.size(), 3); - EXPECT_EQ(tokens[0], 1); - EXPECT_EQ(tokens[1], 2); - EXPECT_EQ(tokens[2], 3); - EXPECT_EQ(pos, 0); - return Result(10); - }); - - // Second chunk: tokens [4, 5, 6] - EXPECT_CALL(*prefiller, prefill_chunk(_, _)) - .WillOnce([&](std::vector& tokens, int64_t& pos) { - EXPECT_EQ(tokens.size(), 3); - EXPECT_EQ(tokens[0], 4); - EXPECT_EQ(tokens[1], 5); - EXPECT_EQ(tokens[2], 6); - EXPECT_EQ(pos, 3); - return Result(20); - }); - - // Third chunk: tokens [7, 8] - EXPECT_CALL(*prefiller, prefill_chunk(_, _)) - .WillOnce([&](std::vector& tokens, int64_t& pos) { - EXPECT_EQ(tokens.size(), 2); - EXPECT_EQ(tokens[0], 7); - EXPECT_EQ(tokens[1], 8); - EXPECT_EQ(pos, 6); - return Result(30); - }); - } + // Track all tokens and positions passed to text_decoder_runner step + struct StepCall { + std::vector tokens; + int64_t pos; + }; + std::vector step_calls; + + // Set up expectations for text_decoder_runner step calls + EXPECT_CALL(text_decoder_runner_, step(_, _)) + .Times(3) // Should be called 3 times for 3 chunks + .WillRepeatedly( + [&](executorch::extension::TensorPtr& tokens, int64_t pos) { + // Extract token values from tensor + std::vector token_values; + int64_t num_tokens = tokens->size(1); + auto* token_data = tokens->const_data_ptr(); + for (int64_t i = 0; i < num_tokens; i++) { + token_values.push_back(static_cast(token_data[i])); + } + step_calls.push_back({token_values, pos}); + return Result(tensor); + }); // Call prefill auto result = prefiller->prefill(prompt_tokens, start_pos); // Verify the result EXPECT_EQ(result.error(), Error::Ok); - EXPECT_EQ(result.get(), 30); // Should return the token from the last chunk - - // Verify that start_pos has been updated correctly - EXPECT_EQ(start_pos, prompt_tokens.size()); -} - -// Test that prefill() handles edge cases correctly -TEST_F(TextPrefillerTest, PrefillHandlesEdgeCasesCorrectly) { - // Create a spy TextPrefiller with max_seq_len = 1 - const int64_t max_seq_len = 1; - auto prefiller = createMockTextPrefiller(max_seq_len); - - // Create prompt tokens with size > max_seq_len - std::vector prompt_tokens = {1, 2, 3}; - int64_t start_pos = 5; // Non-zero starting position - - // Set up expectations for prefill_chunk calls - { - InSequence seq; - - // First chunk: token [1] - EXPECT_CALL(*prefiller, prefill_chunk(_, _)) - .WillOnce([&](std::vector& tokens, int64_t& pos) { - EXPECT_EQ(tokens.size(), 1); - EXPECT_EQ(tokens[0], 1); - EXPECT_EQ(pos, 5); - return Result(10); - }); - - // Second chunk: token [2] - EXPECT_CALL(*prefiller, prefill_chunk(_, _)) - .WillOnce([&](std::vector& tokens, int64_t& pos) { - EXPECT_EQ(tokens.size(), 1); - EXPECT_EQ(tokens[0], 2); - EXPECT_EQ(pos, 6); - return Result(20); - }); - - // Third chunk: token [3] - EXPECT_CALL(*prefiller, prefill_chunk(_, _)) - .WillOnce([&](std::vector& tokens, int64_t& pos) { - EXPECT_EQ(tokens.size(), 1); - EXPECT_EQ(tokens[0], 3); - EXPECT_EQ(pos, 7); - return Result(30); - }); - } - - // Call prefill - auto result = prefiller->prefill(prompt_tokens, start_pos); - // Verify the result - EXPECT_EQ(result.error(), Error::Ok); - EXPECT_EQ(result.get(), 30); + // Verify that step was called 3 times with correct tokens and positions + ASSERT_EQ(step_calls.size(), 3); + + // First chunk: tokens [1, 2, 3] at position 0 + EXPECT_EQ(step_calls[0].tokens.size(), 3); + EXPECT_EQ(step_calls[0].tokens[0], 1); + EXPECT_EQ(step_calls[0].tokens[1], 2); + EXPECT_EQ(step_calls[0].tokens[2], 3); + EXPECT_EQ(step_calls[0].pos, 0); + + // Second chunk: tokens [4, 5, 6] at position 3 + EXPECT_EQ(step_calls[1].tokens.size(), 3); + EXPECT_EQ(step_calls[1].tokens[0], 4); + EXPECT_EQ(step_calls[1].tokens[1], 5); + EXPECT_EQ(step_calls[1].tokens[2], 6); + EXPECT_EQ(step_calls[1].pos, 3); + + // Third chunk: tokens [7, 8] at position 6 + EXPECT_EQ(step_calls[2].tokens.size(), 2); + EXPECT_EQ(step_calls[2].tokens[0], 7); + EXPECT_EQ(step_calls[2].tokens[1], 8); + EXPECT_EQ(step_calls[2].pos, 6); // Verify that start_pos has been updated correctly - EXPECT_EQ(start_pos, 8); // 5 (initial) + 3 (tokens) + EXPECT_EQ(start_pos, prompt_tokens.size()); } // Test that prefill() handles errors from prefill_chunk correctly @@ -305,4 +260,119 @@ TEST_F(TextPrefillerTest, PrefillChunkWorksWithParallelPrefill) { // Verify that start_pos has been updated correctly EXPECT_EQ(start_pos, prompt_tokens.size()); } +// Test that prefill_chunk updates start_pos correctly with parallel prefill +TEST_F(TextPrefillerTest, PrefillChunkUpdatesStartPosCorrectlyParallel) { + // Create a TextPrefiller with parallel prefill enabled + auto prefiller = createTextPrefiller(10, true, true); + + // Set up expectations for the text decoder runner + int64_t captured_pos = -1; + EXPECT_CALL(text_decoder_runner_, step(_, _)) + .WillOnce([&](executorch::extension::TensorPtr& tokens, int64_t pos) { + captured_pos = pos; + // Verify tokens shape is [1, num_tokens] + EXPECT_EQ(tokens->dim(), 2); + EXPECT_EQ(tokens->size(0), 1); + EXPECT_EQ(tokens->size(1), 3); + return Result(tensor); + }); + + // Create prompt tokens + std::vector prompt_tokens = {1, 2, 3}; + int64_t start_pos = 5; // Non-zero starting position + + // Call prefill_chunk directly + auto result = prefiller->prefill_chunk(prompt_tokens, start_pos); + + // Verify the result + EXPECT_EQ(result.error(), Error::Ok); + + // Verify that step was called with the original start_pos + EXPECT_EQ(captured_pos, 5); + + // Verify that start_pos has been updated by the number of tokens + // This is the key test: start_pos should be updated exactly once + EXPECT_EQ(start_pos, 8); // 5 + 3 tokens +} + +// Test that prefill_chunk updates start_pos correctly with sequential prefill +TEST_F(TextPrefillerTest, PrefillChunkUpdatesStartPosCorrectlySequential) { + // Create a TextPrefiller with sequential prefill (parallel disabled) + auto prefiller = createTextPrefiller(10, true, false); + + // Track all positions passed to step + std::vector captured_positions; + EXPECT_CALL(text_decoder_runner_, step(_, _)) + .Times(3) + .WillRepeatedly( + [&](executorch::extension::TensorPtr& tokens, int64_t pos) { + captured_positions.push_back(pos); + // Verify tokens shape is [1, 1] for sequential prefill + EXPECT_EQ(tokens->dim(), 2); + EXPECT_EQ(tokens->size(0), 1); + EXPECT_EQ(tokens->size(1), 1); + return Result(tensor); + }); + + // Create prompt tokens + std::vector prompt_tokens = {1, 2, 3}; + int64_t start_pos = 10; // Non-zero starting position + + // Call prefill_chunk directly + auto result = prefiller->prefill_chunk(prompt_tokens, start_pos); + + // Verify the result + EXPECT_EQ(result.error(), Error::Ok); + + // Verify that step was called with incrementing positions + ASSERT_EQ(captured_positions.size(), 3); + EXPECT_EQ(captured_positions[0], 10); // First token at initial start_pos + EXPECT_EQ(captured_positions[1], 11); // Second token at start_pos + 1 + EXPECT_EQ(captured_positions[2], 12); // Third token at start_pos + 2 + + // Verify that start_pos has been updated by the number of tokens + // This is the key test: start_pos should be updated exactly once per token + EXPECT_EQ(start_pos, 13); // 10 + 3 tokens +} + +// Test that prefill with chunking updates start_pos correctly across chunks. +// This test would have caught the bug where start_pos was being updated twice. +TEST_F( + TextPrefillerTest, + PrefillWithChunkingUpdatesStartPosCorrectlyAcrossChunks) { + // Create a TextPrefiller with max_seq_len = 3 and parallel prefill + auto prefiller = createTextPrefiller(3, true, true); + + // Track all positions passed to step + std::vector captured_positions; + EXPECT_CALL(text_decoder_runner_, step(_, _)) + .Times(3) // Should be called 3 times: [1,2,3], [4,5,6], [7,8] + .WillRepeatedly( + [&](executorch::extension::TensorPtr& tokens, int64_t pos) { + captured_positions.push_back(pos); + return Result(tensor); + }); + + // Create prompt tokens that exceed max_seq_len + std::vector prompt_tokens = {1, 2, 3, 4, 5, 6, 7, 8}; + int64_t start_pos = 100; // Non-zero starting position + + // Call prefill (which will chunk internally) + auto result = prefiller->prefill(prompt_tokens, start_pos); + + // Verify the result + EXPECT_EQ(result.error(), Error::Ok); + + // Verify that step was called with correct positions for each chunk + // If start_pos were updated twice (the bug), these would be wrong + ASSERT_EQ(captured_positions.size(), 3); + EXPECT_EQ(captured_positions[0], 100); // Chunk 1: tokens [1,2,3] + EXPECT_EQ(captured_positions[1], 103); // Chunk 2: tokens [4,5,6] + EXPECT_EQ(captured_positions[2], 106); // Chunk 3: tokens [7,8] + + // Verify that final start_pos is correct + // This is the key test for the bug: start_pos should be exactly + // initial_pos + num_tokens, not double-incremented + EXPECT_EQ(start_pos, 108); // 100 + 8 tokens +} } // namespace diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index 0106fd5c250..3fd9320f2b8 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -228,7 +228,7 @@ Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) { Error err = generate(prompt, config); // Reset stats after warmup, not resetting the std::unique_ptr! - stats_->reset(); + reset(); return err; } diff --git a/extension/llm/runner/text_prefiller.cpp b/extension/llm/runner/text_prefiller.cpp index de092b6b05d..14e032a6b1e 100644 --- a/extension/llm/runner/text_prefiller.cpp +++ b/extension/llm/runner/text_prefiller.cpp @@ -59,7 +59,6 @@ ::executorch::runtime::Result TextPrefiller::prefill( ET_CHECK_OK_OR_RETURN_ERROR(chunk_result.error()); cur_token = chunk_result.get(); - start_pos += num_tokens_to_prefill_with; num_tokens_to_process += num_tokens_to_prefill_with; }