diff --git a/extension/pybindings/pybindings.cpp b/extension/pybindings/pybindings.cpp index a4a015cc879..a896a4bde36 100644 --- a/extension/pybindings/pybindings.cpp +++ b/extension/pybindings/pybindings.cpp @@ -161,10 +161,24 @@ void setup_output_storage( inline std::unique_ptr load_module_from_buffer( const void* ptr, size_t ptr_len, + std::optional data_map_ptr, + std::optional data_map_len, std::unique_ptr event_tracer, Program::Verification program_verification) { EXECUTORCH_SCOPE_PROF("load_module_from_buffer"); auto loader = std::make_unique(ptr, ptr_len); + + if (data_map_ptr.has_value() && data_map_len.has_value()) { + auto data_map_loader = std::make_unique( + data_map_ptr.value(), data_map_len.value()); + return std::make_unique( + std::move(loader), + nullptr, // memory_allocator + nullptr, // temp_allocator + std::move(event_tracer), // event_tracer + std::move(data_map_loader)); // data_map_loader + } + return std::make_unique( std::move(loader), nullptr, // memory_allocator @@ -504,6 +518,7 @@ struct PyMethodMeta final { struct PyModule final { explicit PyModule( const py::bytes& buffer, + std::optional data_map_buffer, bool enable_etdump, size_t debug_buffer_size = 0, Program::Verification program_verification = @@ -512,12 +527,21 @@ struct PyModule final { module_(load_module_from_buffer( buffer.cast().data(), py::len(buffer), + data_map_buffer.has_value() + ? std::optional( + data_map_buffer.value().cast().data()) + : std::nullopt, + data_map_buffer.has_value() + ? std::optional(py::len(data_map_buffer.value())) + : std::nullopt, setup_event_tracer(enable_etdump, debug_buffer_size), program_verification)) {} explicit PyModule( const void* ptr, size_t ptr_len, + std::optional data_map_ptr, + std::optional data_map_ptr_len, bool enable_etdump, size_t debug_buffer_size = 0, Program::Verification program_verification = @@ -526,6 +550,8 @@ struct PyModule final { module_(load_module_from_buffer( ptr, ptr_len, + data_map_ptr, + data_map_ptr_len, setup_event_tracer(enable_etdump, debug_buffer_size), program_verification)) {} @@ -551,12 +577,17 @@ struct PyModule final { // Module is only valid as long as the python buffer is alive. static std::unique_ptr load_from_buffer( const py::bytes& buffer, + std::optional data_map_buffer, bool enable_etdump, size_t debug_buffer_size = 0, Program::Verification program_verification = Program::Verification::InternalConsistency) { return std::make_unique( - buffer, enable_etdump, debug_buffer_size, program_verification); + buffer, + data_map_buffer, + enable_etdump, + debug_buffer_size, + program_verification); } static std::unique_ptr load_from_file( @@ -576,13 +607,25 @@ struct PyModule final { static std::unique_ptr load_from_bundled_program( PyBundledModule& m, + std::optional data_map_buffer, bool enable_etdump, size_t debug_buffer_size = 0) { + std::optional data_map_ptr = std::nullopt; + std::optional data_map_len = std::nullopt; + + if (data_map_buffer.has_value()) { + data_map_ptr = data_map_buffer.value().cast().data(); + data_map_len = py::len(data_map_buffer.value()); + } + return std::make_unique( m.get_program_ptr(), m.get_program_len(), + data_map_ptr, + data_map_len, enable_etdump, - debug_buffer_size); + debug_buffer_size, + Program::Verification::InternalConsistency); } py::list run_method( @@ -1423,6 +1466,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) { "_load_for_executorch_from_buffer", &PyModule::load_from_buffer, py::arg("buffer"), + py::arg("data_map_buffer") = std::nullopt, py::arg("enable_etdump") = false, py::arg("debug_buffer_size") = 0, py::arg("program_verification") = @@ -1432,6 +1476,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) { "_load_for_executorch_from_bundled_program", &PyModule::load_from_bundled_program, py::arg("ptr"), + py::arg("data_map_buffer") = std::nullopt, py::arg("enable_etdump") = false, py::arg("debug_buffer_size") = 0, call_guard); diff --git a/extension/pybindings/pybindings.pyi b/extension/pybindings/pybindings.pyi index 27e523eb4d7..a3b75780369 100644 --- a/extension/pybindings/pybindings.pyi +++ b/extension/pybindings/pybindings.pyi @@ -185,6 +185,7 @@ def _load_for_executorch( @experimental("This API is experimental and subject to change without notice.") def _load_for_executorch_from_buffer( buffer: bytes, + data_map_buffer: Optional[bytes] = None, enable_etdump: bool = False, debug_buffer_size: int = 0, program_verification: Verification = Verification.InternalConsistency, @@ -199,7 +200,10 @@ def _load_for_executorch_from_buffer( @experimental("This API is experimental and subject to change without notice.") def _load_for_executorch_from_bundled_program( - module: BundledModule, enable_etdump: bool = False, debug_buffer_size: int = 0 + module: BundledModule, + data_map_buffer: Optional[bytes] = None, + enable_etdump: bool = False, + debug_buffer_size: int = 0, ) -> ExecuTorchModule: """Same as _load_for_executorch, but takes a bundled program instead of a file path. diff --git a/extension/pybindings/test/TARGETS b/extension/pybindings/test/TARGETS index e368e7c2404..c6a77c9d64e 100644 --- a/extension/pybindings/test/TARGETS +++ b/extension/pybindings/test/TARGETS @@ -17,6 +17,9 @@ runtime.python_library( deps = [ "//caffe2:torch", "//caffe2:torch_fx", + "//executorch/devtools/bundled_program:config", + "//executorch/devtools/bundled_program:core", + "//executorch/devtools/bundled_program/serialize:lib", "//executorch/exir:lib", "//executorch/exir:pass_manager", "//executorch/exir:scalar_type", diff --git a/extension/pybindings/test/test_pybindings.py b/extension/pybindings/test/test_pybindings.py index 12aec38cec6..02ad6b5e327 100644 --- a/extension/pybindings/test/test_pybindings.py +++ b/extension/pybindings/test/test_pybindings.py @@ -635,6 +635,9 @@ def test_program_data_separation(self) -> None: external_constants=True, ) ) + program_buffer = exec_program.buffer + assert len(exec_program._tensor_data) == 1 + data_buffer = bytes(exec_program._tensor_data.pop("_default_external_constant")) import os import tempfile @@ -642,17 +645,74 @@ def test_program_data_separation(self) -> None: with tempfile.TemporaryDirectory() as tmpdir: pte_file = os.path.join(tmpdir, "linear.pte") with open(pte_file, "wb") as f: - f.write(exec_program.buffer) - + f.write(program_buffer) ptd_file = os.path.join(tmpdir, "linear.ptd") with open(ptd_file, "wb") as ptd: - tensor_data = bytes( - exec_program._tensor_data.pop("_default_external_constant") - ) - ptd.write(tensor_data) + ptd.write(data_buffer) + expected = eager_module(inputs[0]) + # Test 1: File-based loading with external data file + executorch_module_file = self.runtime._load_for_executorch( + pte_file, ptd_file + ) + executorch_output_file = executorch_module_file.forward(inputs)[0] + self.assertTrue(torch.allclose(expected, executorch_output_file)) - executorch_program = self.runtime._load_for_executorch(pte_file, ptd_file) + # Test 2: Buffer-based loading with external data buffer + executorch_module_buffer = self.load_fn(program_buffer, data_buffer) + executorch_output_buffer = executorch_module_buffer.forward(inputs)[0] + self.assertTrue(torch.allclose(expected, executorch_output_buffer)) - expected = eager_module(inputs[0]) - executorch_output = executorch_program.forward(inputs)[0] - self.assertTrue(torch.allclose(expected, executorch_output)) + # Test 3: Buffer-based loading without external data file (should fail or work differently) + # This should fail because the program expects external data + executorch_module_no_data = self.load_fn(program_buffer) + with self.assertRaises(RuntimeError): + executorch_module_no_data.forward(inputs) + + # Test 4: Test with invalid data buffer (should fail) + invalid_bytes = b"invalid bytes" + executorch_module_invalid_data = self.load_fn(program_buffer, invalid_bytes) + with self.assertRaises(RuntimeError): + executorch_module_invalid_data.forward(inputs) + + # Test 5: Test bundled program loading with external data + # First create a bundled program with external constants + from executorch.devtools.bundled_program.config import ( + MethodTestCase, + MethodTestSuite, + ) + from executorch.devtools.bundled_program.core import BundledProgram + from executorch.devtools.bundled_program.serialize import ( + serialize_from_bundled_program_to_flatbuffer, + ) + + method_test_suites = [ + MethodTestSuite( + method_name="forward", + test_cases=[ + MethodTestCase( + inputs=input, + expected_outputs=expected, + ) + for input in inputs + ], + ), + ] + bundled_program = BundledProgram(exec_program, method_test_suites) + bundled_buffer = serialize_from_bundled_program_to_flatbuffer(bundled_program) + bundled_module = self.runtime._load_bundled_program_from_buffer(bundled_buffer) + + # Load module from bundled program with external data + executorch_module_bundled = ( + self.runtime._load_for_executorch_from_bundled_program( + bundled_module, data_buffer + ) + ) + executorch_output_bundled = executorch_module_bundled.forward(inputs)[0] + self.assertTrue(torch.allclose(expected, executorch_output_bundled)) + + # Test 6: Bundled program without external data should fail + executorch_module_bundled_no_data = ( + self.runtime._load_for_executorch_from_bundled_program(bundled_module) + ) + with self.assertRaises(RuntimeError): + executorch_module_bundled_no_data.forward(inputs)