Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions extension/pybindings/portable_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,13 @@
_load_for_executorch, # noqa: F401
_load_for_executorch_from_buffer, # noqa: F401
_load_for_executorch_from_bundled_program, # noqa: F401
_load_program, # noqa: F401
_load_program_from_buffer, # noqa: F401
_reset_profile_results, # noqa: F401
_unsafe_reset_threadpool, # noqa: F401
BundledModule, # noqa: F401
ExecuTorchModule, # noqa: F401
ExecuTorchProgram, # noqa: F401
MethodMeta, # noqa: F401
Verification, # noqa: F401
)
Expand Down
103 changes: 103 additions & 0 deletions extension/pybindings/pybindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,89 @@ struct PyModule final {
}
};

inline std::unique_ptr<DataLoader> loader_from_buffer(
const void* ptr,
size_t ptr_len) {
return std::make_unique<BufferDataLoader>(ptr, ptr_len);
}

inline std::unique_ptr<DataLoader> loader_from_file(const std::string& path) {
Result<MmapDataLoader> res = MmapDataLoader::from(
path.c_str(), MmapDataLoader::MlockConfig::UseMlockIgnoreErrors);
THROW_IF_ERROR(
res.error(),
"Failed to create MmapDataLoader from file %s, error: 0x:%" PRIx32,
path.c_str(),
static_cast<uint32_t>(res.error()));

return std::make_unique<MmapDataLoader>(std::move(res.get()));
}

inline std::unique_ptr<Program> load_program(
DataLoader* loader,
Program::Verification program_verification) {
Result<Program> res = Program::load(loader, program_verification);
THROW_IF_ERROR(
res.error(),
"Failed to load program, error: 0x:%" PRIx32,
static_cast<uint32_t>(res.error()));
return std::make_unique<Program>(std::move(res.get()));
}

struct PyProgram final {
explicit PyProgram(
const py::bytes& buffer,
Program::Verification program_verification =
Program::Verification::Minimal)
: loader_(loader_from_buffer(
buffer.cast<std::string_view>().data(),
py::len(buffer))),
program_(load_program(loader_.get(), program_verification)) {}

explicit PyProgram(
const std::string& path,
Program::Verification program_verification =
Program::Verification::Minimal)
: loader_(loader_from_file(path)),
program_(load_program(loader_.get(), program_verification)) {}

static std::unique_ptr<PyProgram> load_from_buffer(
const py::bytes& buffer,
Program::Verification program_verification =
Program::Verification::Minimal) {
return std::make_unique<PyProgram>(buffer, program_verification);
}

static std::unique_ptr<PyProgram> load_from_file(
const std::string& path,
Program::Verification program_verification =
Program::Verification::Minimal) {
return std::make_unique<PyProgram>(path, program_verification);
}

PyProgram(const PyProgram&) = delete;
PyProgram& operator=(const PyProgram&) = delete;
PyProgram(PyProgram&&) = default;
PyProgram& operator=(PyProgram&&) = default;

size_t num_methods() const {
return program_->num_methods();
}

std::string get_method_name(size_t method_index) const {
Result<const char*> res = program_->get_method_name(method_index);
THROW_IF_ERROR(
res.error(),
"Failed get method name, error: 0x:%" PRIx32,
static_cast<uint32_t>(res.error()));
return std::string(res.get());
}

private:
std::unique_ptr<DataLoader> loader_;
std::unique_ptr<Program> program_;
};

void create_profile_block(const std::string& name) {
EXECUTORCH_PROFILE_CREATE_BLOCK(name.c_str());
}
Expand Down Expand Up @@ -1151,6 +1234,26 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
py::arg("index"),
call_guard)
.def("__repr__", &PyMethodMeta::repr, call_guard);

m.def(
"_load_program",
&PyProgram::load_from_file,
py::arg("path"),
py::arg("program_verification") = Program::Verification::Minimal,
call_guard);
m.def(
"_load_program_from_buffer",
&PyProgram::load_from_buffer,
py::arg("buffer"),
py::arg("program_verification") = Program::Verification::Minimal,
call_guard);
py::class_<PyProgram>(m, "ExecuTorchProgram")
.def("num_methods", &PyProgram::num_methods, call_guard)
.def(
"get_method_name",
&PyProgram::get_method_name,
py::arg("method_index"),
call_guard);
}

namespace {
Expand Down
34 changes: 34 additions & 0 deletions extension/pybindings/test/make_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def make_test( # noqa: C901
subfunction of wrapper.
"""
load_fn: Callable = runtime._load_for_executorch_from_buffer
load_prog_fn: Callable = runtime._load_program_from_buffer

def wrapper(tester: unittest.TestCase) -> None:
######### TEST CASES #########
Expand Down Expand Up @@ -474,6 +475,36 @@ def test_unsupported_input_type(tester):
# This should raise a Python error, not hit a fatal assert in the C++ code.
tester.assertRaises(RuntimeError, executorch_module, inputs)

def test_program_methods_one(tester):
# Create an ExecuTorch program from ModuleAdd.
exported_program, _ = create_program(ModuleAdd())

# Use pybindings to load the program.
executorch_program = load_prog_fn(exported_program.buffer)

tester.assertEqual(executorch_program.num_methods(), 1)
tester.assertEqual(executorch_program.get_method_name(0), "forward")

def test_program_methods_multi(tester):
# Create an ExecuTorch program from ModuleMulti.
exported_program, _ = create_program(ModuleMulti())

# Use pybindings to load the program.
executorch_program = load_prog_fn(exported_program.buffer)

tester.assertEqual(executorch_program.num_methods(), 2)
tester.assertEqual(executorch_program.get_method_name(0), "forward")
tester.assertEqual(executorch_program.get_method_name(1), "forward2")

def test_program_method_index_out_of_bounds(tester):
# Create an ExecuTorch program from ModuleMulti.
exported_program, _ = create_program(ModuleMulti())

# Use pybindings to load the program.
executorch_program = load_prog_fn(exported_program.buffer)

tester.assertRaises(RuntimeError, executorch_program.get_method_name, 2)

######### RUN TEST CASES #########
test_e2e(tester)
test_multiple_entry(tester)
Expand All @@ -490,5 +521,8 @@ def test_unsupported_input_type(tester):
test_bad_name(tester)
test_verification_config(tester)
test_unsupported_input_type(tester)
test_program_methods_one(tester)
test_program_methods_multi(tester)
test_program_method_index_out_of_bounds(tester)

return wrapper
Loading