diff --git a/backends/qnnpack/test/TARGETS b/backends/qnnpack/test/TARGETS index d0a25fa9ce4..99097c58687 100644 --- a/backends/qnnpack/test/TARGETS +++ b/backends/qnnpack/test/TARGETS @@ -16,7 +16,6 @@ python_unittest( "//executorch/backends/qnnpack/partition:qnnpack_partitioner", "//executorch/exir:lib", "//executorch/exir/backend:backend_api", - "//executorch/exir/serialize:lib", "//executorch/extension/pybindings:portable", # @manual "//executorch/extension/pytree:pylib", ], diff --git a/backends/qnnpack/test/test_qnnpack.py b/backends/qnnpack/test/test_qnnpack.py index b6c65bdbae7..71aa47e03c2 100644 --- a/backends/qnnpack/test/test_qnnpack.py +++ b/backends/qnnpack/test/test_qnnpack.py @@ -19,7 +19,6 @@ from executorch.exir.backend.backend_api import to_backend, validation_disabled -# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`. from executorch.extension.pybindings.portable import ( # @manual _load_for_executorch_from_buffer, ) diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 5ed35237bee..7b63edde711 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -19,7 +19,7 @@ ctypes.CDLL("libvulkan.so.1") -# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`. + from executorch.extension.pybindings.portable import ( # @manual _load_for_executorch_from_buffer, ) @@ -85,7 +85,6 @@ def forward(self, *args): ) # Test the model with executor - # pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`. executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer) # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. inputs_flattened, _ = tree_flatten(sample_inputs) diff --git a/backends/xnnpack/test/TARGETS b/backends/xnnpack/test/TARGETS index 3305f259931..9b59a6eb9cd 100644 --- a/backends/xnnpack/test/TARGETS +++ b/backends/xnnpack/test/TARGETS @@ -28,7 +28,6 @@ python_unittest( "//executorch/exir:tracer", "//executorch/exir/backend:backend_api", "//executorch/exir/passes:spec_prop_pass", - "//executorch/exir/serialize:lib", "//executorch/extension/pybindings:portable", # @manual "//executorch/extension/pytree:pylib", ], @@ -59,7 +58,6 @@ python_unittest( "//executorch/exir/backend:backend_api", "//executorch/exir/dialects:lib", "//executorch/exir/passes:spec_prop_pass", - "//executorch/exir/serialize:lib", "//executorch/extension/pybindings:portable", # @manual "//executorch/extension/pytree:pylib", ], @@ -89,7 +87,6 @@ python_unittest( "//executorch/exir:tracer", "//executorch/exir/backend:backend_api", "//executorch/exir/passes:spec_prop_pass", - "//executorch/exir/serialize:lib", "//executorch/extension/pybindings:portable", # @manual "//executorch/extension/pytree:pylib", "//pytorch/vision:torchvision", @@ -133,6 +130,7 @@ python_unittest( "//caffe2:torch", "//executorch/backends/xnnpack/partition:xnnpack_partitioner", "//executorch/backends/xnnpack/test/tester:tester", + "//executorch/backends/xnnpack/utils:xnnpack_utils", "//pytorch/vision:torchvision", ], ) diff --git a/backends/xnnpack/test/test_xnnpack_utils.py b/backends/xnnpack/test/test_xnnpack_utils.py index 7af509c462a..bc748ae513f 100644 --- a/backends/xnnpack/test/test_xnnpack_utils.py +++ b/backends/xnnpack/test/test_xnnpack_utils.py @@ -37,7 +37,6 @@ from executorch.exir.passes.spec_prop_pass import SpecPropPass from executorch.exir.tracer import _default_decomposition_table -# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`. from executorch.extension.pybindings.portable import ( # @manual _load_for_executorch_from_buffer, ) @@ -230,7 +229,6 @@ def forward(self, *args): ) # Test the model with executor - # pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`. executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer) # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. inputs_flattened, _ = tree_flatten(sample_inputs) @@ -439,7 +437,6 @@ def forward(self, x): output_path=filename, ) - # pyre-ignore executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer) # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. inputs_flattened, _ = tree_flatten(example_inputs) diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index a736284bd9d..a7825883f4a 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -30,10 +30,7 @@ from executorch.exir.backend.partitioner import Partitioner from executorch.exir.passes.spec_prop_pass import SpecPropPass -# pyre-ignore[21]: Could not find module `executorch.pybindings.portable`. -from executorch.extension.pybindings.portable import ( # @manual - _load_for_executorch_from_buffer, -) +from executorch.extension.pybindings.portable import _load_for_executorch_from_buffer from torch.ao.quantization.backend_config import BackendConfig from torch.ao.quantization.backend_config.executorch import ( get_executorch_backend_config, diff --git a/examples/export/test/test_export.py b/examples/export/test/test_export.py index e4cb98bcffa..0f1135c20c8 100644 --- a/examples/export/test/test_export.py +++ b/examples/export/test/test_export.py @@ -13,7 +13,6 @@ from executorch.examples.export.utils import export_to_edge from executorch.examples.models import MODEL_NAME_TO_MODEL -# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`. from executorch.extension.pybindings.portable import ( # @manual _load_for_executorch_from_buffer, ) @@ -37,7 +36,7 @@ def _assert_eager_lowered_same_result( edge_model = export_to_edge(eager_model, example_inputs) executorch_prog = edge_model.to_executorch() - # pyre-ignore + pte_model = _load_for_executorch_from_buffer(executorch_prog.buffer) with torch.no_grad(): diff --git a/exir/backend/test/demos/rpc/test_rpc.py b/exir/backend/test/demos/rpc/test_rpc.py index f53d72b1b5e..fc754742bd3 100644 --- a/exir/backend/test/demos/rpc/test_rpc.py +++ b/exir/backend/test/demos/rpc/test_rpc.py @@ -17,7 +17,6 @@ ) from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo -# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`. from executorch.extension.pybindings.portable import ( # @manual _load_for_executorch_from_buffer, ) diff --git a/exir/backend/test/demos/test_delegate_aten_mode.py b/exir/backend/test/demos/test_delegate_aten_mode.py index 9198de3989b..7c80cf20cc9 100644 --- a/exir/backend/test/demos/test_delegate_aten_mode.py +++ b/exir/backend/test/demos/test_delegate_aten_mode.py @@ -15,7 +15,6 @@ BackendWithCompilerDemo, ) -# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`. from executorch.extension.pybindings.aten_mode_lib import ( # @manual _load_for_executorch_from_buffer, ) diff --git a/exir/backend/test/demos/test_xnnpack_qnnpack.py b/exir/backend/test/demos/test_xnnpack_qnnpack.py index f7df6c22605..4d8bbd4af0c 100644 --- a/exir/backend/test/demos/test_xnnpack_qnnpack.py +++ b/exir/backend/test/demos/test_xnnpack_qnnpack.py @@ -22,7 +22,6 @@ from executorch.exir.backend.backend_api import to_backend, validation_disabled from executorch.exir.passes.spec_prop_pass import SpecPropPass -# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`. from executorch.extension.pybindings.portable import ( # @manual _load_for_executorch_from_buffer, ) diff --git a/exir/backend/test/test_backends.py b/exir/backend/test/test_backends.py index 34f57a76c01..c79ccc728bb 100644 --- a/exir/backend/test/test_backends.py +++ b/exir/backend/test/test_backends.py @@ -46,7 +46,6 @@ Program, ) -# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`. from executorch.extension.pybindings.portable import ( # @manual _load_for_executorch_from_buffer, ) @@ -224,7 +223,6 @@ def forward(self, x): ) buff = exec_prog.buffer - # pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`. executorch_module = _load_for_executorch_from_buffer(buff) model_inputs = torch.ones(1) model_outputs = executorch_module.forward([model_inputs]) @@ -281,7 +279,6 @@ def forward(self, a, x, b): ) buff = exec_prog.buffer - # pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`. executorch_module = _load_for_executorch_from_buffer(buff) # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. @@ -338,7 +335,6 @@ def forward(self, x): # This line should raise an exception like # RuntimeError: failed with error 0x12 - # pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`. _load_for_executorch_from_buffer(buff) @vary_segments @@ -434,7 +430,6 @@ def forward(self, x): ) ) - # pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`. executorch_module = _load_for_executorch_from_buffer(buff) model_inputs = torch.ones(1) @@ -561,7 +556,6 @@ def forward(self, x): ) flatbuffer = exec_prog.buffer - # pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`. executorch_module = _load_for_executorch_from_buffer(flatbuffer) model_outputs = executorch_module.forward([*model_inputs]) @@ -858,7 +852,6 @@ def forward(self, a, x, b): # There should be 2 delegated modules self.assertEqual(counter, 2) - # pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`. executorch_module = _load_for_executorch_from_buffer(executorch_prog.buffer) # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. inputs_flattened, _ = tree_flatten(inputs) diff --git a/exir/backend/test/test_backends_lifted.py b/exir/backend/test/test_backends_lifted.py index 0d52eb7af47..985a223ce7c 100644 --- a/exir/backend/test/test_backends_lifted.py +++ b/exir/backend/test/test_backends_lifted.py @@ -49,7 +49,6 @@ Program, ) -# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`. from executorch.extension.pybindings.portable import ( # @manual _load_for_executorch_from_buffer, ) @@ -231,7 +230,6 @@ def forward(self, x): ) buff = exec_prog.buffer - # pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`. executorch_module = _load_for_executorch_from_buffer(buff) model_inputs = torch.ones(1) model_outputs = executorch_module.forward([model_inputs]) @@ -290,7 +288,6 @@ def forward(self, a, x, b): ) buff = exec_prog.buffer - # pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`. executorch_module = _load_for_executorch_from_buffer(buff) # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. @@ -347,7 +344,6 @@ def forward(self, x): # This line should raise an exception like # RuntimeError: failed with error 0x12 - # pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`. _load_for_executorch_from_buffer(buff) @vary_segments @@ -443,7 +439,6 @@ def forward(self, x): ) ) - # pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`. executorch_module = _load_for_executorch_from_buffer(buff) model_inputs = torch.ones(1) @@ -570,7 +565,6 @@ def forward(self, x): ) flatbuffer = exec_prog.buffer - # pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`. executorch_module = _load_for_executorch_from_buffer(flatbuffer) model_outputs = executorch_module.forward([*model_inputs]) @@ -869,7 +863,6 @@ def forward(self, a, x, b): # There should be 2 delegated modules self.assertEqual(counter, 2) - # pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`. executorch_module = _load_for_executorch_from_buffer(executorch_prog.buffer) # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. inputs_flattened, _ = tree_flatten(inputs) diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 31f52c9bc3a..3a73c178548 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -38,10 +38,7 @@ from executorch.exir.tests.common import register_additional_test_aten_ops from executorch.exir.tests.models import MLP, Mul -# pyre-ignore -from executorch.extension.pybindings.portable import ( # @manual - _load_for_executorch_from_buffer, -) +from executorch.extension.pybindings.portable import _load_for_executorch_from_buffer from functorch.experimental import control_flow diff --git a/extension/aten_util/aten_bridge.cpp b/extension/aten_util/aten_bridge.cpp index 8f417c85b22..b4e4416060f 100644 --- a/extension/aten_util/aten_bridge.cpp +++ b/extension/aten_util/aten_bridge.cpp @@ -147,39 +147,5 @@ at::Tensor alias_attensor_to_etensor(const torch::executor::Tensor& etensor) { check_tensor_meta(t, etensor); return t; } - -std::unique_ptr eTensorFromAtTensor( - const at::Tensor& tensor, - KeepAliveSizes& keep_alive) { - auto sizes = tensor.sizes(); - auto options = tensor.options(); - keep_alive.sizes32.emplace_back(sizes.size()); - auto& sizes32 = keep_alive.sizes32.back(); - for (size_t i = 0; i < sizes.size(); ++i) { - // NOLINTNEXTLINE - sizes32[i] = sizes[i]; - } - - const torch::executor::ScalarType edtype = - torchToExecuTorchScalarType(options.dtype()); - - return std::make_unique( - edtype, sizes32.size(), sizes32.data(), tensor.mutable_data_ptr()); -} - -at::Tensor atTensorFromETensor( - torch::executor::TensorImpl* etensor, - KeepAliveSizes& keep_alive) { - c10::ScalarType dtype = execuTorchtoTorchScalarType(etensor->scalar_type()); - keep_alive.sizes64.emplace_back(etensor->sizes().size()); - auto& sizes64 = keep_alive.sizes64.back(); - for (size_t i = 0; i < etensor->sizes().size(); ++i) { - // NOLINTNEXTLINE - sizes64[i] = etensor->sizes()[i]; - } - return at::from_blob( - etensor->mutable_data(), sizes64, at::TensorOptions(dtype)); -} - } // namespace util } // namespace torch diff --git a/extension/aten_util/aten_bridge.h b/extension/aten_util/aten_bridge.h index 36182e13d62..596c64d1e37 100644 --- a/extension/aten_util/aten_bridge.h +++ b/extension/aten_util/aten_bridge.h @@ -21,22 +21,6 @@ namespace torch { namespace util { -using sizes32_t = std::vector; -using sizes64_t = std::vector; - -struct KeepAliveSizes { - std::vector sizes32; - std::vector sizes64; -}; - -// TODO: we should really remove this as -__ET_DEPRECATED std::unique_ptr -eTensorFromAtTensor(const at::Tensor& tensor, KeepAliveSizes& keep_alive); - -__ET_DEPRECATED at::Tensor atTensorFromETensor( - torch::executor::TensorImpl* etensor, - KeepAliveSizes& keep_alive); - torch::executor::ScalarType torchToExecuTorchScalarType(caffe2::TypeMeta type); c10::ScalarType execuTorchtoTorchScalarType(torch::executor::ScalarType type); diff --git a/extension/pybindings/TARGETS b/extension/pybindings/TARGETS index 4e21936cdde..6d96f4dc90d 100644 --- a/extension/pybindings/TARGETS +++ b/extension/pybindings/TARGETS @@ -3,10 +3,25 @@ # targets.bzl. This file can contain fbcode-only targets. load("@fbcode//executorch/extension/pybindings:targets.bzl", "ATEN_MODULE_DEPS", "MODELS_ATEN_OPS_ATEN_MODE_GENERATED_LIB", "MODELS_ATEN_OPS_LEAN_MODE_GENERATED_LIB", "PORTABLE_MODULE_DEPS", "define_common_targets", "executorch_pybindings") +load("@fbcode_macros//build_defs:native_rules.bzl", "buck_genrule") load("@fbcode_macros//build_defs:python_library.bzl", "python_library") define_common_targets() +# In order to have pyre recognize the pybindings module, the name of the .pyi must exactly match the +# name of the lib. To avoid copy pasting the pyi file in tree a whole bunch of times we use genrules +# to do it for us +buck_genrule( + name = "pybindings_types_gen", + srcs = [":pybinding_types"], + outs = { + "aten_mode_lib.pyi": ["aten_mode_lib.pyi"], + "portable.pyi": ["portable.pyi"], + }, + cmd = "cp $(location :pybinding_types)/* $OUT/portable.pyi && cp $(location :pybinding_types)/* $OUT/aten_mode_lib.pyi", + visibility = ["//executorch/extension/pybindings/..."], +) + executorch_pybindings( srcs = [ "module.cpp", @@ -22,6 +37,7 @@ executorch_pybindings( ], cppdeps = PORTABLE_MODULE_DEPS + MODELS_ATEN_OPS_LEAN_MODE_GENERATED_LIB, python_module_name = "portable", + types = ["//executorch/extension/pybindings:pybindings_types_gen[portable.pyi]"], visibility = ["PUBLIC"], ) @@ -31,6 +47,7 @@ executorch_pybindings( ], cppdeps = ATEN_MODULE_DEPS + MODELS_ATEN_OPS_ATEN_MODE_GENERATED_LIB, python_module_name = "aten_mode_lib", + types = ["//executorch/extension/pybindings:pybindings_types_gen[aten_mode_lib.pyi]"], visibility = ["PUBLIC"], ) diff --git a/extension/pybindings/module.cpp b/extension/pybindings/module.cpp index 147750d36b1..a5634df9d57 100644 --- a/extension/pybindings/module.cpp +++ b/extension/pybindings/module.cpp @@ -55,7 +55,6 @@ }) namespace py = pybind11; -using ATTensor = at::Tensor; namespace torch { namespace executor { @@ -134,15 +133,7 @@ class Module final { /// Executes the specified method on the provided inputs and returns its /// outputs. - template std::vector run_method( - const std::string& method_name, - Types&&... args) { - return run_method_internal(method_name, std::vector{args...}); - } - - private: - std::vector run_method_internal( const std::string& method_name, const std::vector& args) { auto& method = methods_[method_name]; @@ -187,6 +178,7 @@ class Module final { return result; } + private: /// A wrapper/util class for executorch memory allocations/manager. class Memory { public: @@ -266,66 +258,6 @@ inline std::unique_ptr load_from_file(const std::string& path) { return std::make_unique(std::move(loader)); } -// Struct used to manage the memory of tensors allocated in lean (not aten) mode -#ifdef USE_ATEN_LIB -struct KeepAlive {}; -#else -struct KeepAlive { - std::vector> tensors; - torch::util::KeepAliveSizes sizes; -}; -#endif - -EValue pyToEValue(py::handle h, KeepAlive& keep_alive) { - const std::string& type_str = py::str(h.get_type()); - EXECUTORCH_SCOPE_PROF("pyToEValue"); - if (type_str == "") { - auto atTensor = h.cast(); -#ifdef USE_ATEN_LIB - EValue evalue(atTensor); -#else - auto etensorImpl = - torch::util::eTensorFromAtTensor(atTensor, keep_alive.sizes); - EValue evalue(torch::executor::Tensor(etensorImpl.get())); - keep_alive.tensors.push_back(std::move(etensorImpl)); -#endif - return evalue; - } else if (py::isinstance(h)) { - return EValue(); - } else if (py::isinstance(h)) { - return EValue(py::cast(h)); - } else if (py::isinstance(h)) { - return EValue(py::cast(h)); - } else { - // Unsupported pytype - ET_ASSERT_UNREACHABLE_MSG(type_str.c_str()); - } -} - -py::object pyFromEValue(const EValue& v, KeepAlive& keep_alive) { - EXECUTORCH_SCOPE_PROF("pyFromEValue"); - if (Tag::None == v.tag) { - return py::none(); - } else if (Tag::Int == v.tag) { - return py::cast(v.toInt()); - } else if (Tag::Double == v.tag) { - return py::cast(v.toDouble()); - } else if (Tag::Bool == v.tag) { - return py::cast(v.toBool()); - } else if (Tag::Tensor == v.tag) { -#ifdef USE_ATEN_LIB - return py::cast(v.toTensor().clone()); -#else - // Clone so the outputs in python do not share a lifetime with the module - // object - return py::cast(torch::util::atTensorFromETensor( - v.toTensor().unsafeGetTensorImpl(), keep_alive.sizes) - .clone()); -#endif - } - ET_ASSERT_UNREACHABLE(); -} - static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U; struct PyBundledModule final { @@ -403,30 +335,125 @@ struct PyModule final { return std::make_unique(m.get_program_ptr(), m.get_program_len()); } - py::list run_method(const std::string& name, const py::sequence& pyinputs) { - std::vector inputs; - const auto inputs_size = py::len(pyinputs); - inputs.reserve(inputs_size); + py::list run_method( + const std::string& method_name, + const py::sequence& inputs) { + const auto inputs_size = py::len(inputs); + std::vector cpp_inputs; + cpp_inputs.reserve(inputs_size); + +#ifndef USE_ATEN_LIB // Portable mode + // So the ETensors and their metadata stay in scope for Module->run_method. + std::vector input_tensors; + std::vector> input_sizes; + std::vector> + input_strides; + std::vector> + input_dim_order; + // We store pointers to these vector elements so important to reserve so + // that we don't lose those on a vector resize. Don't need to do this for + // the others since they are vectors of vectors, and we don't store a + // pointer to the root level vector data. + input_tensors.reserve(inputs_size); +#endif + + // Convert python objects into EValues. for (size_t i = 0; i < inputs_size; ++i) { - inputs.emplace_back(pyToEValue(pyinputs[i], keep_alive_)); + auto python_input = inputs[i]; + const std::string& type_str = py::str(python_input.get_type()); + if (type_str == "") { + auto at_tensor = python_input.cast(); + // alias_etensor_to_attensor will assert on this later, so to better + // propogate up to python we check early and throw an exception. + if (!at_tensor.is_contiguous()) { + auto error_msg = "Input " + std::to_string(i) + "for method " + + method_name + " is not contiguous."; + throw std::runtime_error(error_msg); + } + +#ifdef USE_ATEN_LIB + EValue evalue(at_tensor); +#else + // convert at::Tensor to torch::executor::Tensor + auto type = torch::util::torchToExecuTorchScalarType( + at_tensor.options().dtype()); + size_t dim = at_tensor.dim(); + // cant directly alias at::Tensor sizes and strides due to int64 vs + // int32 typing conflict + input_sizes.emplace_back( + at_tensor.sizes().begin(), at_tensor.sizes().end()); + input_strides.emplace_back( + at_tensor.strides().begin(), at_tensor.strides().end()); + + // Only works for MemoryFormat::Contiguous inputs + std::vector dim_order; + for (size_t cur_dim = 0; cur_dim < dim; cur_dim++) { + dim_order.push_back(cur_dim); + } + input_dim_order.push_back(std::move(dim_order)); + input_tensors.emplace_back( + type, + dim, + input_sizes[i].data(), + nullptr, + input_dim_order[i].data(), + input_strides[i].data()); + + torch::executor::Tensor temp = + torch::executor::Tensor(&input_tensors[i]); + torch::util::alias_etensor_to_attensor(at_tensor, temp); + EValue evalue(temp); +#endif + + cpp_inputs.push_back(evalue); + } else if (py::isinstance(python_input)) { + cpp_inputs.push_back(EValue()); + } else if (py::isinstance(python_input)) { + cpp_inputs.push_back(EValue(py::cast(python_input))); + } else if (py::isinstance(python_input)) { + cpp_inputs.push_back(EValue(py::cast(python_input))); + } else { + // Unsupported pytype + ET_ASSERT_UNREACHABLE_MSG(type_str.c_str()); + } } - auto outputs = module_->run_method(name, inputs); + auto outputs = module_->run_method(method_name, cpp_inputs); + // Retrieve outputs const auto outputs_size = outputs.size(); py::list list(outputs_size); for (size_t i = 0; i < outputs_size; ++i) { - list[i] = pyFromEValue(outputs[i], keep_alive_); + auto& v = outputs[i]; + if (Tag::None == v.tag) { + list[i] = py::none(); + } else if (Tag::Int == v.tag) { + list[i] = py::cast(v.toInt()); + } else if (Tag::Double == v.tag) { + list[i] = py::cast(v.toDouble()); + } else if (Tag::Bool == v.tag) { + list[i] = py::cast(v.toBool()); + } else if (Tag::Tensor == v.tag) { +#ifdef USE_ATEN_LIB + // Clone so the outputs in python do not share a lifetime with the + // module object + list[i] = py::cast(v.toTensor().clone()); +#else + list[i] = py::cast( + torch::util::alias_attensor_to_etensor(v.toTensor()).clone()); +#endif + } else { + ET_ASSERT_UNREACHABLE_MSG("Invalid model output type"); + } } return list; } - py::list forward(const py::sequence& pyinputs) { - return run_method("forward", pyinputs); + py::list forward(const py::sequence& inputs) { + return run_method("forward", inputs); } private: - KeepAlive keep_alive_; std::unique_ptr module_; }; @@ -461,7 +488,7 @@ void init_module_functions(py::module_& m) { m.def("_create_profile_block", &create_profile_block); m.def("_reset_profile_results", []() { EXECUTORCH_RESET_PROFILE_RESULTS(); }); - py::class_(m, "Module") + py::class_(m, "ExecutorchModule") .def("run_method", &PyModule::run_method) .def("forward", &PyModule::forward); diff --git a/extension/pybindings/pybindings.pyi b/extension/pybindings/pybindings.pyi new file mode 100644 index 00000000000..e04a07bc984 --- /dev/null +++ b/extension/pybindings/pybindings.pyi @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +from typing import Any, Dict, List, Sequence, Tuple + +class ExecutorchModule: + def run_method(self, method_name: str, inputs: Sequence[Any]) -> List[Any]: ... + def forward(self, inputs: Sequence[Any]) -> List[Any]: ... + +def _load_for_executorch(path: str) -> ExecutorchModule: ... +def _load_for_executorch_from_buffer(buffer: bytes) -> ExecutorchModule: ... +def _create_profile_block(name: str) -> None: ... +def _dump_profile_results() -> bytes: ... +def _reset_profile_results() -> None: ... diff --git a/extension/pybindings/targets.bzl b/extension/pybindings/targets.bzl index 7f5eb2e533f..aec79fc9baf 100644 --- a/extension/pybindings/targets.bzl +++ b/extension/pybindings/targets.bzl @@ -46,12 +46,13 @@ MODELS_ATEN_OPS_ATEN_MODE_GENERATED_LIB = [ "//executorch/kernels/aten:generated_lib_aten", ] -def executorch_pybindings(python_module_name, srcs = [], cppdeps = [], visibility = ["//executorch/..."]): +def executorch_pybindings(python_module_name, srcs = [], cppdeps = [], visibility = ["//executorch/..."], types = []): runtime.cxx_python_extension( name = python_module_name, srcs = [ "//executorch/extension/pybindings:pybindings.cpp", ] + srcs, + types = types, base_module = "executorch.extension.pybindings", preprocessor_flags = [ "-DEXECUTORCH_PYTHON_MODULE_NAME={}".format(python_module_name), @@ -88,6 +89,15 @@ def define_common_targets(): visibility = ["//executorch/extension/pybindings/..."], ) + # cxx_python_extension kwarg 'types' can't take export_file rules directly and we need to rename the .pyi + # file to match the lib anyway, so we just expose the file like this and then have genrules consume and + # rename it before passing it to executorch pybindings. + runtime.filegroup( + name = "pybinding_types", + srcs = ["pybindings.pyi"], + visibility = ["//executorch/extension/pybindings/..."], + ) + executorch_pybindings( srcs = [ "module_stub.cpp", diff --git a/extension/pybindings/test/test.py b/extension/pybindings/test/test.py index 53985fee72a..b8f92fcd7cf 100644 --- a/extension/pybindings/test/test.py +++ b/extension/pybindings/test/test.py @@ -15,14 +15,13 @@ from executorch.exir.scalar_type import ScalarType from executorch.exir.schema import Program -# executorch.extension.pybindings.portable is a cpp_python_extension target. -# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`. +# pyre-ignore[21] from executorch.extension.pybindings.portable import ( - _get_io_metadata_for_program_operators, # @manual - _get_program_from_buffer, # @manual - _get_program_operators, # @manual - _load_for_executorch_from_buffer, # @manual - IOMetaData, # @manual + _get_io_metadata_for_program_operators, + _get_program_from_buffer, + _get_program_operators, + _load_for_executorch_from_buffer, + IOMetaData, ) diff --git a/profiler/test/test_profiler_e2e.py b/profiler/test/test_profiler_e2e.py index d9bec8d0168..92a4bd02e2a 100644 --- a/profiler/test/test_profiler_e2e.py +++ b/profiler/test/test_profiler_e2e.py @@ -15,7 +15,6 @@ from executorch import exir -# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.extension.pybindings.portable`. from executorch.extension.pybindings.portable import ( _create_profile_block, _dump_profile_results, @@ -56,13 +55,11 @@ def setUpClass(cls) -> None: exir.capture(model, inputs).to_edge().to_executorch().buffer ) - # pyre-ignore: Undefined attribute [16]: Module `executorch.extension.pybindings` has no attribute `portable`. cls.module = _load_for_executorch_from_buffer(cls.__buffer) # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. cls.inputs_flattened, _ = tree_flatten(inputs) cls.module.run_method("forward", tuple(cls.inputs_flattened)) - # pyre-ignore: Undefined attribute [16]: Module `executorch.extension.pybindings` has no attribute `portable`. prof_dump = _dump_profile_results() assert ( len(prof_dump) > 0 @@ -72,15 +69,11 @@ def setUpClass(cls) -> None: def test_profiler_new_block(self) -> None: block_names = ["block_1", "block_2"] - # pyre-ignore: Undefined attribute [16]: Module `executorch.extension.pybindings` has no attribute `portable`. _reset_profile_results() - # pyre-ignore: Undefined attribute [16]: Module `executorch.extension.pybindings` has no attribute `portable`. _create_profile_block(block_names[0]) self.module.run_method("forward", tuple(self.inputs_flattened)) - # pyre-ignore: Undefined attribute [16]: Module `executorch.extension.pybindings` has no attribute `portable`. _create_profile_block(block_names[1]) self.module.run_method("forward", tuple(self.inputs_flattened)) - # pyre-ignore: Undefined attribute [16]: Module `executorch.extension.pybindings` has no attribute `portable`. prof_dump = _dump_profile_results() self.assertGreater( len(prof_dump), diff --git a/shim/xplat/executorch/build/runtime_wrapper.bzl b/shim/xplat/executorch/build/runtime_wrapper.bzl index 78bd43a8cea..841463b551a 100644 --- a/shim/xplat/executorch/build/runtime_wrapper.bzl +++ b/shim/xplat/executorch/build/runtime_wrapper.bzl @@ -270,6 +270,8 @@ def _cxx_test(*args, **kwargs): def _cxx_python_extension(*args, **kwargs): _patch_kwargs_common(kwargs) kwargs["srcs"] = _patch_executorch_references(kwargs["srcs"]) + if "types" in kwargs: + kwargs["types"] = _patch_executorch_references(kwargs["types"]) env.cxx_python_extension(*args, **kwargs) def _export_file(*args, **kwargs): diff --git a/test/end2end/test_end2end.py b/test/end2end/test_end2end.py index 5f10f9cb919..6c3ee9eb058 100644 --- a/test/end2end/test_end2end.py +++ b/test/end2end/test_end2end.py @@ -59,7 +59,6 @@ kernel_mode = None # either aten mode or lean mode try: - # pyre-fixme[21]: Could not find module `executorch.extension.pybindings.portable`. from executorch.extension.pybindings.portable import ( _load_bundled_program_from_buffer, _load_for_executorch_from_buffer, @@ -72,7 +71,6 @@ pass try: - # pyre-fixme[21]: Could not find module `executorch.extension.pybindings.portable`. from executorch.extension.pybindings.aten_mode_lib import ( _load_bundled_program_from_buffer, _load_for_executorch_from_buffer, @@ -554,7 +552,6 @@ def wrapper(self: unittest.TestCase) -> None: if run_executor: print("Running on the runtime") - # pyre-fixme[16]: Module `executorch.extension.pybindings` has no attribute `portable`. executorch_module = _load_for_executorch_from_buffer(buff) # compare the result between eager module and executor for idx, inputs in enumerate(inputs_list): @@ -607,7 +604,6 @@ def wrapper(self: unittest.TestCase) -> None: executorch_bundled_program ) - # pyre-fixme[16]: Module `executorch.extension.pybindings` has no attribute `portable`. default_execution_plan_id = Module.FORWARD_METHOD_INDEX # TODO(T144329357): check bundled attachment correctness