Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions extension/module/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ namespace extension {
namespace ET_MODULE_NAMESPACE {

using ET_MERGED_DATA_MAP_NAMESPACE::MergedDataMap;
using ET_RUNTIME_NAMESPACE::Kernel;
using ET_RUNTIME_NAMESPACE::MethodMeta;
using ET_RUNTIME_NAMESPACE::Program;

Expand Down Expand Up @@ -365,7 +366,8 @@ runtime::Error Module::load_method(
const std::string& method_name,
runtime::HierarchicalAllocator* planned_memory,
torch::executor::EventTracer* event_tracer,
const LoadBackendOptionsMap* backend_options) {
const LoadBackendOptionsMap* backend_options,
std::vector<Kernel> kernel_registry) {
if (!is_method_loaded(method_name)) {
ET_CHECK_OK_OR_RETURN_ERROR(load());

Expand Down Expand Up @@ -402,12 +404,16 @@ runtime::Error Module::load_method(

method_holder.memory_manager = std::make_unique<runtime::MemoryManager>(
memory_allocator_.get(), planned_memory, temp_allocator_.get());
method_holder.kernel_registry = std::move(kernel_registry);
auto res_method = program_->load_method(
method_name.c_str(),
method_holder.memory_manager.get(),
event_tracer ? event_tracer : this->event_tracer(),
merged_data_map_.get(),
effective_backend_options);
effective_backend_options,
runtime::Span<const Kernel>(
method_holder.kernel_registry.data(),
method_holder.kernel_registry.size()));
if (!res_method.ok()) {
return res_method.error();
}
Expand Down
14 changes: 11 additions & 3 deletions extension/module/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
namespace executorch {
namespace extension {

using ET_RUNTIME_NAMESPACE::Kernel;
using ET_RUNTIME_NAMESPACE::Method;
using ET_RUNTIME_NAMESPACE::MethodMeta;
using ET_RUNTIME_NAMESPACE::NamedDataMap;
Expand Down Expand Up @@ -255,7 +256,8 @@ class Module {
const std::string& method_name,
runtime::HierarchicalAllocator* planned_memory = nullptr,
torch::executor::EventTracer* event_tracer = nullptr,
const LoadBackendOptionsMap* backend_options = nullptr);
const LoadBackendOptionsMap* backend_options = nullptr,
std::vector<Kernel> kernel_registry = {});

ET_DEPRECATED ET_NODISCARD runtime::Error inline load_method(
const std::string& method_name,
Expand Down Expand Up @@ -303,9 +305,14 @@ class Module {
ET_NODISCARD inline runtime::Error load_forward(
runtime::HierarchicalAllocator* planned_memory = nullptr,
torch::executor::EventTracer* event_tracer = nullptr,
const LoadBackendOptionsMap* backend_options = nullptr) {
const LoadBackendOptionsMap* backend_options = nullptr,
std::vector<Kernel> kernel_registry = {}) {
return load_method(
"forward", planned_memory, event_tracer, backend_options);
"forward",
planned_memory,
event_tracer,
backend_options,
std::move(kernel_registry));
}

ET_DEPRECATED ET_NODISCARD inline runtime::Error load_forward(
Expand Down Expand Up @@ -698,6 +705,7 @@ class Module {
std::unique_ptr<PlannedMemory> planned_memory;
std::unique_ptr<runtime::MemoryManager> memory_manager;
std::unique_ptr<Method> method;
std::vector<Kernel> kernel_registry;
};

std::string file_path_;
Expand Down
Loading