diff --git a/extension/pybindings/pybindings.cpp b/extension/pybindings/pybindings.cpp index 06a6cb475d9..37d36dc9349 100644 --- a/extension/pybindings/pybindings.cpp +++ b/extension/pybindings/pybindings.cpp @@ -174,22 +174,42 @@ inline std::unique_ptr load_module_from_buffer( } inline std::unique_ptr load_module_from_file( - const std::string& path, + const std::string& program_path, + std::optional& data_map_path, std::unique_ptr event_tracer, Program::Verification program_verification) { EXECUTORCH_SCOPE_PROF("load_module_from_file"); - Result res = MmapDataLoader::from( - path.c_str(), MmapDataLoader::MlockConfig::UseMlockIgnoreErrors); + Result program_loader_res = MmapDataLoader::from( + program_path.c_str(), MmapDataLoader::MlockConfig::UseMlockIgnoreErrors); THROW_IF_ERROR( - res.error(), + program_loader_res.error(), "Failed to create MmapDataLoader from file %s, error: 0x:%" PRIx32, - path.c_str(), - static_cast(res.error())); - - auto loader = std::make_unique(std::move(res.get())); + program_path.c_str(), + static_cast(program_loader_res.error())); + auto program_loader = + std::make_unique(std::move(program_loader_res.get())); + + if (data_map_path.has_value()) { + Result data_map_loader_res = MmapDataLoader::from( + data_map_path->c_str(), + MmapDataLoader::MlockConfig::UseMlockIgnoreErrors); + THROW_IF_ERROR( + data_map_loader_res.error(), + "Failed to create MmapDataLoader from file %s, error: 0x:%" PRIx32, + data_map_path->c_str(), + static_cast(data_map_loader_res.error())); + auto data_map_loader = + std::make_unique(std::move(data_map_loader_res.get())); + return std::make_unique( + std::move(program_loader), + nullptr, // memory_allocator + nullptr, // temp_allocator + std::move(event_tracer), // event_tracer + std::move(data_map_loader)); // data_map_loader + } return std::make_unique( - std::move(loader), + std::move(program_loader), nullptr, // memory_allocator nullptr, // temp_allocator std::move(event_tracer), // event_tracer @@ -510,14 +530,16 @@ struct PyModule final { program_verification)) {} explicit PyModule( - const std::string& path, + const std::string& program_path, + std::optional& data_path, bool enable_etdump, size_t debug_buffer_size = 0, Program::Verification program_verification = Program::Verification::InternalConsistency) : debug_buffer_size_(debug_buffer_size), module_(load_module_from_file( - path, + program_path, + data_path, setup_event_tracer(enable_etdump, debug_buffer_size), program_verification)) {} @@ -536,14 +558,20 @@ struct PyModule final { return std::make_unique( buffer, enable_etdump, debug_buffer_size, program_verification); } + static std::unique_ptr load_from_file( - const std::string& path, + const std::string& program_path, + std::optional& data_path, bool enable_etdump, size_t debug_buffer_size = 0, Program::Verification program_verification = Program::Verification::InternalConsistency) { return std::make_unique( - path, enable_etdump, debug_buffer_size, program_verification); + program_path, + data_path, + enable_etdump, + debug_buffer_size, + program_verification); } static std::unique_ptr load_from_bundled_program( @@ -1351,7 +1379,8 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) { m.def( "_load_for_executorch", PyModule::load_from_file, - py::arg("path"), + py::arg("program_path"), + py::arg("data_path") = std::nullopt, py::arg("enable_etdump") = false, py::arg("debug_buffer_size") = 0, py::arg("program_verification") = diff --git a/extension/pybindings/pybindings.pyi b/extension/pybindings/pybindings.pyi index 1978d22ea96..27e523eb4d7 100644 --- a/extension/pybindings/pybindings.pyi +++ b/extension/pybindings/pybindings.pyi @@ -156,7 +156,8 @@ class MethodMeta: @experimental("This API is experimental and subject to change without notice.") def _load_for_executorch( - path: str, + program_path: str, + data_path: Optional[str] = None, enable_etdump: bool = False, debug_buffer_size: int = 0, program_verification: Verification = Verification.InternalConsistency, @@ -168,7 +169,8 @@ def _load_for_executorch( This API is experimental and subject to change without notice. Args: - path: File path to the ExecuTorch program as a string. + program_path: File path to the ExecuTorch program as a string. + data_path: File path to a .ptd file containing data used by the program. enable_etdump: If true, enables an ETDump which can store profiling information. See documentation at https://pytorch.org/executorch/main/etdump for how to use it. diff --git a/extension/pybindings/test/make_test.py b/extension/pybindings/test/make_test.py index e2aba346944..a1bf4b980e0 100644 --- a/extension/pybindings/test/make_test.py +++ b/extension/pybindings/test/make_test.py @@ -133,6 +133,21 @@ def get_inputs(self): return (torch.ones(2, 2), torch.ones(2, 2)) +class ModuleLinear(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + def get_methods_to_export(self): + return ("forward",) + + def get_inputs(self): + return (torch.randn(3),) + + def create_program( eager_module: torch.nn.Module, et_config: Optional[ExecutorchBackendConfig] = None, diff --git a/extension/pybindings/test/test_pybindings.py b/extension/pybindings/test/test_pybindings.py index 8bbdb0d86d4..12aec38cec6 100644 --- a/extension/pybindings/test/test_pybindings.py +++ b/extension/pybindings/test/test_pybindings.py @@ -22,6 +22,7 @@ ModuleAddWithAttributes, ModuleChannelsLast, ModuleChannelsLastInDefaultOut, + ModuleLinear, ModuleMulti, ) from torch.export import export @@ -623,3 +624,35 @@ def test_method_method_meta(self) -> None: self.assertEqual(output_tensor.is_memory_planned(), True) self.assertEqual(output_tensor.nbytes(), 16) self.assertEqual(str(output_tensor), tensor_info) + + def test_program_data_separation(self) -> None: + eager_module = ModuleLinear() + inputs = eager_module.get_inputs() + exported_program = export(eager_module, inputs, strict=True) + exec_program = to_edge(exported_program).to_executorch( + config=ExecutorchBackendConfig( + # Move all tensor data to '_default_external_constant' file. + external_constants=True, + ) + ) + + import os + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + pte_file = os.path.join(tmpdir, "linear.pte") + with open(pte_file, "wb") as f: + f.write(exec_program.buffer) + + ptd_file = os.path.join(tmpdir, "linear.ptd") + with open(ptd_file, "wb") as ptd: + tensor_data = bytes( + exec_program._tensor_data.pop("_default_external_constant") + ) + ptd.write(tensor_data) + + executorch_program = self.runtime._load_for_executorch(pte_file, ptd_file) + + expected = eager_module(inputs[0]) + executorch_output = executorch_program.forward(inputs)[0] + self.assertTrue(torch.allclose(expected, executorch_output))