Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions extension/module/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,25 @@ runtime::Error Module::set_output(
output_tensor.mutable_data_ptr(), output_tensor.nbytes(), output_index);
}

runtime::Error Module::set_outputs(
const std::string& method_name,
const std::vector<runtime::EValue>& output_values) {
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
auto& method = methods_.at(method_name).method;
const auto outputs_size = method->outputs_size();
ET_CHECK_OR_RETURN_ERROR(
output_values.size() == outputs_size,
InvalidArgument,
"output size: %zu is not equal to method output size: %zu",
output_values.size(),
outputs_size);
for (auto index = 0; index < outputs_size; ++index) {
ET_CHECK_OK_OR_RETURN_ERROR(
set_output(method_name, output_values[index], index));
}
return runtime::Error::Ok;
}

} // namespace ET_MODULE_NAMESPACE
} // namespace extension
} // namespace executorch
35 changes: 35 additions & 0 deletions extension/module/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,41 @@ class Module {
return set_output("forward", std::move(output_value), output_index);
}

/**
* Sets all output tensors for a specific method.
*
* Loads the program and method if needed, and for each output uses
* the provided tensor's data buffer as the method's output buffer.
*
* @param[in] method_name The name of the method.
* @param[in] output_values A vector of EValues to set as the method outputs.
*
* @returns An Error to indicate success or failure.
*
* @note Only Tensor outputs are currently supported for setting.
* @note Will fail for outputs that are memory-planned or constants.
*/
ET_NODISCARD
runtime::Error set_outputs(
const std::string& method_name,
const std::vector<runtime::EValue>& output_values);

/**
* Sets all output tensors for the "forward" method.
*
* @param[in] output_values A vector of EValues to set as the method outputs.
*
* @returns An Error to indicate success or failure.
*
* @note Only Tensor outputs are currently supported for setting.
* @note Will fail for outputs that are memory-planned or constants.
*/
ET_NODISCARD
inline runtime::Error set_outputs(
const std::vector<runtime::EValue>& output_values) {
return set_outputs("forward", output_values);
}

/**
* Retrieves the EventTracer instance being used by the Module.
* EventTracer is used for tracking and logging events during the execution
Expand Down
18 changes: 18 additions & 0 deletions extension/module/test/module_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,24 @@ TEST_F(ModuleTest, TestSetOutputInvalidType) {
EXPECT_NE(module.set_output(EValue()), Error::Ok);
}

TEST_F(ModuleTest, TestSetOutputsCountMismatch) {
Module module(model_path_);

EXPECT_NE(module.set_outputs(std::vector<EValue>{}), Error::Ok);
}

TEST_F(ModuleTest, TestSetOutputsInvalidType) {
Module module(model_path_);

EXPECT_NE(module.set_outputs({EValue()}), Error::Ok);
}

TEST_F(ModuleTest, TestSetOutputsMemoryPlanned) {
Module module(model_path_);

EXPECT_NE(module.set_outputs({empty({1})}), Error::Ok);
}

TEST_F(ModuleTest, TestPTD) {
Module module(add_mul_path_, add_mul_data_path_);

Expand Down
Loading