From d2caf751abe5990573cb400a49b2604923973a47 Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Fri, 22 Aug 2025 15:28:33 -0700 Subject: [PATCH] Add get_output API. (#13610) Summary: . Reviewed By: JacobSzwejbka Differential Revision: D80845633 --- extension/module/module.cpp | 24 +++++++++++++ extension/module/module.h | 50 +++++++++++++++++++++++++++ extension/module/test/module_test.cpp | 27 +++++++++++++++ runtime/executor/method.cpp | 2 +- 4 files changed, 102 insertions(+), 1 deletion(-) diff --git a/extension/module/module.cpp b/extension/module/module.cpp index 5d8cf6afc72..4b82dbf4954 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -297,6 +297,30 @@ runtime::Error Module::set_outputs( return runtime::Error::Ok; } +runtime::Result> Module::get_outputs( + const std::string& method_name) { + ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name)); + auto& method = methods_.at(method_name).method; + const auto outputs_size = method->outputs_size(); + std::vector outputs(outputs_size); + ET_CHECK_OK_OR_RETURN_ERROR( + method->get_outputs(outputs.data(), outputs_size)); + return outputs; +} + +runtime::Result Module::get_output( + const std::string& method_name, + size_t output_index) { + ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name)); + auto& method = methods_.at(method_name).method; + ET_CHECK_OR_RETURN_ERROR( + output_index < method->outputs_size(), + InvalidArgument, + "output index: %zu is out of range", + output_index); + return method->get_output(output_index); +} + } // namespace ET_MODULE_NAMESPACE } // namespace extension } // namespace executorch diff --git a/extension/module/module.h b/extension/module/module.h index faed000f711..37fd78f6fdd 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -533,6 +533,56 @@ class Module { return set_outputs("forward", output_values); } + /** + * Retrieve all current output values of a specific method without executing + * it. Loads the program and method before retrieval if needed. + * + * @param[in] method_name The name of the method. + * + * @returns A Result containing the vector of output values, or an error. + */ + ET_NODISCARD + runtime::Result> get_outputs( + const std::string& method_name); + + /** + * Retrieve all current output values of the "forward" method without + * executing it. Loads the program and method before retrieval if needed. + * + * @returns A Result containing the vector of output values, or an error. + */ + ET_NODISCARD + inline runtime::Result> get_outputs() { + return get_outputs("forward"); + } + + /** + * Retrieve a single current output value of a specific method without + * executing it. Loads the program and method before retrieval if needed. + * + * @param[in] method_name The name of the method. + * @param[in] output_index Zero-based index of the output to retrieve. + * + * @returns A Result containing the requested output value, or an error. + */ + ET_NODISCARD + runtime::Result get_output( + const std::string& method_name, + size_t output_index = 0); + + /** + * Retrieve a single current output value of the "forward" method without + * executing it. Loads the program and method before retrieval if needed. + * + * @param[in] output_index Zero-based index of the output to retrieve. + * + * @returns A Result containing the requested output value, or an error. + */ + ET_NODISCARD + inline runtime::Result get_output(size_t output_index = 0) { + return get_output("forward", output_index); + } + /** * Retrieves the EventTracer instance being used by the Module. * EventTracer is used for tracking and logging events during the execution diff --git a/extension/module/test/module_test.cpp b/extension/module/test/module_test.cpp index ecc8f1b3250..1c9fc5628ba 100644 --- a/extension/module/test/module_test.cpp +++ b/extension/module/test/module_test.cpp @@ -495,6 +495,33 @@ TEST_F(ModuleTest, TestSetOutputsMemoryPlanned) { EXPECT_NE(module.set_outputs({empty({1})}), Error::Ok); } +TEST_F(ModuleTest, TestGetOutputAndGetOutputs) { + Module module(model_path_); + + auto tensor = make_tensor_ptr({2, 2}, {1.f, 2.f, 3.f, 4.f}); + + ASSERT_EQ(module.forward({tensor, tensor, 1.0}).error(), Error::Ok); + + const auto single = module.get_output(); + EXPECT_EQ(single.error(), Error::Ok); + const auto expected = make_tensor_ptr({2, 2}, {2.f, 4.f, 6.f, 8.f}); + EXPECT_TENSOR_CLOSE(single->toTensor(), *expected.get()); + + const auto all = module.get_outputs(); + EXPECT_EQ(all.error(), Error::Ok); + ASSERT_EQ(all->size(), 1); + EXPECT_TENSOR_CLOSE(all->at(0).toTensor(), *expected.get()); +} + +TEST_F(ModuleTest, TestGetOutputInvalidIndex) { + Module module(model_path_); + + ASSERT_EQ(module.load_method("forward"), Error::Ok); + + const auto bad = module.get_output("forward", 99); + EXPECT_NE(bad.error(), Error::Ok); +} + TEST_F(ModuleTest, TestPTD) { Module module(add_mul_path_, add_mul_data_path_); diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index 238e150e7bd..e8f3c471b8f 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -1278,7 +1278,7 @@ ET_NODISCARD Error Method::get_outputs(EValue* output_evalues, size_t length) { InvalidArgument, "The given array is not large enough to hold all outputs."); for (size_t i = 0; i < n_output; ++i) { - output_evalues[i] = values_[get_output_index(i)]; + output_evalues[i] = get_output(i); } for (size_t i = n_output; i < length; ++i) { output_evalues[i] = EValue();