From 78d6b88854fda60ffbea4d688b41e9d9945dcedb Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Fri, 22 Aug 2025 14:54:49 -0700 Subject: [PATCH] Add set_outputs() API. (#13609) Summary: . Reviewed By: JacobSzwejbka Differential Revision: D80845634 --- extension/module/module.cpp | 19 +++++++++++++++ extension/module/module.h | 35 +++++++++++++++++++++++++++ extension/module/test/module_test.cpp | 18 ++++++++++++++ 3 files changed, 72 insertions(+) diff --git a/extension/module/module.cpp b/extension/module/module.cpp index 76304d20e25..5d8cf6afc72 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -278,6 +278,25 @@ runtime::Error Module::set_output( output_tensor.mutable_data_ptr(), output_tensor.nbytes(), output_index); } +runtime::Error Module::set_outputs( + const std::string& method_name, + const std::vector& output_values) { + ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name)); + auto& method = methods_.at(method_name).method; + const auto outputs_size = method->outputs_size(); + ET_CHECK_OR_RETURN_ERROR( + output_values.size() == outputs_size, + InvalidArgument, + "output size: %zu is not equal to method output size: %zu", + output_values.size(), + outputs_size); + for (auto index = 0; index < outputs_size; ++index) { + ET_CHECK_OK_OR_RETURN_ERROR( + set_output(method_name, output_values[index], index)); + } + return runtime::Error::Ok; +} + } // namespace ET_MODULE_NAMESPACE } // namespace extension } // namespace executorch diff --git a/extension/module/module.h b/extension/module/module.h index 9350cdd3026..faed000f711 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -498,6 +498,41 @@ class Module { return set_output("forward", std::move(output_value), output_index); } + /** + * Sets all output tensors for a specific method. + * + * Loads the program and method if needed, and for each output uses + * the provided tensor's data buffer as the method's output buffer. + * + * @param[in] method_name The name of the method. + * @param[in] output_values A vector of EValues to set as the method outputs. + * + * @returns An Error to indicate success or failure. + * + * @note Only Tensor outputs are currently supported for setting. + * @note Will fail for outputs that are memory-planned or constants. + */ + ET_NODISCARD + runtime::Error set_outputs( + const std::string& method_name, + const std::vector& output_values); + + /** + * Sets all output tensors for the "forward" method. + * + * @param[in] output_values A vector of EValues to set as the method outputs. + * + * @returns An Error to indicate success or failure. + * + * @note Only Tensor outputs are currently supported for setting. + * @note Will fail for outputs that are memory-planned or constants. + */ + ET_NODISCARD + inline runtime::Error set_outputs( + const std::vector& output_values) { + return set_outputs("forward", output_values); + } + /** * Retrieves the EventTracer instance being used by the Module. * EventTracer is used for tracking and logging events during the execution diff --git a/extension/module/test/module_test.cpp b/extension/module/test/module_test.cpp index 9623e5a6745..ecc8f1b3250 100644 --- a/extension/module/test/module_test.cpp +++ b/extension/module/test/module_test.cpp @@ -477,6 +477,24 @@ TEST_F(ModuleTest, TestSetOutputInvalidType) { EXPECT_NE(module.set_output(EValue()), Error::Ok); } +TEST_F(ModuleTest, TestSetOutputsCountMismatch) { + Module module(model_path_); + + EXPECT_NE(module.set_outputs(std::vector{}), Error::Ok); +} + +TEST_F(ModuleTest, TestSetOutputsInvalidType) { + Module module(model_path_); + + EXPECT_NE(module.set_outputs({EValue()}), Error::Ok); +} + +TEST_F(ModuleTest, TestSetOutputsMemoryPlanned) { + Module module(model_path_); + + EXPECT_NE(module.set_outputs({empty({1})}), Error::Ok); +} + TEST_F(ModuleTest, TestPTD) { Module module(add_mul_path_, add_mul_data_path_);