Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions examples/models/llama/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ DEFINE_double(
DEFINE_int32(
seq_len,
128,
"Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens.");
"DEPRECATED: Please use max_seq_len instead. Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens.");

DEFINE_int32(
max_new_tokens,
-1,
"Total number of tokens to generate, excluding the prompt, will be capped by max_seq_len - # prompt tokens.");

DEFINE_int32(
cpu_threads,
Expand Down Expand Up @@ -100,20 +105,33 @@ int32_t main(int32_t argc, char** argv) {
}

if (warmup) {
auto error = runner->warmup(prompt, /*max_new_tokens=*/seq_len);
int32_t warmup_max_new_tokens =
FLAGS_max_new_tokens != -1 ? FLAGS_max_new_tokens : seq_len;
auto error =
runner->warmup(prompt, /*max_new_tokens=*/warmup_max_new_tokens);
if (error != executorch::runtime::Error::Ok) {
ET_LOG(Error, "Failed to warmup llama runner");
return 1;
}
// reset kv cache pos to 0
runner->reset();
}
// generate
executorch::extension::llm::GenerationConfig config{
.seq_len = seq_len, .temperature = temperature};
.temperature = temperature};

if (FLAGS_max_new_tokens != -1) {
config.max_new_tokens = FLAGS_max_new_tokens;
} else {
ET_LOG(
Info,
"max_new_tokens not provided, falling back to seq_len=%d. "
"Consider using --max_new_tokens instead of --seq_len for specifying generation length.",
seq_len);
config.seq_len = seq_len;
}

auto error = runner->generate(prompt, config);
if (error != executorch::runtime::Error::Ok) {
ET_LOG(Error, "Failed to warmup llama runner");
ET_LOG(Error, "Failed to run llama runner");
return 1;
}

Expand Down
254 changes: 162 additions & 92 deletions extension/llm/runner/test/test_text_prefiller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,113 +138,68 @@ TEST_F(TextPrefillerTest, PrefillCallsPrefillChunkOnceWhenPromptFits) {
TEST_F(
TextPrefillerTest,
PrefillCallsPrefillChunkMultipleTimesWhenPromptExceedsMaxLen) {
// Create a spy TextPrefiller with max_seq_len = 3
// Create a real TextPrefiller with max_seq_len = 3 and parallel prefill
const int64_t max_seq_len = 3;
auto prefiller = createMockTextPrefiller(max_seq_len);
auto prefiller = createTextPrefiller(max_seq_len, true, true);

// Create prompt tokens with size > max_seq_len
std::vector<uint64_t> prompt_tokens = {1, 2, 3, 4, 5, 6, 7, 8};
int64_t start_pos = 0;

// Set up expectations for prefill_chunk calls
{
InSequence seq; // Ensure calls happen in the expected order

// First chunk: tokens [1, 2, 3]
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
EXPECT_EQ(tokens.size(), 3);
EXPECT_EQ(tokens[0], 1);
EXPECT_EQ(tokens[1], 2);
EXPECT_EQ(tokens[2], 3);
EXPECT_EQ(pos, 0);
return Result<uint64_t>(10);
});

// Second chunk: tokens [4, 5, 6]
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
EXPECT_EQ(tokens.size(), 3);
EXPECT_EQ(tokens[0], 4);
EXPECT_EQ(tokens[1], 5);
EXPECT_EQ(tokens[2], 6);
EXPECT_EQ(pos, 3);
return Result<uint64_t>(20);
});

// Third chunk: tokens [7, 8]
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
EXPECT_EQ(tokens.size(), 2);
EXPECT_EQ(tokens[0], 7);
EXPECT_EQ(tokens[1], 8);
EXPECT_EQ(pos, 6);
return Result<uint64_t>(30);
});
}
// Track all tokens and positions passed to text_decoder_runner step
struct StepCall {
std::vector<uint64_t> tokens;
int64_t pos;
};
std::vector<StepCall> step_calls;

// Set up expectations for text_decoder_runner step calls
EXPECT_CALL(text_decoder_runner_, step(_, _))
.Times(3) // Should be called 3 times for 3 chunks
.WillRepeatedly(
[&](executorch::extension::TensorPtr& tokens, int64_t pos) {
// Extract token values from tensor
std::vector<uint64_t> token_values;
int64_t num_tokens = tokens->size(1);
auto* token_data = tokens->const_data_ptr<int64_t>();
for (int64_t i = 0; i < num_tokens; i++) {
token_values.push_back(static_cast<uint64_t>(token_data[i]));
}
step_calls.push_back({token_values, pos});
return Result<executorch::aten::Tensor>(tensor);
});

