Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions extension/module/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,6 @@ Error Module::load(const Program::Verification verification) {
return Error::Ok;
}

bool Module::is_loaded() const {
return program_ != nullptr;
}

Result<std::unordered_set<std::string>> Module::method_names() {
ET_CHECK_OK_OR_RETURN_ERROR(load());
const auto method_count = program_->num_methods();
Expand Down Expand Up @@ -181,10 +177,6 @@ Error Module::load_method(const std::string& method_name) {
return Error::Ok;
}

bool Module::is_method_loaded(const std::string& method_name) const {
return methods_.count(method_name);
}

Result<MethodMeta> Module::method_meta(const std::string& method_name) {
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
return methods_.at(method_name).method->method_meta();
Expand Down
50 changes: 27 additions & 23 deletions extension/module/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,17 @@ class Module final {
*
* @returns true if the program is loaded, false otherwise.
*/
bool is_loaded() const;
inline bool is_loaded() const {
return program_ != nullptr;
}

/**
* Get the program. The data loader used by the program is guaranteed to be
* valid for the lifetime of the program.
*
* @returns Shared pointer to the program or nullptr if it's not yet loaded.
*/
std::shared_ptr<::executorch::runtime::Program> program() const {
inline std::shared_ptr<::executorch::runtime::Program> program() const {
return program_;
}

Expand Down Expand Up @@ -151,7 +153,9 @@ class Module final {
* @returns true if the method specified by method_name is loaded, false
* otherwise.
*/
bool is_method_loaded(const std::string& method_name) const;
inline bool is_method_loaded(const std::string& method_name) const {
return methods_.count(method_name);
}

/**
* Get a method metadata struct by method name.
Expand Down Expand Up @@ -191,8 +195,8 @@ class Module final {
* @returns A Result object containing either a vector of output values
* from the method or an error to indicate failure.
*/
ET_NODISCARD
::executorch::runtime::Result<std::vector<::executorch::runtime::EValue>>
ET_NODISCARD inline ::executorch::runtime::Result<
std::vector<::executorch::runtime::EValue>>
execute(
const std::string& method_name,
const ::executorch::runtime::EValue& input) {
Expand All @@ -209,8 +213,8 @@ class Module final {
* @returns A Result object containing either a vector of output values
* from the method or an error to indicate failure.
*/
ET_NODISCARD
::executorch::runtime::Result<std::vector<::executorch::runtime::EValue>>
ET_NODISCARD inline ::executorch::runtime::Result<
std::vector<::executorch::runtime::EValue>>
execute(const std::string& method_name) {
return execute(method_name, std::vector<::executorch::runtime::EValue>{});
}
Expand All @@ -225,9 +229,9 @@ class Module final {
* @returns A Result object containing either the first output value from the
* method or an error to indicate failure.
*/
ET_NODISCARD
::executorch::runtime::Result<::executorch::runtime::EValue> get(
const std::string& method_name,
ET_NODISCARD inline ::executorch::runtime::Result<
::executorch::runtime::EValue>
get(const std::string& method_name,
const std::vector<::executorch::runtime::EValue>& input) {
auto result = ET_UNWRAP(execute(method_name, input));
if (result.empty()) {
Expand All @@ -246,9 +250,9 @@ class Module final {
* @returns A Result object containing either the first output value from the
* method or an error to indicate failure.
*/
ET_NODISCARD
::executorch::runtime::Result<::executorch::runtime::EValue> get(
const std::string& method_name,
ET_NODISCARD inline ::executorch::runtime::Result<
::executorch::runtime::EValue>
get(const std::string& method_name,
const ::executorch::runtime::EValue& input) {
return get(method_name, std::vector<::executorch::runtime::EValue>{input});
}
Expand All @@ -262,9 +266,9 @@ class Module final {
* @returns A Result object containing either the first output value from the
* method or an error to indicate failure.
*/
ET_NODISCARD
::executorch::runtime::Result<::executorch::runtime::EValue> get(
const std::string& method_name) {
ET_NODISCARD inline ::executorch::runtime::Result<
::executorch::runtime::EValue>
get(const std::string& method_name) {
return get(method_name, std::vector<::executorch::runtime::EValue>{});
}

Expand All @@ -277,8 +281,8 @@ class Module final {
* @returns A Result object containing either a vector of output values
* from the 'forward' method or an error to indicate failure.
*/
ET_NODISCARD
::executorch::runtime::Result<std::vector<::executorch::runtime::EValue>>
ET_NODISCARD inline ::executorch::runtime::Result<
std::vector<::executorch::runtime::EValue>>
forward(const std::vector<::executorch::runtime::EValue>& input) {
return execute("forward", input);
}
Expand All @@ -292,8 +296,8 @@ class Module final {
* @returns A Result object containing either a vector of output values
* from the 'forward' method or an error to indicate failure.
*/
ET_NODISCARD
::executorch::runtime::Result<std::vector<::executorch::runtime::EValue>>
ET_NODISCARD inline ::executorch::runtime::Result<
std::vector<::executorch::runtime::EValue>>
forward(const ::executorch::runtime::EValue& input) {
return forward(std::vector<::executorch::runtime::EValue>{input});
}
Expand All @@ -305,8 +309,8 @@ class Module final {
* @returns A Result object containing either a vector of output values
* from the 'forward' method or an error to indicate failure.
*/
ET_NODISCARD
::executorch::runtime::Result<std::vector<::executorch::runtime::EValue>>
ET_NODISCARD inline ::executorch::runtime::Result<
std::vector<::executorch::runtime::EValue>>
forward() {
return forward(std::vector<::executorch::runtime::EValue>{});
}
Expand All @@ -319,7 +323,7 @@ class Module final {
* @returns A pointer to the EventTracer instance. Returns nullptr if no
* EventTracer is set.
*/
::executorch::runtime::EventTracer* event_tracer() const {
inline ::executorch::runtime::EventTracer* event_tracer() const {
return event_tracer_.get();
}

Expand Down