diff --git a/extension/module/module.cpp b/extension/module/module.cpp index b06fe1279f0..ba90b4e87f3 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -154,6 +154,7 @@ runtime::Error Module::load_method( temp_allocator_.get()); method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method( method_name.c_str(), method_holder.memory_manager.get(), tracer)); + method_holder.inputs.resize(method_holder.method->inputs_size()); methods_.emplace(method_name, std::move(method_holder)); } return runtime::Error::Ok; @@ -170,10 +171,19 @@ runtime::Result> Module::execute( const std::vector& input_values) { ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name)); auto& method = methods_.at(method_name).method; + auto& inputs = methods_.at(method_name).inputs; - ET_CHECK_OK_OR_RETURN_ERROR( - method->set_inputs(exec_aten::ArrayRef( - input_values.data(), input_values.size()))); + for (size_t i = 0; i < input_values.size(); ++i) { + if (!input_values[i].isNone()) { + inputs[i] = input_values[i]; + } + } + for (size_t i = 0; i < inputs.size(); ++i) { + ET_CHECK_OR_RETURN_ERROR( + !inputs[i].isNone(), InvalidArgument, "input %zu is none", i); + } + ET_CHECK_OK_OR_RETURN_ERROR(method->set_inputs( + exec_aten::ArrayRef(inputs.data(), inputs.size()))); ET_CHECK_OK_OR_RETURN_ERROR(method->execute()); const auto outputs_size = method->outputs_size(); @@ -184,6 +194,30 @@ runtime::Result> Module::execute( return outputs; } +runtime::Error Module::set_input( + const std::string& method_name, + const runtime::EValue& input_value, + size_t input_index) { + ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name)); + methods_.at(method_name).inputs.at(input_index) = input_value; + return runtime::Error::Ok; +} + +runtime::Error Module::set_inputs( + const std::string& method_name, + const std::vector& input_values) { + ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name)); + auto& inputs = methods_.at(method_name).inputs; + ET_CHECK_OR_RETURN_ERROR( + inputs.size() == input_values.size(), + InvalidArgument, + "input size: %zu does not match method input size: %zu", + input_values.size(), + inputs.size()); + inputs = input_values; + return runtime::Error::Ok; +} + runtime::Error Module::set_output_data_ptr( runtime::EValue output_value, size_t output_index, diff --git a/extension/module/module.h b/extension/module/module.h index 1197eace331..4c0dddc559e 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -300,6 +300,62 @@ class Module { return forward(std::vector{}); } + /** + * Sets a single input value for a specific method. + * + * @param[in] method_name The name of the method. + * @param[in] input_value The EValue to set as the method input. + * @param[in] input_index Zero-based index of the input to set. + * + * @returns An Error to indicate success or failure. + */ + ET_NODISCARD + runtime::Error set_input( + const std::string& method_name, + const runtime::EValue& input_value, + size_t input_index); + + /** + * Sets a single input value for the "forward" method. + * + * @param[in] input_value The EValue to set as the method input. + * @param[in] input_index Zero-based index of the input to set. + * + * @returns An Error to indicate success or failure. + */ + ET_NODISCARD + inline runtime::Error set_input( + const runtime::EValue& input_value, + size_t input_index) { + return set_input("forward", input_value, input_index); + } + + /** + * Sets all input values for a specific method. + * + * @param[in] method_name The name of the method. + * @param[in] input_values A vector of EValues to set as the method inputs. + * + * @returns An Error to indicate success or failure. + */ + ET_NODISCARD + runtime::Error set_inputs( + const std::string& method_name, + const std::vector& input_values); + + /** + * Sets all input values for the "forward" method. + * + * @param[in] input_values A vector of EValues to set as the method inputs. + * + * @returns An Error to indicate success or failure. + */ + ET_NODISCARD + inline runtime::Error set_inputs( + const std::vector& input_values) { + return set_inputs("forward", input_values); + } + /** * Retrieves the EventTracer instance being used by the Module. * EventTracer is used for tracking and logging events during the execution @@ -332,6 +388,7 @@ class Module { std::unique_ptr planned_memory; std::unique_ptr memory_manager; std::unique_ptr method; + std::vector inputs; }; private: diff --git a/extension/module/test/module_test.cpp b/extension/module/test/module_test.cpp index 6f18c8d9cbf..38cbfb39d68 100644 --- a/extension/module/test/module_test.cpp +++ b/extension/module/test/module_test.cpp @@ -373,3 +373,51 @@ TEST_F(ModuleTest, TestConcurrentExecutionWithSharedProgram) { t4.join(); t5.join(); } + +TEST_F(ModuleTest, TestSetInputsBeforeExecute) { + Module module(model_path_); + + auto tensor1 = make_tensor_ptr({4.f}); + auto tensor2 = make_tensor_ptr({5.f}); + + EXPECT_EQ(module.set_inputs({tensor1, tensor2}), Error::Ok); + + const auto result = module.forward(); + EXPECT_EQ(result.error(), Error::Ok); + + const auto data = result->at(0).toTensor().const_data_ptr(); + EXPECT_NEAR(data[0], 9, 1e-5); +} + +TEST_F(ModuleTest, TestSetInputCombinedWithExecute) { + Module module(model_path_); + + auto tensor1 = make_tensor_ptr({2.f}); + auto tensor2 = make_tensor_ptr({3.f}); + + EXPECT_EQ(module.set_input(tensor2, 1), Error::Ok); + + const auto result = module.forward(tensor1); + EXPECT_EQ(result.error(), Error::Ok); + + const auto data = result->at(0).toTensor().const_data_ptr(); + EXPECT_NEAR(data[0], 5, 1e-5); +} + +TEST_F(ModuleTest, TestPartiallySetInputs) { + Module module(model_path_); + + auto tensor = make_tensor_ptr({1.f}); + + EXPECT_EQ(module.set_input(tensor, 0), Error::Ok); + + const auto result = module.forward(); + EXPECT_NE(result.error(), Error::Ok); +} + +TEST_F(ModuleTest, TestUnsetInputs) { + Module module(model_path_); + + const auto result = module.forward(); + EXPECT_NE(result.error(), Error::Ok); +}