From fa7e447b5f6493032406df0df4ea22edc15288c7 Mon Sep 17 00:00:00 2001 From: Conan Jeffrey Truong Date: Tue, 1 Jul 2025 14:18:54 -0700 Subject: [PATCH] Add Pybindings for Program.h/cpp (#12016) Summary: Today our python apis in executorch.runtime are implemented off of extension/pybindings which only offers a module api. We would like to migrate to having the lower level ET api exposed to python directly and then writing the module api in python. The first step to this is adding pybindings for Program. Bindings for the class Program and its methods num_methods and get_method_name were added. Test Plan: Tests were added to `extension/pybindings/test/make_test.py` 1. test_program_methods_one -- verifies num_methods and get_method_name works with one method 2. test_program_methods_multi -- verifies num_methods and get_method_name works with multiple methods 3. test_program_method_index_out_of_bounds -- verifies get_method_name raises a runtime error if index is out of bounds Rollback Plan: Reviewed By: JacobSzwejbka Differential Revision: D77388495 Pulled By: Conarnar --- extension/pybindings/portable_lib.py | 3 + extension/pybindings/pybindings.cpp | 103 +++++++++++++++++++++++++ extension/pybindings/test/make_test.py | 34 ++++++++ 3 files changed, 140 insertions(+) diff --git a/extension/pybindings/portable_lib.py b/extension/pybindings/portable_lib.py index 758e41545d1..de01797a648 100644 --- a/extension/pybindings/portable_lib.py +++ b/extension/pybindings/portable_lib.py @@ -44,10 +44,13 @@ _load_for_executorch, # noqa: F401 _load_for_executorch_from_buffer, # noqa: F401 _load_for_executorch_from_bundled_program, # noqa: F401 + _load_program, # noqa: F401 + _load_program_from_buffer, # noqa: F401 _reset_profile_results, # noqa: F401 _unsafe_reset_threadpool, # noqa: F401 BundledModule, # noqa: F401 ExecuTorchModule, # noqa: F401 + ExecuTorchProgram, # noqa: F401 MethodMeta, # noqa: F401 Verification, # noqa: F401 ) diff --git a/extension/pybindings/pybindings.cpp b/extension/pybindings/pybindings.cpp index db0871657f6..b71226fd2c5 100644 --- a/extension/pybindings/pybindings.cpp +++ b/extension/pybindings/pybindings.cpp @@ -965,6 +965,89 @@ struct PyModule final { } }; +inline std::unique_ptr loader_from_buffer( + const void* ptr, + size_t ptr_len) { + return std::make_unique(ptr, ptr_len); +} + +inline std::unique_ptr loader_from_file(const std::string& path) { + Result res = MmapDataLoader::from( + path.c_str(), MmapDataLoader::MlockConfig::UseMlockIgnoreErrors); + THROW_IF_ERROR( + res.error(), + "Failed to create MmapDataLoader from file %s, error: 0x:%" PRIx32, + path.c_str(), + static_cast(res.error())); + + return std::make_unique(std::move(res.get())); +} + +inline std::unique_ptr load_program( + DataLoader* loader, + Program::Verification program_verification) { + Result res = Program::load(loader, program_verification); + THROW_IF_ERROR( + res.error(), + "Failed to load program, error: 0x:%" PRIx32, + static_cast(res.error())); + return std::make_unique(std::move(res.get())); +} + +struct PyProgram final { + explicit PyProgram( + const py::bytes& buffer, + Program::Verification program_verification = + Program::Verification::Minimal) + : loader_(loader_from_buffer( + buffer.cast().data(), + py::len(buffer))), + program_(load_program(loader_.get(), program_verification)) {} + + explicit PyProgram( + const std::string& path, + Program::Verification program_verification = + Program::Verification::Minimal) + : loader_(loader_from_file(path)), + program_(load_program(loader_.get(), program_verification)) {} + + static std::unique_ptr load_from_buffer( + const py::bytes& buffer, + Program::Verification program_verification = + Program::Verification::Minimal) { + return std::make_unique(buffer, program_verification); + } + + static std::unique_ptr load_from_file( + const std::string& path, + Program::Verification program_verification = + Program::Verification::Minimal) { + return std::make_unique(path, program_verification); + } + + PyProgram(const PyProgram&) = delete; + PyProgram& operator=(const PyProgram&) = delete; + PyProgram(PyProgram&&) = default; + PyProgram& operator=(PyProgram&&) = default; + + size_t num_methods() const { + return program_->num_methods(); + } + + std::string get_method_name(size_t method_index) const { + Result res = program_->get_method_name(method_index); + THROW_IF_ERROR( + res.error(), + "Failed get method name, error: 0x:%" PRIx32, + static_cast(res.error())); + return std::string(res.get()); + } + + private: + std::unique_ptr loader_; + std::unique_ptr program_; +}; + void create_profile_block(const std::string& name) { EXECUTORCH_PROFILE_CREATE_BLOCK(name.c_str()); } @@ -1151,6 +1234,26 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) { py::arg("index"), call_guard) .def("__repr__", &PyMethodMeta::repr, call_guard); + + m.def( + "_load_program", + &PyProgram::load_from_file, + py::arg("path"), + py::arg("program_verification") = Program::Verification::Minimal, + call_guard); + m.def( + "_load_program_from_buffer", + &PyProgram::load_from_buffer, + py::arg("buffer"), + py::arg("program_verification") = Program::Verification::Minimal, + call_guard); + py::class_(m, "ExecuTorchProgram") + .def("num_methods", &PyProgram::num_methods, call_guard) + .def( + "get_method_name", + &PyProgram::get_method_name, + py::arg("method_index"), + call_guard); } namespace { diff --git a/extension/pybindings/test/make_test.py b/extension/pybindings/test/make_test.py index f3087d112ed..f0f941d0eb3 100644 --- a/extension/pybindings/test/make_test.py +++ b/extension/pybindings/test/make_test.py @@ -168,6 +168,7 @@ def make_test( # noqa: C901 subfunction of wrapper. """ load_fn: Callable = runtime._load_for_executorch_from_buffer + load_prog_fn: Callable = runtime._load_program_from_buffer def wrapper(tester: unittest.TestCase) -> None: ######### TEST CASES ######### @@ -474,6 +475,36 @@ def test_unsupported_input_type(tester): # This should raise a Python error, not hit a fatal assert in the C++ code. tester.assertRaises(RuntimeError, executorch_module, inputs) + def test_program_methods_one(tester): + # Create an ExecuTorch program from ModuleAdd. + exported_program, _ = create_program(ModuleAdd()) + + # Use pybindings to load the program. + executorch_program = load_prog_fn(exported_program.buffer) + + tester.assertEqual(executorch_program.num_methods(), 1) + tester.assertEqual(executorch_program.get_method_name(0), "forward") + + def test_program_methods_multi(tester): + # Create an ExecuTorch program from ModuleMulti. + exported_program, _ = create_program(ModuleMulti()) + + # Use pybindings to load the program. + executorch_program = load_prog_fn(exported_program.buffer) + + tester.assertEqual(executorch_program.num_methods(), 2) + tester.assertEqual(executorch_program.get_method_name(0), "forward") + tester.assertEqual(executorch_program.get_method_name(1), "forward2") + + def test_program_method_index_out_of_bounds(tester): + # Create an ExecuTorch program from ModuleMulti. + exported_program, _ = create_program(ModuleMulti()) + + # Use pybindings to load the program. + executorch_program = load_prog_fn(exported_program.buffer) + + tester.assertRaises(RuntimeError, executorch_program.get_method_name, 2) + ######### RUN TEST CASES ######### test_e2e(tester) test_multiple_entry(tester) @@ -490,5 +521,8 @@ def test_unsupported_input_type(tester): test_bad_name(tester) test_verification_config(tester) test_unsupported_input_type(tester) + test_program_methods_one(tester) + test_program_methods_multi(tester) + test_program_method_index_out_of_bounds(tester) return wrapper