Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/models/llava/runner/llava_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class ET_EXPERIMENTAL LlavaRunner {
const float temperature = 0.8f)
: temperature_(temperature),
module_(std::make_unique<Module>(model_path, Module::LoadMode::File)),
io_manager_(std::make_unique<IOManager>()),
io_manager_(std::make_unique<IOManager>(*module_)),
tokenizer_path_(tokenizer_path) {
ET_LOG(
Info,
Expand Down
158 changes: 124 additions & 34 deletions extension/llm/runner/io_manager/io_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,8 @@

#pragma once

#include <vector>

#include <executorch/extension/module/module.h>
#include <executorch/extension/tensor/tensor.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/executor/method.h>
#include <executorch/runtime/executor/method_meta.h>

namespace executorch {
namespace extension {
Expand All @@ -29,6 +25,13 @@ namespace llm {
*/
class ET_EXPERIMENTAL IOManager {
public:
/**
* @brief Construct an IOManager bound to a Module.
*
* @param module The Module used for querying method metadata and execution.
*/
explicit IOManager(ET_MODULE_NAMESPACE::Module& module) : module_(module) {}

/**
* @brief Virtual destructor to allow proper cleanup in derived classes.
*/
Expand All @@ -38,88 +41,143 @@ class ET_EXPERIMENTAL IOManager {
* @brief Load the IO manager with method metadata for prefill and
* decode operations.
*
* @param program The program prefill and decode methods are loaded from.
* @param prefill_method The prefill method to initialize with.
* @param decode_method The decode method to initialize with.
*/
ET_NODISCARD virtual runtime::Error load(
const executorch::ET_RUNTIME_NAMESPACE::Program& program,
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
(void)program;
const std::string& prefill_method,
const std::string& decode_method) {
(void)prefill_method;
(void)decode_method;
return runtime::Error::Ok;
}

/**
* @brief Load the IO manager using the default method names.
*
* Uses "forward" for both prefill and decode.
*
* @return Error code.
*/
ET_NODISCARD runtime::Error load() {
return load("forward", "forward");
}

/**
* @brief Reset the IO manager state.
*
* @param prefill_method The prefill method to reset with.
* @param decode_method The decode method to reset with.
*/
ET_NODISCARD virtual runtime::Error reset(
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
const std::string& prefill_method,
const std::string& decode_method) {
(void)prefill_method;
(void)decode_method;
return runtime::Error::Ok;
}

/**
* @brief Reset the IO manager state using the default method names.
*
* Uses "forward" for both prefill and decode.
*
* @return Error code.
*/
ET_NODISCARD runtime::Error reset() {
return reset("forward", "forward");
}

/**
* @brief Prepare inputs for the prefill phase of LLM inference.
*
* @param input The input tensor containing token IDs.
* @param start_pos The tensor containing the starting position of the current
* input within the context.
* @param prefill_method The prefill method to prepare inputs for.
* @return std::vector<executorch::runtime::EValue> Vector of prepared inputs
* @return std::vector<runtime::EValue> Vector of prepared inputs
* for the prefill method.
*/
virtual runtime::Result<std::vector<executorch::runtime::EValue>>
prepare_prefill(
const executorch::extension::TensorPtr& input,
const executorch::extension::TensorPtr& start_pos,
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method) {
if (prefill_method.inputs_size() != 2) {
virtual runtime::Result<std::vector<runtime::EValue>> prepare_prefill(
const TensorPtr& input,
const TensorPtr& start_pos,
const std::string& prefill_method) {
auto method_meta = module_.method_meta(prefill_method);
if (!method_meta.ok()) {
return method_meta.error();
}
if (method_meta->num_inputs() != 2) {
ET_LOG(
Error,
"Expected 2 inputs for prefill method, got %zu. Likely the model takes the caches or mask as an argument which this IOManager does not support.",
prefill_method.inputs_size());
method_meta->num_inputs());
return runtime::Error::InvalidState;
}
// Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
// here.
return std::vector<runtime::EValue>{input, start_pos};
}

/**
* @brief Prepare inputs for the prefill phase using the default method name.
*
* Uses "forward" as the prefill method.
*
* @param input The input tensor containing token IDs.
* @param start_pos The tensor containing the starting position.
* @return Vector of prepared inputs for the prefill method.
*/
runtime::Result<std::vector<runtime::EValue>> prepare_prefill(
const TensorPtr& input,
const TensorPtr& start_pos) {
return prepare_prefill(input, start_pos, "forward");
}

/**
* @brief Prepare inputs for the decode phase of LLM inference.
*
* @param input The input tensor containing token IDs.
* @param start_pos The tensor containing the starting position of the current
* input within the context.
* @param decode_method The decode method to prepare inputs for.
* @return std::vector<executorch::runtime::EValue> Vector of prepared inputs
* @return std::vector<runtime::EValue> Vector of prepared inputs
* for the decode method.
*/
virtual runtime::Result<std::vector<executorch::runtime::EValue>>
prepare_decode(
const executorch::extension::TensorPtr& input,
const executorch::extension::TensorPtr& start_pos,
executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) {
if (decode_method.inputs_size() != 2) {
virtual runtime::Result<std::vector<runtime::EValue>> prepare_decode(
const TensorPtr& input,
const TensorPtr& start_pos,
const std::string& decode_method) {
auto method_meta = module_.method_meta(decode_method);
if (!method_meta.ok()) {
return method_meta.error();
}
if (method_meta->num_inputs() != 2) {
ET_LOG(
Error,
"Expected 2 inputs for decode method, got %zu. Likely the model takes the caches or mask as an argument which this IOManager does not support.",
decode_method.inputs_size());
method_meta->num_inputs());
return runtime::Error::InvalidState;
}
// Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
// here.
return std::vector<runtime::EValue>{input, start_pos};
}

/**
* @brief Prepare inputs for the decode phase using the default method name.
*
* Uses "forward" as the decode method.
*
* @param input The input tensor containing token IDs.
* @param start_pos The tensor containing the starting position.
* @return Vector of prepared inputs for the decode method.
*/
runtime::Result<std::vector<runtime::EValue>> prepare_decode(
const TensorPtr& input,
const TensorPtr& start_pos) {
return prepare_decode(input, start_pos, "forward");
}

/**
* @brief Process and update internal state with outputs from the prefill
* phase.
Expand All @@ -128,14 +186,27 @@ class ET_EXPERIMENTAL IOManager {
* @param model_outputs Vector of outputs from the prefill method execution.
*/
ET_NODISCARD virtual runtime::Error update_prefill(
executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method,
const std::vector<executorch::runtime::EValue>& model_outputs) {
(void)prefill_method;
const std::vector<runtime::EValue>& model_outputs,
const std::string& prefill_method) {
(void)model_outputs;
(void)prefill_method;
// No post inference work to do.
return runtime::Error::Ok;
}

/**
* @brief Process outputs from the prefill phase using the default method.
*
* Uses "forward" as the prefill method.
*
* @param model_outputs Vector of outputs from the prefill execution.
* @return Error code.
*/
ET_NODISCARD runtime::Error update_prefill(
const std::vector<runtime::EValue>& model_outputs) {
return update_prefill(model_outputs, "forward");
}

/**
* @brief Process and update internal state with outputs from the decode
* phase.
Expand All @@ -144,13 +215,32 @@ class ET_EXPERIMENTAL IOManager {
* @param model_outputs Vector of outputs from the decode method execution.
*/
ET_NODISCARD virtual runtime::Error update_decode(
const executorch::ET_RUNTIME_NAMESPACE::Method& decode_method,
const std::vector<executorch::runtime::EValue>& model_outputs) {
(void)decode_method;
const std::vector<runtime::EValue>& model_outputs,
const std::string& decode_method) {
(void)model_outputs;
(void)decode_method;
// No post inference work to do.
return runtime::Error::Ok;
}

/**
* @brief Process outputs from the decode phase using the default method.
*
* Uses "forward" as the decode method.
*
* @param model_outputs Vector of outputs from the decode execution.
* @return Error code.
*/
ET_NODISCARD runtime::Error update_decode(
const std::vector<runtime::EValue>& model_outputs) {
return update_decode(model_outputs, "forward");
}

private:
/**
* @brief Reference to the Module used for method metadata and execution.
*/
ET_MODULE_NAMESPACE::Module& module_;
};

} // namespace llm
Expand Down
5 changes: 2 additions & 3 deletions extension/llm/runner/io_manager/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@ def define_common_targets():
exported_headers = [
"io_manager.h",
],
deps = [
exported_deps = [
"//executorch/extension/tensor:tensor" + aten_suffix,
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
"//executorch/runtime/executor:program_no_prim_ops" + aten_suffix,
"//executorch/extension/module:module" + aten_suffix,
],
visibility = [
"@EXECUTORCH_CLIENTS",
Expand Down
10 changes: 4 additions & 6 deletions extension/llm/runner/io_manager/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@ define_common_targets()

runtime.cxx_test(
name = "test_io_manager",
srcs = ["test_io_manager.cpp"],
srcs = [
"test_io_manager.cpp",
],
deps = [
"//executorch/extension/llm/runner/io_manager:io_manager",
"//executorch/extension/llm/runner/io_manager:io_manager",
"//executorch/extension/module:module",
"//executorch/extension/tensor:tensor",
"//executorch/runtime/executor:program",
"//executorch/kernels/portable:generated_lib",
"//executorch/kernels/portable:generated_lib",
],
env = {
"KVCACHE_CACHE_POS": "$(location fbcode//executorch/test/models:exported_programs[ModuleKVCacheCachePos.pte])",
Expand Down
Loading
Loading