From 2663ae3f88b2c89a2f821f30e03e4da94174efd2 Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Fri, 22 Aug 2025 14:06:34 -0700 Subject: [PATCH] Make IOManager use Module instead of Method. (#13542) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/13542 Let's not expose Method from Module so that it's not getting misused beyond its owner. Reviewed By: mergennachin Differential Revision: D80595261 --- examples/models/llava/runner/llava_runner.h | 2 +- extension/llm/runner/io_manager/io_manager.h | 158 ++++++++++++++---- extension/llm/runner/io_manager/targets.bzl | 5 +- extension/llm/runner/io_manager/test/TARGETS | 10 +- .../io_manager/test/test_io_manager.cpp | 138 ++++----------- extension/llm/runner/llm_runner_helper.cpp | 4 +- .../runner/test/test_text_decoder_runner.cpp | 7 +- .../llm/runner/test/test_text_llm_runner.cpp | 28 +++- extension/llm/runner/text_decoder_runner.cpp | 9 +- extension/llm/runner/text_llm_runner.cpp | 9 +- 10 files changed, 190 insertions(+), 180 deletions(-) diff --git a/examples/models/llava/runner/llava_runner.h b/examples/models/llava/runner/llava_runner.h index 184522c2cf1..62df890b46d 100644 --- a/examples/models/llava/runner/llava_runner.h +++ b/examples/models/llava/runner/llava_runner.h @@ -42,7 +42,7 @@ class ET_EXPERIMENTAL LlavaRunner { const float temperature = 0.8f) : temperature_(temperature), module_(std::make_unique(model_path, Module::LoadMode::File)), - io_manager_(std::make_unique()), + io_manager_(std::make_unique(*module_)), tokenizer_path_(tokenizer_path) { ET_LOG( Info, diff --git a/extension/llm/runner/io_manager/io_manager.h b/extension/llm/runner/io_manager/io_manager.h index ce158c23b6e..fc9a8f0641b 100644 --- a/extension/llm/runner/io_manager/io_manager.h +++ b/extension/llm/runner/io_manager/io_manager.h @@ -8,12 +8,8 @@ #pragma once -#include - +#include #include -#include -#include -#include namespace executorch { namespace extension { @@ -29,6 +25,13 @@ namespace llm { */ class ET_EXPERIMENTAL IOManager { public: + /** + * @brief Construct an IOManager bound to a Module. + * + * @param module The Module used for querying method metadata and execution. + */ + explicit IOManager(ET_MODULE_NAMESPACE::Module& module) : module_(module) {} + /** * @brief Virtual destructor to allow proper cleanup in derived classes. */ @@ -38,20 +41,28 @@ class ET_EXPERIMENTAL IOManager { * @brief Load the IO manager with method metadata for prefill and * decode operations. * - * @param program The program prefill and decode methods are loaded from. * @param prefill_method The prefill method to initialize with. * @param decode_method The decode method to initialize with. */ ET_NODISCARD virtual runtime::Error load( - const executorch::ET_RUNTIME_NAMESPACE::Program& program, - executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method, - executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) { - (void)program; + const std::string& prefill_method, + const std::string& decode_method) { (void)prefill_method; (void)decode_method; return runtime::Error::Ok; } + /** + * @brief Load the IO manager using the default method names. + * + * Uses "forward" for both prefill and decode. + * + * @return Error code. + */ + ET_NODISCARD runtime::Error load() { + return load("forward", "forward"); + } + /** * @brief Reset the IO manager state. * @@ -59,13 +70,24 @@ class ET_EXPERIMENTAL IOManager { * @param decode_method The decode method to reset with. */ ET_NODISCARD virtual runtime::Error reset( - executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method, - executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) { + const std::string& prefill_method, + const std::string& decode_method) { (void)prefill_method; (void)decode_method; return runtime::Error::Ok; } + /** + * @brief Reset the IO manager state using the default method names. + * + * Uses "forward" for both prefill and decode. + * + * @return Error code. + */ + ET_NODISCARD runtime::Error reset() { + return reset("forward", "forward"); + } + /** * @brief Prepare inputs for the prefill phase of LLM inference. * @@ -73,19 +95,22 @@ class ET_EXPERIMENTAL IOManager { * @param start_pos The tensor containing the starting position of the current * input within the context. * @param prefill_method The prefill method to prepare inputs for. - * @return std::vector Vector of prepared inputs + * @return std::vector Vector of prepared inputs * for the prefill method. */ - virtual runtime::Result> - prepare_prefill( - const executorch::extension::TensorPtr& input, - const executorch::extension::TensorPtr& start_pos, - executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method) { - if (prefill_method.inputs_size() != 2) { + virtual runtime::Result> prepare_prefill( + const TensorPtr& input, + const TensorPtr& start_pos, + const std::string& prefill_method) { + auto method_meta = module_.method_meta(prefill_method); + if (!method_meta.ok()) { + return method_meta.error(); + } + if (method_meta->num_inputs() != 2) { ET_LOG( Error, "Expected 2 inputs for prefill method, got %zu. Likely the model takes the caches or mask as an argument which this IOManager does not support.", - prefill_method.inputs_size()); + method_meta->num_inputs()); return runtime::Error::InvalidState; } // Cpu IO Manager supports dynamic shapes for prefill, so no work to be done @@ -93,6 +118,21 @@ class ET_EXPERIMENTAL IOManager { return std::vector{input, start_pos}; } + /** + * @brief Prepare inputs for the prefill phase using the default method name. + * + * Uses "forward" as the prefill method. + * + * @param input The input tensor containing token IDs. + * @param start_pos The tensor containing the starting position. + * @return Vector of prepared inputs for the prefill method. + */ + runtime::Result> prepare_prefill( + const TensorPtr& input, + const TensorPtr& start_pos) { + return prepare_prefill(input, start_pos, "forward"); + } + /** * @brief Prepare inputs for the decode phase of LLM inference. * @@ -100,19 +140,22 @@ class ET_EXPERIMENTAL IOManager { * @param start_pos The tensor containing the starting position of the current * input within the context. * @param decode_method The decode method to prepare inputs for. - * @return std::vector Vector of prepared inputs + * @return std::vector Vector of prepared inputs * for the decode method. */ - virtual runtime::Result> - prepare_decode( - const executorch::extension::TensorPtr& input, - const executorch::extension::TensorPtr& start_pos, - executorch::ET_RUNTIME_NAMESPACE::Method& decode_method) { - if (decode_method.inputs_size() != 2) { + virtual runtime::Result> prepare_decode( + const TensorPtr& input, + const TensorPtr& start_pos, + const std::string& decode_method) { + auto method_meta = module_.method_meta(decode_method); + if (!method_meta.ok()) { + return method_meta.error(); + } + if (method_meta->num_inputs() != 2) { ET_LOG( Error, "Expected 2 inputs for decode method, got %zu. Likely the model takes the caches or mask as an argument which this IOManager does not support.", - decode_method.inputs_size()); + method_meta->num_inputs()); return runtime::Error::InvalidState; } // Cpu IO Manager supports dynamic shapes for prefill, so no work to be done @@ -120,6 +163,21 @@ class ET_EXPERIMENTAL IOManager { return std::vector{input, start_pos}; } + /** + * @brief Prepare inputs for the decode phase using the default method name. + * + * Uses "forward" as the decode method. + * + * @param input The input tensor containing token IDs. + * @param start_pos The tensor containing the starting position. + * @return Vector of prepared inputs for the decode method. + */ + runtime::Result> prepare_decode( + const TensorPtr& input, + const TensorPtr& start_pos) { + return prepare_decode(input, start_pos, "forward"); + } + /** * @brief Process and update internal state with outputs from the prefill * phase. @@ -128,14 +186,27 @@ class ET_EXPERIMENTAL IOManager { * @param model_outputs Vector of outputs from the prefill method execution. */ ET_NODISCARD virtual runtime::Error update_prefill( - executorch::ET_RUNTIME_NAMESPACE::Method& prefill_method, - const std::vector& model_outputs) { - (void)prefill_method; + const std::vector& model_outputs, + const std::string& prefill_method) { (void)model_outputs; + (void)prefill_method; // No post inference work to do. return runtime::Error::Ok; } + /** + * @brief Process outputs from the prefill phase using the default method. + * + * Uses "forward" as the prefill method. + * + * @param model_outputs Vector of outputs from the prefill execution. + * @return Error code. + */ + ET_NODISCARD runtime::Error update_prefill( + const std::vector& model_outputs) { + return update_prefill(model_outputs, "forward"); + } + /** * @brief Process and update internal state with outputs from the decode * phase. @@ -144,13 +215,32 @@ class ET_EXPERIMENTAL IOManager { * @param model_outputs Vector of outputs from the decode method execution. */ ET_NODISCARD virtual runtime::Error update_decode( - const executorch::ET_RUNTIME_NAMESPACE::Method& decode_method, - const std::vector& model_outputs) { - (void)decode_method; + const std::vector& model_outputs, + const std::string& decode_method) { (void)model_outputs; + (void)decode_method; // No post inference work to do. return runtime::Error::Ok; } + + /** + * @brief Process outputs from the decode phase using the default method. + * + * Uses "forward" as the decode method. + * + * @param model_outputs Vector of outputs from the decode execution. + * @return Error code. + */ + ET_NODISCARD runtime::Error update_decode( + const std::vector& model_outputs) { + return update_decode(model_outputs, "forward"); + } + + private: + /** + * @brief Reference to the Module used for method metadata and execution. + */ + ET_MODULE_NAMESPACE::Module& module_; }; } // namespace llm diff --git a/extension/llm/runner/io_manager/targets.bzl b/extension/llm/runner/io_manager/targets.bzl index ef93d541098..5b891b24376 100644 --- a/extension/llm/runner/io_manager/targets.bzl +++ b/extension/llm/runner/io_manager/targets.bzl @@ -11,10 +11,9 @@ def define_common_targets(): exported_headers = [ "io_manager.h", ], - deps = [ + exported_deps = [ "//executorch/extension/tensor:tensor" + aten_suffix, - "//executorch/runtime/core/exec_aten:lib" + aten_suffix, - "//executorch/runtime/executor:program_no_prim_ops" + aten_suffix, + "//executorch/extension/module:module" + aten_suffix, ], visibility = [ "@EXECUTORCH_CLIENTS", diff --git a/extension/llm/runner/io_manager/test/TARGETS b/extension/llm/runner/io_manager/test/TARGETS index 6db0a7c590b..e214060942a 100644 --- a/extension/llm/runner/io_manager/test/TARGETS +++ b/extension/llm/runner/io_manager/test/TARGETS @@ -10,14 +10,12 @@ define_common_targets() runtime.cxx_test( name = "test_io_manager", - srcs = ["test_io_manager.cpp"], + srcs = [ + "test_io_manager.cpp", + ], deps = [ "//executorch/extension/llm/runner/io_manager:io_manager", - "//executorch/extension/llm/runner/io_manager:io_manager", - "//executorch/extension/module:module", - "//executorch/extension/tensor:tensor", - "//executorch/runtime/executor:program", - "//executorch/kernels/portable:generated_lib", + "//executorch/kernels/portable:generated_lib", ], env = { "KVCACHE_CACHE_POS": "$(location fbcode//executorch/test/models:exported_programs[ModuleKVCacheCachePos.pte])", diff --git a/extension/llm/runner/io_manager/test/test_io_manager.cpp b/extension/llm/runner/io_manager/test/test_io_manager.cpp index bc265e8d083..7c31ff9ea18 100644 --- a/extension/llm/runner/io_manager/test/test_io_manager.cpp +++ b/extension/llm/runner/io_manager/test/test_io_manager.cpp @@ -7,74 +7,45 @@ */ #include -#include -#include -#include -#include + #include using namespace ::testing; -using executorch::extension::Module; -using executorch::extension::llm::IOManager; -using executorch::runtime::Error; -using executorch::runtime::EValue; -using executorch::runtime::Method; -using executorch::runtime::Program; -using executorch::runtime::Result; +using namespace ::executorch::extension; +using namespace ::executorch::runtime; // Test fixture for IOManager tests class IOManagerTest : public Test { protected: void SetUp() override { - executorch::runtime::runtime_init(); - module_ = std::make_unique(std::getenv("KVCACHE_CACHE_POS")); - io_manager_ = std::make_unique(); - auto err = module_->load_method("forward"); - EXPECT_EQ(err, Error::Ok); + io_manager_ = std::make_unique(*module_); + EXPECT_EQ(module_->load_forward(), Error::Ok); } protected: std::unique_ptr module_; - - std::unique_ptr io_manager_; + std::unique_ptr io_manager_; }; // Test that load() returns Error::Ok (no-op) TEST_F(IOManagerTest, LoadReturnsOk) { - auto* program = module_->program().get(); - auto* prefill_method = module_->method("forward").get(); - auto* decode_method = module_->method("forward").get(); - - auto result = io_manager_->load(*program, *prefill_method, *decode_method); - - EXPECT_EQ(result, Error::Ok); + EXPECT_EQ(io_manager_->load(), Error::Ok); } // Test that reset() returns Error::Ok (no-op) TEST_F(IOManagerTest, ResetReturnsOk) { - auto* prefill_method = module_->method("forward").get(); - auto* decode_method = module_->method("forward").get(); - - auto result = io_manager_->reset(*prefill_method, *decode_method); - - EXPECT_EQ(result, Error::Ok); + EXPECT_EQ(io_manager_->reset(), Error::Ok); } // Test that prepare_prefill() returns the input tensors when method has 2 // inputs TEST_F(IOManagerTest, PreparePrefillReturnsInputsWhenValidInputCount) { - auto* prefill_method = module_->method("forward").get(); - - // Create test tensors std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f}; std::vector start_pos_data = {0}; - auto input_ptr = executorch::extension::make_tensor_ptr({1, 4}, input_data); - auto start_pos_ptr = - executorch::extension::make_tensor_ptr({1}, start_pos_data); - - auto result = - io_manager_->prepare_prefill(input_ptr, start_pos_ptr, *prefill_method); + auto input_ptr = make_tensor_ptr({1, 4}, input_data); + auto start_pos_ptr = make_tensor_ptr({1}, start_pos_data); + auto result = io_manager_->prepare_prefill(input_ptr, start_pos_ptr); EXPECT_EQ(result.error(), Error::Ok); auto outputs = result.get(); @@ -87,17 +58,12 @@ TEST_F(IOManagerTest, PreparePrefillReturnsInputsWhenValidInputCount) { // Test that prepare_decode() returns the input tensors when method has 2 inputs TEST_F(IOManagerTest, PrepareDecodeReturnsInputsWhenValidInputCount) { - auto* decode_method = module_->method("forward").get(); - - // Create test tensors std::vector input_data = {5.0f, 6.0f, 7.0f, 8.0f}; std::vector start_pos_data = {10}; - auto input_ptr = executorch::extension::make_tensor_ptr({1, 4}, input_data); - auto start_pos_ptr = - executorch::extension::make_tensor_ptr({1}, start_pos_data); + auto input_ptr = make_tensor_ptr({1, 4}, input_data); + auto start_pos_ptr = make_tensor_ptr({1}, start_pos_data); - auto result = - io_manager_->prepare_decode(input_ptr, start_pos_ptr, *decode_method); + auto result = io_manager_->prepare_decode(input_ptr, start_pos_ptr); EXPECT_EQ(result.error(), Error::Ok); auto outputs = result.get(); @@ -110,49 +76,31 @@ TEST_F(IOManagerTest, PrepareDecodeReturnsInputsWhenValidInputCount) { // Test that update_prefill() returns Error::Ok (no-op) TEST_F(IOManagerTest, UpdatePrefillReturnsOk) { - auto* prefill_method = module_->method("forward").get(); - - // Create dummy model outputs std::vector model_outputs; std::vector output_data = {0.1f, 0.2f, 0.3f}; - auto output_tensor = - executorch::extension::make_tensor_ptr({1, 3}, output_data); + auto output_tensor = make_tensor_ptr({1, 3}, output_data); model_outputs.emplace_back(*output_tensor); - auto result = io_manager_->update_prefill(*prefill_method, model_outputs); - - EXPECT_EQ(result, Error::Ok); + EXPECT_EQ(io_manager_->update_prefill(model_outputs), Error::Ok); } // Test that update_decode() returns Error::Ok (no-op) TEST_F(IOManagerTest, UpdateDecodeReturnsOk) { - auto* decode_method = module_->method("forward").get(); - - // Create dummy model outputs std::vector model_outputs; std::vector output_data = {0.4f, 0.5f, 0.6f}; - auto output_tensor = - executorch::extension::make_tensor_ptr({1, 3}, output_data); + auto output_tensor = make_tensor_ptr({1, 3}, output_data); model_outputs.emplace_back(*output_tensor); - auto result = io_manager_->update_decode(*decode_method, model_outputs); - - EXPECT_EQ(result, Error::Ok); + EXPECT_EQ(io_manager_->update_decode(model_outputs), Error::Ok); } // Test that prepare_prefill() correctly passes through different tensor shapes TEST_F(IOManagerTest, PreparePrefillPassesThroughDifferentTensorShapes) { - auto* prefill_method = module_->method("forward").get(); - - // Create test tensors with different shapes std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; std::vector start_pos_data = {5, 10}; - auto input_ptr = executorch::extension::make_tensor_ptr({2, 3}, input_data); - auto start_pos_ptr = - executorch::extension::make_tensor_ptr({2}, start_pos_data); - - auto result = - io_manager_->prepare_prefill(input_ptr, start_pos_ptr, *prefill_method); + auto input_ptr = make_tensor_ptr({2, 3}, input_data); + auto start_pos_ptr = make_tensor_ptr({2}, start_pos_data); + auto result = io_manager_->prepare_prefill(input_ptr, start_pos_ptr); EXPECT_EQ(result.error(), Error::Ok); auto outputs = result.get(); @@ -165,18 +113,12 @@ TEST_F(IOManagerTest, PreparePrefillPassesThroughDifferentTensorShapes) { // Test that prepare_decode() correctly passes through different tensor shapes TEST_F(IOManagerTest, PrepareDecodePassesThroughDifferentTensorShapes) { - auto* decode_method = module_->method("forward").get(); - - // Create test tensors with different shapes std::vector input_data = { 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f}; std::vector start_pos_data = {15, 20, 25}; - auto input_ptr = executorch::extension::make_tensor_ptr({2, 4}, input_data); - auto start_pos_ptr = - executorch::extension::make_tensor_ptr({3}, start_pos_data); - - auto result = - io_manager_->prepare_decode(input_ptr, start_pos_ptr, *decode_method); + auto input_ptr = make_tensor_ptr({2, 4}, input_data); + auto start_pos_ptr = make_tensor_ptr({3}, start_pos_data); + auto result = io_manager_->prepare_decode(input_ptr, start_pos_ptr); EXPECT_EQ(result.error(), Error::Ok); auto outputs = result.get(); @@ -189,42 +131,22 @@ TEST_F(IOManagerTest, PrepareDecodePassesThroughDifferentTensorShapes) { // Test that update methods handle empty model outputs TEST_F(IOManagerTest, UpdateMethodsHandleEmptyModelOutputs) { - auto* prefill_method = module_->method("forward").get(); - auto* decode_method = module_->method("forward").get(); - - // Create empty model outputs std::vector empty_outputs; - auto prefill_result = - io_manager_->update_prefill(*prefill_method, empty_outputs); - auto decode_result = - io_manager_->update_decode(*decode_method, empty_outputs); - - EXPECT_EQ(prefill_result, Error::Ok); - EXPECT_EQ(decode_result, Error::Ok); + EXPECT_EQ(io_manager_->update_prefill(empty_outputs), Error::Ok); + EXPECT_EQ(io_manager_->update_decode(empty_outputs), Error::Ok); } // Test that update methods handle multiple model outputs TEST_F(IOManagerTest, UpdateMethodsHandleMultipleModelOutputs) { - auto* prefill_method = module_->method("forward").get(); - auto* decode_method = module_->method("forward").get(); - - // Create multiple model outputs std::vector model_outputs; std::vector output1_data = {0.1f, 0.2f}; std::vector output2_data = {0.3f, 0.4f, 0.5f}; - auto output1_tensor = - executorch::extension::make_tensor_ptr({1, 2}, output1_data); - auto output2_tensor = - executorch::extension::make_tensor_ptr({1, 3}, output2_data); + auto output1_tensor = make_tensor_ptr({1, 2}, output1_data); + auto output2_tensor = make_tensor_ptr({1, 3}, output2_data); model_outputs.emplace_back(*output1_tensor); model_outputs.emplace_back(*output2_tensor); - auto prefill_result = - io_manager_->update_prefill(*prefill_method, model_outputs); - auto decode_result = - io_manager_->update_decode(*decode_method, model_outputs); - - EXPECT_EQ(prefill_result, Error::Ok); - EXPECT_EQ(decode_result, Error::Ok); + EXPECT_EQ(io_manager_->update_prefill(model_outputs), Error::Ok); + EXPECT_EQ(io_manager_->update_decode(model_outputs), Error::Ok); } diff --git a/extension/llm/runner/llm_runner_helper.cpp b/extension/llm/runner/llm_runner_helper.cpp index 2e17e518c4a..ec2e335b7d6 100644 --- a/extension/llm/runner/llm_runner_helper.cpp +++ b/extension/llm/runner/llm_runner_helper.cpp @@ -171,7 +171,7 @@ std::unique_ptr create_text_llm_runner( llm::get_eos_ids(tokenizer.get(), module.get())); // Create IOManager - std::unique_ptr io_manager = std::make_unique(); + std::unique_ptr io_manager = std::make_unique(*module); // Create text_decoder_runner. Use a shared_ptr so that it can be shared with // TextPrefiller and TextTokenGenerator @@ -234,7 +234,7 @@ std::unique_ptr create_multimodal_runner( get_eos_ids(tokenizer.get(), module.get())); // Create IOManager - std::unique_ptr io_manager = std::make_unique(); + std::unique_ptr io_manager = std::make_unique(*module); // Create text_decoder_runner auto text_decoder_runner = diff --git a/extension/llm/runner/test/test_text_decoder_runner.cpp b/extension/llm/runner/test/test_text_decoder_runner.cpp index 9b1c57216e6..0001509ec55 100644 --- a/extension/llm/runner/test/test_text_decoder_runner.cpp +++ b/extension/llm/runner/test/test_text_decoder_runner.cpp @@ -36,7 +36,8 @@ class TextDecoderRunnerTest : public Test { protected: void SetUp() override { mock_module_ = std::make_unique(); - io_manager_ = std::make_unique(); + io_manager_ = + std::make_unique(*mock_module_); runner_ = std::make_unique( mock_module_.get(), io_manager_.get()); } @@ -162,8 +163,8 @@ TEST_F(TextDecoderRunnerTest, StepWithAllModels) { << model_path << " with error: " << (int)load_result; continue; } - std::unique_ptr io_manager = - std::make_unique(); + auto io_manager = + std::make_unique(*module); // Create TextDecoderRunner TextDecoderRunner runner(module.get(), io_manager.get()); auto runner_load_result = runner.load(); diff --git a/extension/llm/runner/test/test_text_llm_runner.cpp b/extension/llm/runner/test/test_text_llm_runner.cpp index 05c11bfe16b..8ec48b48ec3 100644 --- a/extension/llm/runner/test/test_text_llm_runner.cpp +++ b/extension/llm/runner/test/test_text_llm_runner.cpp @@ -219,14 +219,17 @@ TEST_F(RunnerTest, GenerateCallsCallbackExactlyMaxNewTokensTimes) { tokenizer.get(), text_decoder_runner.get(), stats.get()); // Create a Runner with our mocked components + auto module = std::make_unique(); + auto io_manager = + std::make_unique(*module); TextLLMRunner runner( createDefaultMetadata(), std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()), - std::make_unique(), + std::move(module), std::move(text_decoder_runner), std::unique_ptr<::executorch::extension::llm::TextPrefiller>( text_prefiller.release()), - std::make_unique(), + std::move(io_manager), std::move(text_token_generator), std::move(stats)); @@ -284,14 +287,17 @@ TEST_F(RunnerTest, WarmupCallsGenerateWithWarmingFlag) { tokenizer.get(), text_decoder_runner.get(), stats.get()); // Create a Runner with our mocked components + auto module = std::make_unique(); + auto io_manager = + std::make_unique(*module); TextLLMRunner runner( createDefaultMetadata(), std::move(tokenizer), - std::make_unique(), + std::move(module), std::move(text_decoder_runner), std::unique_ptr<::executorch::extension::llm::TextPrefiller>( text_prefiller.release()), - std::make_unique(), + std::move(io_manager), std::move(text_token_generator), std::move(stats)); @@ -319,14 +325,17 @@ TEST_F(RunnerTest, IsLoadedReturnsTrueWhenComponentsInitialized) { tokenizer.get(), text_decoder_runner.get(), stats.get()); // Create a Runner with our mocked components + auto module = std::make_unique(); + auto io_manager = + std::make_unique(*module); TextLLMRunner runner( createDefaultMetadata(), std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()), - std::make_unique(), + std::move(module), std::move(text_decoder_runner), std::unique_ptr<::executorch::extension::llm::TextPrefiller>( text_prefiller.release()), - std::make_unique(), + std::move(io_manager), std::move(text_token_generator), std::move(stats)); @@ -361,6 +370,9 @@ TEST_F(RunnerTest, GenerateFromPosErrorsWithNegativeMaxNewTokens) { tokenizer.get(), text_decoder_runner.get(), stats.get()); // Create a Runner with our mocked components + auto module = std::make_unique(); + auto io_manager = + std::make_unique(*module); TextLLMRunner runner( { {"enable_dynamic_shape", false}, @@ -369,11 +381,11 @@ TEST_F(RunnerTest, GenerateFromPosErrorsWithNegativeMaxNewTokens) { {"use_kv_cache", true}, }, std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()), - std::make_unique(), + std::move(module), std::move(text_decoder_runner), std::unique_ptr<::executorch::extension::llm::TextPrefiller>( text_prefiller.release()), - std::make_unique(), + std::move(io_manager), std::move(text_token_generator), std::move(stats)); diff --git a/extension/llm/runner/text_decoder_runner.cpp b/extension/llm/runner/text_decoder_runner.cpp index bffd140eade..27c00c19089 100644 --- a/extension/llm/runner/text_decoder_runner.cpp +++ b/extension/llm/runner/text_decoder_runner.cpp @@ -69,18 +69,13 @@ ::executorch::runtime::Result TextDecoderRunner::step( } std::vector inputs; - auto method_err = module_->method("forward"); - ET_CHECK_OK_OR_RETURN_ERROR(method_err.error()); - auto& method = *(method_err.get()); - - auto inputs_res = - io_manager_->prepare_decode(tokens, start_pos_tensor, method); + auto inputs_res = io_manager_->prepare_decode(tokens, start_pos_tensor); ET_CHECK_OK_OR_RETURN_ERROR(inputs_res.error()); inputs = inputs_res.get(); auto outputs_res = module_->forward(inputs); ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); - auto update_err = io_manager_->update_decode(method, outputs_res.get()); + auto update_err = io_manager_->update_decode(outputs_res.get()); ET_CHECK_OK_OR_RETURN_ERROR(update_err); ET_CHECK_MSG( diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index 2220a84ff0f..f0ac9ed0781 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -57,14 +57,7 @@ Error TextLLMRunner::load() { return Error::Ok; } ET_CHECK_OK_OR_RETURN_ERROR(text_prefiller_->load()); - ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward")); - auto method_res = module_->method("forward"); - - Program& program = *module_->program(); - - ET_CHECK_OK_OR_RETURN_ERROR(method_res.error()); - auto& forward = *(method_res.get()); - ET_CHECK_OK_OR_RETURN_ERROR(io_manager_->load(program, forward, forward)); + ET_CHECK_OK_OR_RETURN_ERROR(io_manager_->load()); ET_CHECK_OK_OR_RETURN_ERROR(text_token_generator_->load()); return Error::Ok; }