Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions extension/module/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,30 @@ runtime::Error Module::set_outputs(
return runtime::Error::Ok;
}

runtime::Result<std::vector<runtime::EValue>> Module::get_outputs(
const std::string& method_name) {
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
auto& method = methods_.at(method_name).method;
const auto outputs_size = method->outputs_size();
std::vector<runtime::EValue> outputs(outputs_size);
ET_CHECK_OK_OR_RETURN_ERROR(
method->get_outputs(outputs.data(), outputs_size));
return outputs;
}

runtime::Result<runtime::EValue> Module::get_output(
const std::string& method_name,
size_t output_index) {
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
auto& method = methods_.at(method_name).method;
ET_CHECK_OR_RETURN_ERROR(
output_index < method->outputs_size(),
InvalidArgument,
"output index: %zu is out of range",
output_index);
return method->get_output(output_index);
}

} // namespace ET_MODULE_NAMESPACE
} // namespace extension
} // namespace executorch
50 changes: 50 additions & 0 deletions extension/module/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,56 @@ class Module {
return set_outputs("forward", output_values);
}

/**
* Retrieve all current output values of a specific method without executing
* it. Loads the program and method before retrieval if needed.
*
* @param[in] method_name The name of the method.
*
* @returns A Result containing the vector of output values, or an error.
*/
ET_NODISCARD
runtime::Result<std::vector<runtime::EValue>> get_outputs(
const std::string& method_name);

/**
* Retrieve all current output values of the "forward" method without
* executing it. Loads the program and method before retrieval if needed.
*
* @returns A Result containing the vector of output values, or an error.
*/
ET_NODISCARD
inline runtime::Result<std::vector<runtime::EValue>> get_outputs() {
return get_outputs("forward");
}

/**
* Retrieve a single current output value of a specific method without
* executing it. Loads the program and method before retrieval if needed.
*
* @param[in] method_name The name of the method.
* @param[in] output_index Zero-based index of the output to retrieve.
*
* @returns A Result containing the requested output value, or an error.
*/
ET_NODISCARD
runtime::Result<runtime::EValue> get_output(
const std::string& method_name,
size_t output_index = 0);

/**
* Retrieve a single current output value of the "forward" method without
* executing it. Loads the program and method before retrieval if needed.
*
* @param[in] output_index Zero-based index of the output to retrieve.
*
* @returns A Result containing the requested output value, or an error.
*/
ET_NODISCARD
inline runtime::Result<runtime::EValue> get_output(size_t output_index = 0) {
return get_output("forward", output_index);
}

/**
* Retrieves the EventTracer instance being used by the Module.
* EventTracer is used for tracking and logging events during the execution
Expand Down
27 changes: 27 additions & 0 deletions extension/module/test/module_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,33 @@ TEST_F(ModuleTest, TestSetOutputsMemoryPlanned) {
EXPECT_NE(module.set_outputs({empty({1})}), Error::Ok);
}

TEST_F(ModuleTest, TestGetOutputAndGetOutputs) {
Module module(model_path_);

auto tensor = make_tensor_ptr({2, 2}, {1.f, 2.f, 3.f, 4.f});

ASSERT_EQ(module.forward({tensor, tensor, 1.0}).error(), Error::Ok);

const auto single = module.get_output();
EXPECT_EQ(single.error(), Error::Ok);
const auto expected = make_tensor_ptr({2, 2}, {2.f, 4.f, 6.f, 8.f});
EXPECT_TENSOR_CLOSE(single->toTensor(), *expected.get());

const auto all = module.get_outputs();
EXPECT_EQ(all.error(), Error::Ok);
ASSERT_EQ(all->size(), 1);
EXPECT_TENSOR_CLOSE(all->at(0).toTensor(), *expected.get());
}

TEST_F(ModuleTest, TestGetOutputInvalidIndex) {
Module module(model_path_);

ASSERT_EQ(module.load_method("forward"), Error::Ok);

const auto bad = module.get_output("forward", 99);
EXPECT_NE(bad.error(), Error::Ok);
}

TEST_F(ModuleTest, TestPTD) {
Module module(add_mul_path_, add_mul_data_path_);

Expand Down
2 changes: 1 addition & 1 deletion runtime/executor/method.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1278,7 +1278,7 @@ ET_NODISCARD Error Method::get_outputs(EValue* output_evalues, size_t length) {
InvalidArgument,
"The given array is not large enough to hold all outputs.");
for (size_t i = 0; i < n_output; ++i) {
output_evalues[i] = values_[get_output_index(i)];
output_evalues[i] = get_output(i);
}
for (size_t i = n_output; i < length; ++i) {
output_evalues[i] = EValue();
Expand Down
Loading