Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions extension/flat_tensor/test/flat_tensor_data_map_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ class FlatTensorDataMapTest : public ::testing::Test {
// first.
executorch::runtime::runtime_init();

// Load data map. The eager linear model is defined at:
// //executorch/test/models/linear_model.py
// Load data map. The eager addmul model is defined at:
// //executorch/test/models/export_program.py
const char* path = std::getenv("ET_MODULE_ADD_MUL_DATA_PATH");
Result<FileDataLoader> loader = FileDataLoader::from(path);
ASSERT_EQ(loader.error(), Error::Ok);
Expand Down
149 changes: 149 additions & 0 deletions runtime/executor/merged_data_map.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/runtime/core/named_data_map.h>

namespace executorch {
namespace ET_RUNTIME_NAMESPACE {
namespace internal {

/**
* A NamedDataMap implementation that wraps other NamedDataMaps.
*/
class MergedDataMap final : public NamedDataMap {
public:
/**
* Creates a new NamedDataMap that wraps two other data maps.
*
* @param[in] first The first NamedDataMap to merge.
* @param[in] second The second NamedDataMap to merge.
* Note: the data maps must outlive the MergedDataMap instance.
*/
static Result<MergedDataMap> load(
const NamedDataMap* first,
const NamedDataMap* second) {
ET_CHECK_OR_RETURN_ERROR(
first != nullptr && second != nullptr,
InvalidArgument,
"Input data map is null.");

// Check for duplicate keys.
for (uint32_t k = 0; k < first->get_num_keys().get(); k++) {
const auto key = first->get_key(k).get();
ET_CHECK_OR_RETURN_ERROR(
second->get_tensor_layout(key).error() == Error::NotFound,
InvalidArgument,
"Duplicate key %s.",
key);
}
return MergedDataMap(first, second);
}

/**
* Retrieve the tensor_layout for the specified key.
*
* @param[in] key The name of the tensor to get metadata on.
*
* @return Error::NotFound if the key is not present.
*/
ET_NODISCARD
Result<const TensorLayout> get_tensor_layout(
executorch::aten::string_view key) const override {
auto layout = first_->get_tensor_layout(key);
if (layout.ok()) {
return layout.get();
}
if (layout.error() != Error::NotFound) {
return layout.error();
}
return second_->get_tensor_layout(key);
}

/**
* Retrieve read-only data for the specified key.
*
* @param[in] key The name of the tensor to get data on.
*
* @return error if the key is not present or data cannot be loaded.
*/
ET_NODISCARD
Result<FreeableBuffer> get_data(
executorch::aten::string_view key) const override {
auto data = first_->get_data(key);
if (data.error() != Error::NotFound) {
return data;
}
return second_->get_data(key);
}

/**
* Loads the data of the specified tensor into the provided buffer.
* Not used in the MergedDataMap.
*
* @param[in] key The name of the tensor to get the data of.
* @param[in] buffer The buffer to load data into. Must point to at least
* `size` bytes of memory.
* @param[in] size The number of bytes to load.
*
* @returns an Error indicating if the load was successful.
*/
ET_NODISCARD Error load_data_into(
ET_UNUSED executorch::aten::string_view key,
ET_UNUSED void* buffer,
ET_UNUSED size_t size) const override {
return Error::NotImplemented;
}

/**
* @returns The number of keys in the map.
*/
ET_NODISCARD Result<uint32_t> get_num_keys() const override {
return first_->get_num_keys().get() + second_->get_num_keys().get();
}

/**
* @returns The key at the specified index, error if index out of bounds.
*/
ET_NODISCARD Result<const char*> get_key(uint32_t index) const override {
uint32_t total_num_keys = get_num_keys().get();
ET_CHECK_OR_RETURN_ERROR(
index >= 0 && index < total_num_keys,
InvalidArgument,
"Index %u out of range of size %u",
index,
total_num_keys);

if (index < first_->get_num_keys().get()) {
return first_->get_key(index);
} else {
return second_->get_key(index - first_->get_num_keys().get());
}
}

MergedDataMap(MergedDataMap&&) noexcept = default;

~MergedDataMap() override = default;

private:
MergedDataMap(const NamedDataMap* first, const NamedDataMap* second)
: first_{first}, second_{second} {}

// Not copyable or assignable.
MergedDataMap(const MergedDataMap& rhs) = delete;
MergedDataMap& operator=(MergedDataMap&& rhs) noexcept = delete;
MergedDataMap& operator=(const MergedDataMap& rhs) = delete;

const NamedDataMap* first_;
const NamedDataMap* second_;
};

} // namespace internal
} // namespace ET_RUNTIME_NAMESPACE
} // namespace executorch
10 changes: 10 additions & 0 deletions runtime/executor/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ def define_common_targets():
exported_preprocessor_flags = [] if runtime.is_oss else ["-DEXECUTORCH_INTERNAL_FLATBUFFERS=1"],
)

