Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 47 additions & 20 deletions extension/module/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@ Module::Module(
load_mode_(load_mode),
memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
event_tracer_(std::move(event_tracer)),
data_map_loader_(nullptr),
data_map_(nullptr) {
event_tracer_(std::move(event_tracer)) {
runtime::runtime_init();
}

Expand All @@ -87,13 +85,27 @@ Module::Module(
const LoadMode load_mode,
std::unique_ptr<runtime::EventTracer> event_tracer)
: file_path_(file_path),
data_map_path_(data_map_path),
load_mode_(load_mode),
memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
event_tracer_(std::move(event_tracer)),
data_map_loader_(nullptr),
data_map_(nullptr) {
event_tracer_(std::move(event_tracer)) {
if (!data_map_path.empty()) {
data_files_.push_back(data_map_path);
}
runtime::runtime_init();
}

Module::Module(
const std::string& file_path,
std::vector<std::string> data_files,
const LoadMode load_mode,
std::unique_ptr<runtime::EventTracer> event_tracer)
: file_path_(file_path),
data_files_(std::move(data_files)),
load_mode_(load_mode),
memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
event_tracer_(std::move(event_tracer)) {
runtime::runtime_init();
}

Expand All @@ -110,9 +122,10 @@ Module::Module(
temp_allocator_(
temp_allocator ? std::move(temp_allocator)
: std::make_unique<MallocMemoryAllocator>()),
event_tracer_(std::move(event_tracer)),
data_map_loader_(std::move(data_map_loader)),
data_map_(nullptr) {
event_tracer_(std::move(event_tracer)) {
if (data_map_loader) {
data_map_loaders_.push_back(std::move(data_map_loader));
}
runtime::runtime_init();
}

Expand All @@ -129,9 +142,10 @@ Module::Module(
temp_allocator_(
temp_allocator ? std::move(temp_allocator)
: std::make_unique<MallocMemoryAllocator>()),
event_tracer_(std::move(event_tracer)),
data_map_loader_(std::move(data_map_loader)),
data_map_(nullptr) {
event_tracer_(std::move(event_tracer)) {
if (data_map_loader) {
data_map_loaders_.push_back(std::move(data_map_loader));
}
runtime::runtime_init();
}

Expand All @@ -140,14 +154,27 @@ runtime::Error Module::load(const Program::Verification verification) {
if (!data_loader_) {
data_loader_ = ET_UNWRAP(make_data_loader(file_path_, load_mode_));
}
if (!data_map_path_.empty()) {
data_map_loader_ =
ET_UNWRAP(make_data_loader(data_map_path_, load_mode_));
if (data_files_.size() > 0) {
ET_CHECK_OR_RETURN_ERROR(
data_files_.size() == 1,
NotImplemented,
"Multiple named data map paths are not supported yet.");
for (const auto& data_file : data_files_) {
data_map_loaders_.push_back(
ET_UNWRAP(make_data_loader(data_file, load_mode_)));
}
}
if (data_map_loader_) {
data_map_ =
ET_UNWRAP_UNIQUE(FlatTensorDataMap::load(data_map_loader_.get()));

if (data_map_loaders_.size() > 0) {
ET_CHECK_OR_RETURN_ERROR(
data_map_loaders_.size() == 1 && merged_data_map_ == nullptr,
NotImplemented,
"Multiple named data map loaders are not supported yet.");
// TODO(lfq): support multiple named data map loaders.
merged_data_map_ =
ET_UNWRAP_UNIQUE(FlatTensorDataMap::load(data_map_loaders_[0].get()));
}

auto program =
ET_UNWRAP_UNIQUE(Program::load(data_loader_.get(), verification));
program_ = std::shared_ptr<Program>(
Expand Down Expand Up @@ -209,7 +236,7 @@ runtime::Error Module::load_method(
method_name.c_str(),
method_holder.memory_manager.get(),
event_tracer ? event_tracer : this->event_tracer(),
data_map_.get()));
merged_data_map_.get()));
methods_.emplace(method_name, std::move(method_holder));
}
return runtime::Error::Ok;
Expand Down
24 changes: 20 additions & 4 deletions extension/module/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Module {
* memory locking behavior.
*
* @param[in] file_path The path to the ExecuTorch program file to load.
* @param[in] data_map_path The path to a .ptd file
* @param[in] data_map_path The path to a .ptd file.
* @param[in] load_mode The loading mode to use.
* @param[in] event_tracer A EventTracer used for tracking and logging events.
*/
Expand All @@ -80,6 +80,21 @@ class Module {
const LoadMode load_mode = LoadMode::File,
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);

/**
* Constructs an instance by loading a program from a file with specified
* memory locking behavior.
*
* @param[in] file_path The path to the ExecuTorch program file to load.
* @param[in] data_files The path to one or more .ptd file/s.
* @param[in] load_mode The loading mode to use.
* @param[in] event_tracer A EventTracer used for tracking and logging events.
*/
explicit Module(
const std::string& file_path,
std::vector<std::string> data_files,
const LoadMode load_mode = LoadMode::File,
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);

/**
* Constructs an instance with the provided data loader and memory allocator.
*
Expand Down Expand Up @@ -614,15 +629,16 @@ class Module {
};

std::string file_path_;
std::string data_map_path_;
std::vector<std::string> data_files_;
LoadMode load_mode_{LoadMode::File};
std::shared_ptr<Program> program_;
std::unique_ptr<runtime::DataLoader> data_loader_;
std::unique_ptr<runtime::MemoryAllocator> memory_allocator_;
std::unique_ptr<runtime::MemoryAllocator> temp_allocator_;
std::unique_ptr<runtime::EventTracer> event_tracer_;
std::unique_ptr<runtime::DataLoader> data_map_loader_;
std::unique_ptr<NamedDataMap> data_map_;
std::vector<std::unique_ptr<runtime::DataLoader>> data_map_loaders_;
std::vector<std::unique_ptr<NamedDataMap>> named_data_maps_;
std::unique_ptr<NamedDataMap> merged_data_map_;
ET_DEPRECATED std::vector<uint8_t> debug_buffer_;

protected:
Expand Down
15 changes: 15 additions & 0 deletions extension/module/test/module_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,3 +530,18 @@ TEST_F(ModuleTest, TestPTD) {
auto tensor = make_tensor_ptr({2, 2}, {2.f, 3.f, 4.f, 2.f});
ASSERT_EQ(module.forward(tensor).error(), Error::Ok);
}

TEST_F(ModuleTest, TestPTD_Multiple) {
std::vector<std::string> data_files = {add_mul_data_path_};
Module module(add_mul_path_, data_files);

ASSERT_EQ(module.load_method("forward"), Error::Ok);

auto tensor = make_tensor_ptr({2, 2}, {2.f, 3.f, 4.f, 2.f});
ASSERT_EQ(module.forward(tensor).error(), Error::Ok);

// Confirm that the data_file is not std::move'd away.
ASSERT_EQ(std::strcmp(data_files[0].c_str(), add_mul_data_path_.c_str()), 0);

// TODO(lfq): add test when merge capability is supported.
}
Loading