diff --git a/extension/flat_tensor/test/flat_tensor_data_map_test.cpp b/extension/flat_tensor/test/flat_tensor_data_map_test.cpp index 5a94b47b954..37e1cd2edac 100644 --- a/extension/flat_tensor/test/flat_tensor_data_map_test.cpp +++ b/extension/flat_tensor/test/flat_tensor_data_map_test.cpp @@ -33,8 +33,8 @@ class FlatTensorDataMapTest : public ::testing::Test { // first. executorch::runtime::runtime_init(); - // Load data map. The eager linear model is defined at: - // //executorch/test/models/linear_model.py + // Load data map. The eager addmul model is defined at: + // //executorch/test/models/export_program.py const char* path = std::getenv("ET_MODULE_ADD_MUL_DATA_PATH"); Result loader = FileDataLoader::from(path); ASSERT_EQ(loader.error(), Error::Ok); diff --git a/runtime/executor/merged_data_map.h b/runtime/executor/merged_data_map.h new file mode 100644 index 00000000000..0f0175098ae --- /dev/null +++ b/runtime/executor/merged_data_map.h @@ -0,0 +1,149 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace ET_RUNTIME_NAMESPACE { +namespace internal { + +/** + * A NamedDataMap implementation that wraps other NamedDataMaps. + */ +class MergedDataMap final : public NamedDataMap { + public: + /** + * Creates a new NamedDataMap that wraps two other data maps. + * + * @param[in] first The first NamedDataMap to merge. + * @param[in] second The second NamedDataMap to merge. + * Note: the data maps must outlive the MergedDataMap instance. + */ + static Result load( + const NamedDataMap* first, + const NamedDataMap* second) { + ET_CHECK_OR_RETURN_ERROR( + first != nullptr && second != nullptr, + InvalidArgument, + "Input data map is null."); + + // Check for duplicate keys. + for (uint32_t k = 0; k < first->get_num_keys().get(); k++) { + const auto key = first->get_key(k).get(); + ET_CHECK_OR_RETURN_ERROR( + second->get_tensor_layout(key).error() == Error::NotFound, + InvalidArgument, + "Duplicate key %s.", + key); + } + return MergedDataMap(first, second); + } + + /** + * Retrieve the tensor_layout for the specified key. + * + * @param[in] key The name of the tensor to get metadata on. + * + * @return Error::NotFound if the key is not present. + */ + ET_NODISCARD + Result get_tensor_layout( + executorch::aten::string_view key) const override { + auto layout = first_->get_tensor_layout(key); + if (layout.ok()) { + return layout.get(); + } + if (layout.error() != Error::NotFound) { + return layout.error(); + } + return second_->get_tensor_layout(key); + } + + /** + * Retrieve read-only data for the specified key. + * + * @param[in] key The name of the tensor to get data on. + * + * @return error if the key is not present or data cannot be loaded. + */ + ET_NODISCARD + Result get_data( + executorch::aten::string_view key) const override { + auto data = first_->get_data(key); + if (data.error() != Error::NotFound) { + return data; + } + return second_->get_data(key); + } + + /** + * Loads the data of the specified tensor into the provided buffer. + * Not used in the MergedDataMap. + * + * @param[in] key The name of the tensor to get the data of. + * @param[in] buffer The buffer to load data into. Must point to at least + * `size` bytes of memory. + * @param[in] size The number of bytes to load. + * + * @returns an Error indicating if the load was successful. + */ + ET_NODISCARD Error load_data_into( + ET_UNUSED executorch::aten::string_view key, + ET_UNUSED void* buffer, + ET_UNUSED size_t size) const override { + return Error::NotImplemented; + } + + /** + * @returns The number of keys in the map. + */ + ET_NODISCARD Result get_num_keys() const override { + return first_->get_num_keys().get() + second_->get_num_keys().get(); + } + + /** + * @returns The key at the specified index, error if index out of bounds. + */ + ET_NODISCARD Result get_key(uint32_t index) const override { + uint32_t total_num_keys = get_num_keys().get(); + ET_CHECK_OR_RETURN_ERROR( + index >= 0 && index < total_num_keys, + InvalidArgument, + "Index %u out of range of size %u", + index, + total_num_keys); + + if (index < first_->get_num_keys().get()) { + return first_->get_key(index); + } else { + return second_->get_key(index - first_->get_num_keys().get()); + } + } + + MergedDataMap(MergedDataMap&&) noexcept = default; + + ~MergedDataMap() override = default; + + private: + MergedDataMap(const NamedDataMap* first, const NamedDataMap* second) + : first_{first}, second_{second} {} + + // Not copyable or assignable. + MergedDataMap(const MergedDataMap& rhs) = delete; + MergedDataMap& operator=(MergedDataMap&& rhs) noexcept = delete; + MergedDataMap& operator=(const MergedDataMap& rhs) = delete; + + const NamedDataMap* first_; + const NamedDataMap* second_; +}; + +} // namespace internal +} // namespace ET_RUNTIME_NAMESPACE +} // namespace executorch diff --git a/runtime/executor/targets.bzl b/runtime/executor/targets.bzl index 649b2c13cc1..98165373b73 100644 --- a/runtime/executor/targets.bzl +++ b/runtime/executor/targets.bzl @@ -69,6 +69,16 @@ def define_common_targets(): exported_preprocessor_flags = [] if runtime.is_oss else ["-DEXECUTORCH_INTERNAL_FLATBUFFERS=1"], ) + runtime.cxx_library( + name = "merged_data_map" + aten_suffix, + exported_headers = [ + "merged_data_map.h", + ], + exported_deps = [ + "//executorch/runtime/core:named_data_map" + aten_suffix, + ], + ) + runtime.cxx_library( name = "program" + aten_suffix, exported_deps = [ diff --git a/runtime/executor/test/merged_data_map_test.cpp b/runtime/executor/test/merged_data_map_test.cpp new file mode 100644 index 00000000000..c9d1d510b97 --- /dev/null +++ b/runtime/executor/test/merged_data_map_test.cpp @@ -0,0 +1,148 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +#include + +using namespace ::testing; +using executorch::extension::FileDataLoader; +using executorch::extension::FlatTensorDataMap; +using executorch::runtime::DataLoader; +using executorch::runtime::Error; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::NamedDataMap; +using executorch::runtime::Result; +using executorch::runtime::TensorLayout; +using executorch::runtime::internal::MergedDataMap; + +class MergedDataMapTest : public ::testing::Test { + protected: + void load_flat_tensor_data_map(const char* path, const char* module_name) { + Result loader = FileDataLoader::from(path); + ASSERT_EQ(loader.error(), Error::Ok); + loaders_.insert( + {module_name, + std::make_unique(std::move(loader.get()))}); + + Result data_map = + FlatTensorDataMap::load(loaders_[module_name].get()); + EXPECT_EQ(data_map.error(), Error::Ok); + + data_maps_.insert( + {module_name, + std::make_unique(std::move(data_map.get()))}); + } + + void SetUp() override { + // Since these tests cause ET_LOG to be called, the PAL must be initialized + // first. + executorch::runtime::runtime_init(); + + // Load FlatTensor data maps. + // The eager addmul and linear models are defined at: + // //executorch/test/models/export_program.py + load_flat_tensor_data_map( + std::getenv("ET_MODULE_ADD_MUL_DATA_PATH"), "addmul"); + load_flat_tensor_data_map( + std::getenv("ET_MODULE_LINEAR_DATA_PATH"), "linear"); + } + + private: + // Must outlive data_maps_, but tests shouldn't need to touch it. + std::unordered_map> loaders_; + + protected: + std::unordered_map> data_maps_; +}; + +// Check that two tensor layouts are equivalent. +void check_tensor_layout(TensorLayout& layout1, TensorLayout& layout2) { + EXPECT_EQ(layout1.scalar_type(), layout2.scalar_type()); + EXPECT_EQ(layout1.nbytes(), layout2.nbytes()); + EXPECT_EQ(layout1.sizes().size(), layout2.sizes().size()); + for (size_t i = 0; i < layout1.sizes().size(); i++) { + EXPECT_EQ(layout1.sizes()[i], layout2.sizes()[i]); + } + EXPECT_EQ(layout1.dim_order().size(), layout2.dim_order().size()); + for (size_t i = 0; i < layout1.dim_order().size(); i++) { + EXPECT_EQ(layout1.dim_order()[i], layout2.dim_order()[i]); + } +} + +// Given that ndm is part of merged, check that all the API calls on ndm produce +// the same results as merged. +void compare_ndm_api_calls( + const NamedDataMap* ndm, + const NamedDataMap* merged) { + uint32_t num_keys = ndm->get_num_keys().get(); + for (uint32_t i = 0; i < num_keys; i++) { + auto key = ndm->get_key(i).get(); + + // Compare get_tensor_layout. + auto ndm_meta = ndm->get_tensor_layout(key).get(); + auto merged_meta = merged->get_tensor_layout(key).get(); + check_tensor_layout(ndm_meta, merged_meta); + + // Coompare get_data. + auto ndm_data = ndm->get_data(key); + auto merged_data = merged->get_data(key); + EXPECT_EQ(ndm_data.get().size(), merged_data.get().size()); + for (size_t j = 0; j < ndm_meta.nbytes(); j++) { + EXPECT_EQ( + ((uint8_t*)ndm_data.get().data())[j], + ((uint8_t*)merged_data.get().data())[j]); + } + ndm_data->Free(); + merged_data->Free(); + } +} + +TEST_F(MergedDataMapTest, LoadNullDataMap) { + Result merged_map = MergedDataMap::load(nullptr, nullptr); + EXPECT_EQ(merged_map.error(), Error::InvalidArgument); +} + +TEST_F(MergedDataMapTest, LoadMultipleDataMaps) { + Result merged_map = MergedDataMap::load( + data_maps_["addmul"].get(), data_maps_["linear"].get()); + EXPECT_EQ(merged_map.error(), Error::Ok); +} + +TEST_F(MergedDataMapTest, LoadDuplicateDataMapsFail) { + Result merged_map = MergedDataMap::load( + data_maps_["addmul"].get(), data_maps_["addmul"].get()); + EXPECT_EQ(merged_map.error(), Error::InvalidArgument); +} + +TEST_F(MergedDataMapTest, CheckDataMapContents) { + Result merged_map = MergedDataMap::load( + data_maps_["addmul"].get(), data_maps_["linear"].get()); + EXPECT_EQ(merged_map.error(), Error::Ok); + + // Num keys. + size_t addmul_num_keys = data_maps_["addmul"]->get_num_keys().get(); + size_t linear_num_keys = data_maps_["linear"]->get_num_keys().get(); + EXPECT_EQ( + merged_map->get_num_keys().get(), addmul_num_keys + linear_num_keys); + + // Load data into is not implemented for the merged data map. + void* memory_block = malloc(10); + ASSERT_EQ( + Error::NotImplemented, merged_map->load_data_into("a", memory_block, 10)); + free(memory_block); + + // API calls produce equivalent results. + compare_ndm_api_calls(data_maps_["addmul"].get(), &merged_map.get()); + compare_ndm_api_calls(data_maps_["linear"].get(), &merged_map.get()); +} diff --git a/runtime/executor/test/targets.bzl b/runtime/executor/test/targets.bzl index 39ff0668d5d..7b4672e4414 100644 --- a/runtime/executor/test/targets.bzl +++ b/runtime/executor/test/targets.bzl @@ -125,6 +125,7 @@ def define_common_targets(is_fbcode = False): "ET_MODULE_STATEFUL_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleStateful.pte])", "ET_MODULE_ADD_MUL_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleAddMul.pte])", "ET_MODULE_ADD_MUL_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleAddMul.ptd])", + "ET_MODULE_LINEAR_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.ptd])", } runtime.cxx_test( @@ -142,6 +143,19 @@ def define_common_targets(is_fbcode = False): env = modules_env, ) + runtime.cxx_test( + name = "merged_data_map_test", + srcs = [ + "merged_data_map_test.cpp", + ], + deps = [ + "//executorch/extension/data_loader:file_data_loader", + "//executorch/extension/flat_tensor:flat_tensor_data_map", + "//executorch/runtime/executor:merged_data_map", + ], + env = modules_env, + ) + runtime.cxx_test( name = "method_test", srcs = [