runtime.cxx_library(
name = "merged_data_map" + aten_suffix,
exported_headers = [
"merged_data_map.h",
],
exported_deps = [
"//executorch/runtime/core:named_data_map" + aten_suffix,
],
)

runtime.cxx_library(
name = "program" + aten_suffix,
exported_deps = [
Expand Down
148 changes: 148 additions & 0 deletions runtime/executor/test/merged_data_map_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/extension/data_loader/file_data_loader.h>
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/result.h>
#include <executorch/runtime/executor/merged_data_map.h>
#include <executorch/runtime/platform/runtime.h>

#include <gtest/gtest.h>

using namespace ::testing;
using executorch::extension::FileDataLoader;
using executorch::extension::FlatTensorDataMap;
using executorch::runtime::DataLoader;
using executorch::runtime::Error;
using executorch::runtime::FreeableBuffer;
using executorch::runtime::NamedDataMap;
using executorch::runtime::Result;
using executorch::runtime::TensorLayout;
using executorch::runtime::internal::MergedDataMap;

class MergedDataMapTest : public ::testing::Test {
protected:
void load_flat_tensor_data_map(const char* path, const char* module_name) {
Result<FileDataLoader> loader = FileDataLoader::from(path);
ASSERT_EQ(loader.error(), Error::Ok);
loaders_.insert(
{module_name,
std::make_unique<FileDataLoader>(std::move(loader.get()))});

Result<FlatTensorDataMap> data_map =
FlatTensorDataMap::load(loaders_[module_name].get());
EXPECT_EQ(data_map.error(), Error::Ok);

data_maps_.insert(
{module_name,
std::make_unique<FlatTensorDataMap>(std::move(data_map.get()))});
}

void SetUp() override {
// Since these tests cause ET_LOG to be called, the PAL must be initialized
// first.
executorch::runtime::runtime_init();

// Load FlatTensor data maps.
// The eager addmul and linear models are defined at:
// //executorch/test/models/export_program.py
load_flat_tensor_data_map(
std::getenv("ET_MODULE_ADD_MUL_DATA_PATH"), "addmul");
load_flat_tensor_data_map(
std::getenv("ET_MODULE_LINEAR_DATA_PATH"), "linear");
}

private:
// Must outlive data_maps_, but tests shouldn't need to touch it.
std::unordered_map<std::string, std::unique_ptr<FileDataLoader>> loaders_;

protected:
std::unordered_map<std::string, std::unique_ptr<NamedDataMap>> data_maps_;
};

// Check that two tensor layouts are equivalent.
void check_tensor_layout(TensorLayout& layout1, TensorLayout& layout2) {
EXPECT_EQ(layout1.scalar_type(), layout2.scalar_type());
EXPECT_EQ(layout1.nbytes(), layout2.nbytes());
EXPECT_EQ(layout1.sizes().size(), layout2.sizes().size());
for (size_t i = 0; i < layout1.sizes().size(); i++) {
EXPECT_EQ(layout1.sizes()[i], layout2.sizes()[i]);
}
EXPECT_EQ(layout1.dim_order().size(), layout2.dim_order().size());
for (size_t i = 0; i < layout1.dim_order().size(); i++) {
EXPECT_EQ(layout1.dim_order()[i], layout2.dim_order()[i]);
}
}

// Given that ndm is part of merged, check that all the API calls on ndm produce
// the same results as merged.
void compare_ndm_api_calls(
const NamedDataMap* ndm,
const NamedDataMap* merged) {
uint32_t num_keys = ndm->get_num_keys().get();
for (uint32_t i = 0; i < num_keys; i++) {
auto key = ndm->get_key(i).get();

// Compare get_tensor_layout.
auto ndm_meta = ndm->get_tensor_layout(key).get();
auto merged_meta = merged->get_tensor_layout(key).get();
check_tensor_layout(ndm_meta, merged_meta);

// Coompare get_data.
auto ndm_data = ndm->get_data(key);
auto merged_data = merged->get_data(key);
EXPECT_EQ(ndm_data.get().size(), merged_data.get().size());
for (size_t j = 0; j < ndm_meta.nbytes(); j++) {
EXPECT_EQ(
((uint8_t*)ndm_data.get().data())[j],
((uint8_t*)merged_data.get().data())[j]);
}
ndm_data->Free();
merged_data->Free();
}
}

TEST_F(MergedDataMapTest, LoadNullDataMap) {
Result<MergedDataMap> merged_map = MergedDataMap::load(nullptr, nullptr);
EXPECT_EQ(merged_map.error(), Error::InvalidArgument);
}

TEST_F(MergedDataMapTest, LoadMultipleDataMaps) {
Result<MergedDataMap> merged_map = MergedDataMap::load(
data_maps_["addmul"].get(), data_maps_["linear"].get());
EXPECT_EQ(merged_map.error(), Error::Ok);
}

TEST_F(MergedDataMapTest, LoadDuplicateDataMapsFail) {
Result<MergedDataMap> merged_map = MergedDataMap::load(
data_maps_["addmul"].get(), data_maps_["addmul"].get());
EXPECT_EQ(merged_map.error(), Error::InvalidArgument);
}

TEST_F(MergedDataMapTest, CheckDataMapContents) {
Result<MergedDataMap> merged_map = MergedDataMap::load(
data_maps_["addmul"].get(), data_maps_["linear"].get());
EXPECT_EQ(merged_map.error(), Error::Ok);

// Num keys.
size_t addmul_num_keys = data_maps_["addmul"]->get_num_keys().get();
size_t linear_num_keys = data_maps_["linear"]->get_num_keys().get();
EXPECT_EQ(
merged_map->get_num_keys().get(), addmul_num_keys + linear_num_keys);

// Load data into is not implemented for the merged data map.
void* memory_block = malloc(10);
ASSERT_EQ(
Error::NotImplemented, merged_map->load_data_into("a", memory_block, 10));
free(memory_block);

// API calls produce equivalent results.
compare_ndm_api_calls(data_maps_["addmul"].get(), &merged_map.get());
compare_ndm_api_calls(data_maps_["linear"].get(), &merged_map.get());
}
14 changes: 14 additions & 0 deletions runtime/executor/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def define_common_targets(is_fbcode = False):
"ET_MODULE_STATEFUL_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleStateful.pte])",
"ET_MODULE_ADD_MUL_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleAddMul.pte])",
"ET_MODULE_ADD_MUL_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleAddMul.ptd])",
"ET_MODULE_LINEAR_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.ptd])",
}

runtime.cxx_test(
Expand All @@ -142,6 +143,19 @@ def define_common_targets(is_fbcode = False):
env = modules_env,
)

runtime.cxx_test(
name = "merged_data_map_test",
srcs = [
"merged_data_map_test.cpp",
],
deps = [
"//executorch/extension/data_loader:file_data_loader",
"//executorch/extension/flat_tensor:flat_tensor_data_map",
"//executorch/runtime/executor:merged_data_map",
],
env = modules_env,
)

runtime.cxx_test(
name = "method_test",
srcs = [
Expand Down
Loading