Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion extension/llm/runner/multimodal_decoder_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class ET_EXPERIMENTAL MultimodalDecoderRunner
&start_pos, {1}, executorch::aten::ScalarType::Long);
// run text model
auto outputs_res = ET_UNWRAP(
module_->execute(kTextModelMethod, {start_pos_tensor, embeddings}));
module_->execute(kTextModelMethod, {embeddings, start_pos_tensor}));

ET_CHECK_MSG(
outputs_res.size() == 1,
Expand Down
22 changes: 10 additions & 12 deletions extension/llm/runner/multimodal_prefiller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,24 +91,22 @@ Result<uint64_t> MultimodalPrefiller::prefill(
}

// 2. Run decoder model for prefill.
// `cache_position` goes from start_pos to start_pos + encoder_output.size(1).
// e.g. if start_pos = 2 and encoder_output.size(1) = 5,
// cache_position_tensor should be [2, 3, 4, 5, 6].

// Get expected shape of cache position tensor, which should be the second
// argument

int64_t seq_len = encoder_output.toTensor().size(1);
if (seq_len == 0) {
ET_LOG(Error, "The encoder returned an empty output.");
return ::executorch::runtime::Error::InvalidState;
}
std::vector<int64_t> cache_positions(seq_len);
for (int64_t i = 0; i < seq_len; ++i) {
cache_positions[i] = start_pos + i;
}
auto cache_position_tensor = ::executorch::extension::from_blob(
cache_positions.data(),
{static_cast<int>(seq_len)},
executorch::aten::ScalarType::Long);
std::vector<int64_t> cache_positions;

auto cache_position_tensor = ET_UNWRAP(populate_start_pos_or_cache_position(
module_, start_pos, cache_positions, seq_len, kTextModelMethod));

auto prefill_result = module_->execute(
kTextModelMethod, {cache_position_tensor, encoder_output});
kTextModelMethod, {encoder_output, cache_position_tensor});
if (prefill_result.error() != ::executorch::runtime::Error::Ok) {
return prefill_result.error();
}
Expand Down
34 changes: 4 additions & 30 deletions extension/llm/runner/text_decoder_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,37 +36,11 @@ ::executorch::runtime::Result<executorch::aten::Tensor> TextDecoderRunner::step(
// If only 1 input, we are not using kv cache
bool use_kv_cache = method_meta.num_inputs() > 1;

std::vector<int64_t> cache_positions;

if (use_kv_cache) {
// Size of the second argument. This could be either input_pos or
// cache_positions

// Check if we are using cache positions instead of input pos.
auto second_input_info = ET_UNWRAP(method_meta.input_tensor_meta(1));
// For input_pos, numel is 1, for cache_positions, numel is max_seq_len
auto sizes = second_input_info.sizes();
// Assuming 1D tensor
ET_CHECK_OR_RETURN_ERROR(
sizes.size() == 1,
InvalidProgram,
"The second input tensor is not 1D tensor. Got dimension (%zu)",
sizes.size());
auto numel = sizes[0];
std::vector<::executorch::aten::SizesType> sizes_vec = {numel};

TensorPtr start_pos_tensor;
if (numel > 1) {
// If we are here, model is exported with cache_positions, create a tensor
// with the same length as input_ids. Assuming the last dimension is the
// one with the variable token length, for example [1, S] or [1, 1, S]
sizes_vec[sizes_vec.size() - 1] = tokens->numel();
start_pos_tensor = empty(sizes_vec, ::executorch::aten::ScalarType::Long);
torch::executor::native::arange_out_impl(
start_pos, start_pos + tokens->numel(), 1.0, *start_pos_tensor);
} else {
// Assuming model is exported with input_pos, create a tensor with size 1
start_pos_tensor = from_blob(
&start_pos, sizes_vec, ::executorch::aten::ScalarType::Long);
}
auto start_pos_tensor = ET_UNWRAP(populate_start_pos_or_cache_position(
module_, start_pos, cache_positions, tokens->numel(), "forward"));

std::vector<runtime::EValue> inputs;
auto inputs_res = io_manager_->prepare_decode(tokens, start_pos_tensor);
Expand Down
45 changes: 45 additions & 0 deletions extension/llm/runner/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
*/

#pragma once
#include <executorch/extension/llm/runner/constants.h>
#include <executorch/extension/llm/runner/multimodal_prefiller.h>
#include <executorch/extension/tensor/tensor.h>
#include <executorch/runtime/platform/compiler.h>
#include <stdio.h>
#include <time.h>
Expand Down Expand Up @@ -99,6 +102,48 @@ ET_EXPERIMENTAL size_t inline get_rss_bytes() {
// when this changed.
return 0;
}

// Returns the cache position tensor, which can be either a single start_pos
// (when the method_name [`text_decoder` or `forward`] expects a tensor with
// size 1 because model will populate the cache position tensor underneath), or
// a populated tensor for cache position, for the given start_pos and seq_len.
inline runtime::Result<TensorPtr> populate_start_pos_or_cache_position(
Module* module,
int64_t& start_pos,
std::vector<int64_t>& cache_positions_vec,
int seq_len,
const char* method_name = "forward") {
// Get expected shape of cache position tensor, which should be the second
// argument
auto method_meta = ET_UNWRAP(module->method_meta(method_name));
auto second_input_info = ET_UNWRAP(method_meta.input_tensor_meta(1));
auto second_input_sizes = second_input_info.sizes();
auto numel = second_input_sizes[0];

for (int i = 0; i < second_input_sizes.size(); ++i) {
ET_LOG(Error, "second_input_sizes[%d] = %d", i, second_input_sizes[i]);
}

TensorPtr start_pos_tensor;
if (numel > 1) {
// `cache_position` goes from start_pos to start_pos +
// encoder_output.size(1). e.g. if start_pos = 2 and encoder_output.size(1)
// = 5, cache_position_tensor should be [2, 3, 4, 5, 6].
cache_positions_vec.resize(seq_len);
for (int64_t i = 0; i < seq_len; ++i) {
cache_positions_vec[i] = start_pos + i;
}
return ::executorch::extension::from_blob(
cache_positions_vec.data(),
{static_cast<int>(seq_len)},
executorch::aten::ScalarType::Long);
} else {
// Cache position is size 1.
return ::executorch::extension::from_blob(
&start_pos, {1}, executorch::aten::ScalarType::Long);
}
}

} // namespace llm
} // namespace extension
} // namespace executorch
Expand Down
Loading