// Call prefill
auto result = prefiller->prefill(prompt_tokens, start_pos);

// Verify the result
EXPECT_EQ(result.error(), Error::Ok);
EXPECT_EQ(result.get(), 30); // Should return the token from the last chunk

// Verify that start_pos has been updated correctly
EXPECT_EQ(start_pos, prompt_tokens.size());
}

// Test that prefill() handles edge cases correctly
TEST_F(TextPrefillerTest, PrefillHandlesEdgeCasesCorrectly) {
// Create a spy TextPrefiller with max_seq_len = 1
const int64_t max_seq_len = 1;
auto prefiller = createMockTextPrefiller(max_seq_len);

// Create prompt tokens with size > max_seq_len
std::vector<uint64_t> prompt_tokens = {1, 2, 3};
int64_t start_pos = 5; // Non-zero starting position

// Set up expectations for prefill_chunk calls
{
InSequence seq;

// First chunk: token [1]
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
EXPECT_EQ(tokens.size(), 1);
EXPECT_EQ(tokens[0], 1);
EXPECT_EQ(pos, 5);
return Result<uint64_t>(10);
});

// Second chunk: token [2]
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
EXPECT_EQ(tokens.size(), 1);
EXPECT_EQ(tokens[0], 2);
EXPECT_EQ(pos, 6);
return Result<uint64_t>(20);
});

// Third chunk: token [3]
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
EXPECT_EQ(tokens.size(), 1);
EXPECT_EQ(tokens[0], 3);
EXPECT_EQ(pos, 7);
return Result<uint64_t>(30);
});
}

// Call prefill
auto result = prefiller->prefill(prompt_tokens, start_pos);

// Verify the result
EXPECT_EQ(result.error(), Error::Ok);
EXPECT_EQ(result.get(), 30);
// Verify that step was called 3 times with correct tokens and positions
ASSERT_EQ(step_calls.size(), 3);

// First chunk: tokens [1, 2, 3] at position 0
EXPECT_EQ(step_calls[0].tokens.size(), 3);
EXPECT_EQ(step_calls[0].tokens[0], 1);
EXPECT_EQ(step_calls[0].tokens[1], 2);
EXPECT_EQ(step_calls[0].tokens[2], 3);
EXPECT_EQ(step_calls[0].pos, 0);

// Second chunk: tokens [4, 5, 6] at position 3
EXPECT_EQ(step_calls[1].tokens.size(), 3);
EXPECT_EQ(step_calls[1].tokens[0], 4);
EXPECT_EQ(step_calls[1].tokens[1], 5);
EXPECT_EQ(step_calls[1].tokens[2], 6);
EXPECT_EQ(step_calls[1].pos, 3);

// Third chunk: tokens [7, 8] at position 6
EXPECT_EQ(step_calls[2].tokens.size(), 2);
EXPECT_EQ(step_calls[2].tokens[0], 7);
EXPECT_EQ(step_calls[2].tokens[1], 8);
EXPECT_EQ(step_calls[2].pos, 6);

// Verify that start_pos has been updated correctly
EXPECT_EQ(start_pos, 8); // 5 (initial) + 3 (tokens)
EXPECT_EQ(start_pos, prompt_tokens.size());
}

// Test that prefill() handles errors from prefill_chunk correctly
Expand Down Expand Up @@ -305,4 +260,119 @@ TEST_F(TextPrefillerTest, PrefillChunkWorksWithParallelPrefill) {
// Verify that start_pos has been updated correctly
EXPECT_EQ(start_pos, prompt_tokens.size());
}
// Test that prefill_chunk updates start_pos correctly with parallel prefill
TEST_F(TextPrefillerTest, PrefillChunkUpdatesStartPosCorrectlyParallel) {
// Create a TextPrefiller with parallel prefill enabled
auto prefiller = createTextPrefiller(10, true, true);

// Set up expectations for the text decoder runner
int64_t captured_pos = -1;
EXPECT_CALL(text_decoder_runner_, step(_, _))
.WillOnce([&](executorch::extension::TensorPtr& tokens, int64_t pos) {
captured_pos = pos;
// Verify tokens shape is [1, num_tokens]
EXPECT_EQ(tokens->dim(), 2);
EXPECT_EQ(tokens->size(0), 1);
EXPECT_EQ(tokens->size(1), 3);
return Result<executorch::aten::Tensor>(tensor);
});

// Create prompt tokens
std::vector<uint64_t> prompt_tokens = {1, 2, 3};
int64_t start_pos = 5; // Non-zero starting position

// Call prefill_chunk directly
auto result = prefiller->prefill_chunk(prompt_tokens, start_pos);

// Verify the result
EXPECT_EQ(result.error(), Error::Ok);

// Verify that step was called with the original start_pos
EXPECT_EQ(captured_pos, 5);

// Verify that start_pos has been updated by the number of tokens
// This is the key test: start_pos should be updated exactly once
EXPECT_EQ(start_pos, 8); // 5 + 3 tokens
}

