diff --git a/extension/pybindings/portable_lib.py b/extension/pybindings/portable_lib.py index 758e41545d1..de01797a648 100644 --- a/extension/pybindings/portable_lib.py +++ b/extension/pybindings/portable_lib.py @@ -44,10 +44,13 @@ _load_for_executorch, # noqa: F401 _load_for_executorch_from_buffer, # noqa: F401 _load_for_executorch_from_bundled_program, # noqa: F401 + _load_program, # noqa: F401 + _load_program_from_buffer, # noqa: F401 _reset_profile_results, # noqa: F401 _unsafe_reset_threadpool, # noqa: F401 BundledModule, # noqa: F401 ExecuTorchModule, # noqa: F401 + ExecuTorchProgram, # noqa: F401 MethodMeta, # noqa: F401 Verification, # noqa: F401 ) diff --git a/extension/pybindings/pybindings.cpp b/extension/pybindings/pybindings.cpp index db0871657f6..b71226fd2c5 100644 --- a/extension/pybindings/pybindings.cpp +++ b/extension/pybindings/pybindings.cpp @@ -965,6 +965,89 @@ struct PyModule final { } }; +inline std::unique_ptr loader_from_buffer( + const void* ptr, + size_t ptr_len) { + return std::make_unique(ptr, ptr_len); +} + +inline std::unique_ptr loader_from_file(const std::string& path) { + Result res = MmapDataLoader::from( + path.c_str(), MmapDataLoader::MlockConfig::UseMlockIgnoreErrors); + THROW_IF_ERROR( + res.error(), + "Failed to create MmapDataLoader from file %s, error: 0x:%" PRIx32, + path.c_str(), + static_cast(res.error())); + + return std::make_unique(std::move(res.get())); +} + +inline std::unique_ptr load_program( + DataLoader* loader, + Program::Verification program_verification) { + Result res = Program::load(loader, program_verification); + THROW_IF_ERROR( + res.error(), + "Failed to load program, error: 0x:%" PRIx32, + static_cast(res.error())); + return std::make_unique(std::move(res.get())); +} + +struct PyProgram final { + explicit PyProgram( + const py::bytes& buffer, + Program::Verification program_verification = + Program::Verification::Minimal) + : loader_(loader_from_buffer( + buffer.cast().data(), + py::len(buffer))), + program_(load_program(loader_.get(), program_verification)) {} + + explicit PyProgram( + const std::string& path, + Program::Verification program_verification = + Program::Verification::Minimal) + : loader_(loader_from_file(path)), + program_(load_program(loader_.get(), program_verification)) {} + + static std::unique_ptr load_from_buffer( + const py::bytes& buffer, + Program::Verification program_verification = + Program::Verification::Minimal) { + return std::make_unique(buffer, program_verification); + } + + static std::unique_ptr load_from_file( + const std::string& path, + Program::Verification program_verification = + Program::Verification::Minimal) { + return std::make_unique(path, program_verification); + } + + PyProgram(const PyProgram&) = delete; + PyProgram& operator=(const PyProgram&) = delete; + PyProgram(PyProgram&&) = default; + PyProgram& operator=(PyProgram&&) = default; + + size_t num_methods() const { + return program_->num_methods(); + } + + std::string get_method_name(size_t method_index) const { + Result res = program_->get_method_name(method_index); + THROW_IF_ERROR( + res.error(), + "Failed get method name, error: 0x:%" PRIx32, + static_cast(res.error())); + return std::string(res.get()); + } + + private: + std::unique_ptr loader_; + std::unique_ptr program_; +}; + void create_profile_block(const std::string& name) { EXECUTORCH_PROFILE_CREATE_BLOCK(name.c_str()); } @@ -1151,6 +1234,26 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) { py::arg("index"), call_guard) .def("__repr__", &PyMethodMeta::repr, call_guard); + + m.def( + "_load_program", + &PyProgram::load_from_file, + py::arg("path"), + py::arg("program_verification") = Program::Verification::Minimal, + call_guard); + m.def( + "_load_program_from_buffer", + &PyProgram::load_from_buffer, + py::arg("buffer"), + py::arg("program_verification") = Program::Verification::Minimal, + call_guard); + py::class_(m, "ExecuTorchProgram") + .def("num_methods", &PyProgram::num_methods, call_guard) + .def( + "get_method_name", + &PyProgram::get_method_name, + py::arg("method_index"), + call_guard); } namespace { diff --git a/extension/pybindings/test/make_test.py b/extension/pybindings/test/make_test.py index f3087d112ed..f0f941d0eb3 100644 --- a/extension/pybindings/test/make_test.py +++ b/extension/pybindings/test/make_test.py @@ -168,6 +168,7 @@ def make_test( # noqa: C901 subfunction of wrapper. """ load_fn: Callable = runtime._load_for_executorch_from_buffer + load_prog_fn: Callable = runtime._load_program_from_buffer def wrapper(tester: unittest.TestCase) -> None: ######### TEST CASES ######### @@ -474,6 +475,36 @@ def test_unsupported_input_type(tester): # This should raise a Python error, not hit a fatal assert in the C++ code. tester.assertRaises(RuntimeError, executorch_module, inputs) + def test_program_methods_one(tester): + # Create an ExecuTorch program from ModuleAdd. + exported_program, _ = create_program(ModuleAdd()) + + # Use pybindings to load the program. + executorch_program = load_prog_fn(exported_program.buffer) + + tester.assertEqual(executorch_program.num_methods(), 1) + tester.assertEqual(executorch_program.get_method_name(0), "forward") + + def test_program_methods_multi(tester): + # Create an ExecuTorch program from ModuleMulti. + exported_program, _ = create_program(ModuleMulti()) + + # Use pybindings to load the program. + executorch_program = load_prog_fn(exported_program.buffer) + + tester.assertEqual(executorch_program.num_methods(), 2) + tester.assertEqual(executorch_program.get_method_name(0), "forward") + tester.assertEqual(executorch_program.get_method_name(1), "forward2") + + def test_program_method_index_out_of_bounds(tester): + # Create an ExecuTorch program from ModuleMulti. + exported_program, _ = create_program(ModuleMulti()) + + # Use pybindings to load the program. + executorch_program = load_prog_fn(exported_program.buffer) + + tester.assertRaises(RuntimeError, executorch_program.get_method_name, 2) + ######### RUN TEST CASES ######### test_e2e(tester) test_multiple_entry(tester) @@ -490,5 +521,8 @@ def test_unsupported_input_type(tester): test_bad_name(tester) test_verification_config(tester) test_unsupported_input_type(tester) + test_program_methods_one(tester) + test_program_methods_multi(tester) + test_program_method_index_out_of_bounds(tester) return wrapper