Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions extension/pybindings/pybindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,24 @@ void setup_output_storage(
inline std::unique_ptr<Module> load_module_from_buffer(
const void* ptr,
size_t ptr_len,
std::optional<const void*> data_map_ptr,
std::optional<size_t> data_map_len,
std::unique_ptr<runtime::EventTracer> event_tracer,
Program::Verification program_verification) {
EXECUTORCH_SCOPE_PROF("load_module_from_buffer");
auto loader = std::make_unique<BufferDataLoader>(ptr, ptr_len);

if (data_map_ptr.has_value() && data_map_len.has_value()) {
auto data_map_loader = std::make_unique<BufferDataLoader>(
data_map_ptr.value(), data_map_len.value());
return std::make_unique<Module>(
std::move(loader),
nullptr, // memory_allocator
nullptr, // temp_allocator
std::move(event_tracer), // event_tracer
std::move(data_map_loader)); // data_map_loader
}

return std::make_unique<Module>(
std::move(loader),
nullptr, // memory_allocator
Expand Down Expand Up @@ -504,6 +518,7 @@ struct PyMethodMeta final {
struct PyModule final {
explicit PyModule(
const py::bytes& buffer,
std::optional<const py::bytes> data_map_buffer,
bool enable_etdump,
size_t debug_buffer_size = 0,
Program::Verification program_verification =
Expand All @@ -512,12 +527,21 @@ struct PyModule final {
module_(load_module_from_buffer(
buffer.cast<std::string_view>().data(),
py::len(buffer),
data_map_buffer.has_value()
? std::optional<const void*>(
data_map_buffer.value().cast<std::string_view>().data())
: std::nullopt,
data_map_buffer.has_value()
? std::optional<size_t>(py::len(data_map_buffer.value()))
: std::nullopt,
setup_event_tracer(enable_etdump, debug_buffer_size),
program_verification)) {}

explicit PyModule(
const void* ptr,
size_t ptr_len,
std::optional<const void*> data_map_ptr,
std::optional<size_t> data_map_ptr_len,
bool enable_etdump,
size_t debug_buffer_size = 0,
Program::Verification program_verification =
Expand All @@ -526,6 +550,8 @@ struct PyModule final {
module_(load_module_from_buffer(
ptr,
ptr_len,
data_map_ptr,
data_map_ptr_len,
setup_event_tracer(enable_etdump, debug_buffer_size),
program_verification)) {}

Expand All @@ -551,12 +577,17 @@ struct PyModule final {
// Module is only valid as long as the python buffer is alive.
static std::unique_ptr<PyModule> load_from_buffer(
const py::bytes& buffer,
std::optional<const py::bytes> data_map_buffer,
bool enable_etdump,
size_t debug_buffer_size = 0,
Program::Verification program_verification =
Program::Verification::InternalConsistency) {
return std::make_unique<PyModule>(
buffer, enable_etdump, debug_buffer_size, program_verification);
buffer,
data_map_buffer,
enable_etdump,
debug_buffer_size,
program_verification);
}

static std::unique_ptr<PyModule> load_from_file(
Expand All @@ -576,13 +607,25 @@ struct PyModule final {

static std::unique_ptr<PyModule> load_from_bundled_program(
PyBundledModule& m,
std::optional<const py::bytes> data_map_buffer,
bool enable_etdump,
size_t debug_buffer_size = 0) {
std::optional<const void*> data_map_ptr = std::nullopt;
std::optional<size_t> data_map_len = std::nullopt;

if (data_map_buffer.has_value()) {
data_map_ptr = data_map_buffer.value().cast<std::string_view>().data();
data_map_len = py::len(data_map_buffer.value());
}

return std::make_unique<PyModule>(
m.get_program_ptr(),
m.get_program_len(),
data_map_ptr,
data_map_len,
enable_etdump,
debug_buffer_size);
debug_buffer_size,
Program::Verification::InternalConsistency);
}

py::list run_method(
Expand Down Expand Up @@ -1423,6 +1466,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
"_load_for_executorch_from_buffer",
&PyModule::load_from_buffer,
py::arg("buffer"),
py::arg("data_map_buffer") = std::nullopt,
py::arg("enable_etdump") = false,
py::arg("debug_buffer_size") = 0,
py::arg("program_verification") =
Expand All @@ -1432,6 +1476,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
"_load_for_executorch_from_bundled_program",
&PyModule::load_from_bundled_program,
py::arg("ptr"),
py::arg("data_map_buffer") = std::nullopt,
py::arg("enable_etdump") = false,
py::arg("debug_buffer_size") = 0,
call_guard);
Expand Down
6 changes: 5 additions & 1 deletion extension/pybindings/pybindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def _load_for_executorch(
@experimental("This API is experimental and subject to change without notice.")
def _load_for_executorch_from_buffer(
buffer: bytes,
data_map_buffer: Optional[bytes] = None,
enable_etdump: bool = False,
debug_buffer_size: int = 0,
program_verification: Verification = Verification.InternalConsistency,
Expand All @@ -199,7 +200,10 @@ def _load_for_executorch_from_buffer(

@experimental("This API is experimental and subject to change without notice.")
def _load_for_executorch_from_bundled_program(
module: BundledModule, enable_etdump: bool = False, debug_buffer_size: int = 0
module: BundledModule,
data_map_buffer: Optional[bytes] = None,
enable_etdump: bool = False,
debug_buffer_size: int = 0,
) -> ExecuTorchModule:
"""Same as _load_for_executorch, but takes a bundled program instead of a file path.

Expand Down
3 changes: 3 additions & 0 deletions extension/pybindings/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ runtime.python_library(
deps = [
"//caffe2:torch",
"//caffe2:torch_fx",
"//executorch/devtools/bundled_program:config",
"//executorch/devtools/bundled_program:core",
"//executorch/devtools/bundled_program/serialize:lib",
"//executorch/exir:lib",
"//executorch/exir:pass_manager",
"//executorch/exir:scalar_type",
Expand Down
80 changes: 70 additions & 10 deletions extension/pybindings/test/test_pybindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,24 +635,84 @@ def test_program_data_separation(self) -> None:
external_constants=True,
)
)
program_buffer = exec_program.buffer
assert len(exec_program._tensor_data) == 1
data_buffer = bytes(exec_program._tensor_data.pop("_default_external_constant"))

import os
import tempfile

with tempfile.TemporaryDirectory() as tmpdir:
pte_file = os.path.join(tmpdir, "linear.pte")
with open(pte_file, "wb") as f:
f.write(exec_program.buffer)

f.write(program_buffer)
ptd_file = os.path.join(tmpdir, "linear.ptd")
with open(ptd_file, "wb") as ptd:
tensor_data = bytes(
exec_program._tensor_data.pop("_default_external_constant")
)
ptd.write(tensor_data)
ptd.write(data_buffer)
expected = eager_module(inputs[0])
# Test 1: File-based loading with external data file
executorch_module_file = self.runtime._load_for_executorch(
pte_file, ptd_file
)
executorch_output_file = executorch_module_file.forward(inputs)[0]
self.assertTrue(torch.allclose(expected, executorch_output_file))

executorch_program = self.runtime._load_for_executorch(pte_file, ptd_file)
# Test 2: Buffer-based loading with external data buffer
executorch_module_buffer = self.load_fn(program_buffer, data_buffer)
executorch_output_buffer = executorch_module_buffer.forward(inputs)[0]
self.assertTrue(torch.allclose(expected, executorch_output_buffer))

expected = eager_module(inputs[0])
executorch_output = executorch_program.forward(inputs)[0]
self.assertTrue(torch.allclose(expected, executorch_output))
# Test 3: Buffer-based loading without external data file (should fail or work differently)
# This should fail because the program expects external data
executorch_module_no_data = self.load_fn(program_buffer)
with self.assertRaises(RuntimeError):
executorch_module_no_data.forward(inputs)

# Test 4: Test with invalid data buffer (should fail)
invalid_bytes = b"invalid bytes"
executorch_module_invalid_data = self.load_fn(program_buffer, invalid_bytes)
with self.assertRaises(RuntimeError):
executorch_module_invalid_data.forward(inputs)

# Test 5: Test bundled program loading with external data
# First create a bundled program with external constants
from executorch.devtools.bundled_program.config import (
MethodTestCase,
MethodTestSuite,
)
from executorch.devtools.bundled_program.core import BundledProgram
from executorch.devtools.bundled_program.serialize import (
serialize_from_bundled_program_to_flatbuffer,
)

method_test_suites = [
MethodTestSuite(
method_name="forward",
test_cases=[
MethodTestCase(
inputs=input,
expected_outputs=expected,
)
for input in inputs
],
),
]
bundled_program = BundledProgram(exec_program, method_test_suites)
bundled_buffer = serialize_from_bundled_program_to_flatbuffer(bundled_program)
bundled_module = self.runtime._load_bundled_program_from_buffer(bundled_buffer)

# Load module from bundled program with external data
executorch_module_bundled = (
self.runtime._load_for_executorch_from_bundled_program(
bundled_module, data_buffer
)
)
executorch_output_bundled = executorch_module_bundled.forward(inputs)[0]
self.assertTrue(torch.allclose(expected, executorch_output_bundled))

# Test 6: Bundled program without external data should fail
executorch_module_bundled_no_data = (
self.runtime._load_for_executorch_from_bundled_program(bundled_module)
)
with self.assertRaises(RuntimeError):
executorch_module_bundled_no_data.forward(inputs)
Loading