// Test that prefill_chunk updates start_pos correctly with sequential prefill
TEST_F(TextPrefillerTest, PrefillChunkUpdatesStartPosCorrectlySequential) {
// Create a TextPrefiller with sequential prefill (parallel disabled)
auto prefiller = createTextPrefiller(10, true, false);

// Track all positions passed to step
std::vector<int64_t> captured_positions;
EXPECT_CALL(text_decoder_runner_, step(_, _))
.Times(3)
.WillRepeatedly(
[&](executorch::extension::TensorPtr& tokens, int64_t pos) {
captured_positions.push_back(pos);
// Verify tokens shape is [1, 1] for sequential prefill
EXPECT_EQ(tokens->dim(), 2);
EXPECT_EQ(tokens->size(0), 1);
EXPECT_EQ(tokens->size(1), 1);
return Result<executorch::aten::Tensor>(tensor);
});

// Create prompt tokens
std::vector<uint64_t> prompt_tokens = {1, 2, 3};
int64_t start_pos = 10; // Non-zero starting position

// Call prefill_chunk directly
auto result = prefiller->prefill_chunk(prompt_tokens, start_pos);

// Verify the result
EXPECT_EQ(result.error(), Error::Ok);

// Verify that step was called with incrementing positions
ASSERT_EQ(captured_positions.size(), 3);
EXPECT_EQ(captured_positions[0], 10); // First token at initial start_pos
EXPECT_EQ(captured_positions[1], 11); // Second token at start_pos + 1
EXPECT_EQ(captured_positions[2], 12); // Third token at start_pos + 2

// Verify that start_pos has been updated by the number of tokens
// This is the key test: start_pos should be updated exactly once per token
EXPECT_EQ(start_pos, 13); // 10 + 3 tokens
}

// Test that prefill with chunking updates start_pos correctly across chunks.
// This test would have caught the bug where start_pos was being updated twice.
TEST_F(
TextPrefillerTest,
PrefillWithChunkingUpdatesStartPosCorrectlyAcrossChunks) {
// Create a TextPrefiller with max_seq_len = 3 and parallel prefill
auto prefiller = createTextPrefiller(3, true, true);

// Track all positions passed to step
std::vector<int64_t> captured_positions;
EXPECT_CALL(text_decoder_runner_, step(_, _))
.Times(3) // Should be called 3 times: [1,2,3], [4,5,6], [7,8]
.WillRepeatedly(
[&](executorch::extension::TensorPtr& tokens, int64_t pos) {
captured_positions.push_back(pos);
return Result<executorch::aten::Tensor>(tensor);
});

// Create prompt tokens that exceed max_seq_len
std::vector<uint64_t> prompt_tokens = {1, 2, 3, 4, 5, 6, 7, 8};
int64_t start_pos = 100; // Non-zero starting position

// Call prefill (which will chunk internally)
auto result = prefiller->prefill(prompt_tokens, start_pos);

// Verify the result
EXPECT_EQ(result.error(), Error::Ok);

// Verify that step was called with correct positions for each chunk
// If start_pos were updated twice (the bug), these would be wrong
ASSERT_EQ(captured_positions.size(), 3);
EXPECT_EQ(captured_positions[0], 100); // Chunk 1: tokens [1,2,3]
EXPECT_EQ(captured_positions[1], 103); // Chunk 2: tokens [4,5,6]
EXPECT_EQ(captured_positions[2], 106); // Chunk 3: tokens [7,8]

// Verify that final start_pos is correct
// This is the key test for the bug: start_pos should be exactly
// initial_pos + num_tokens, not double-incremented
EXPECT_EQ(start_pos, 108); // 100 + 8 tokens
}
} // namespace
2 changes: 1 addition & 1 deletion extension/llm/runner/text_llm_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) {
Error err = generate(prompt, config);

// Reset stats after warmup, not resetting the std::unique_ptr!
stats_->reset();
reset();
return err;
}

Expand Down
1 change: 0 additions & 1 deletion extension/llm/runner/text_prefiller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
ET_CHECK_OK_OR_RETURN_ERROR(chunk_result.error());
cur_token = chunk_result.get();

start_pos += num_tokens_to_prefill_with;
num_tokens_to_process += num_tokens_to_prefill_with;
}

Expand Down
Loading