diff --git a/extension/flat_tensor/flat_tensor_data_map.cpp b/extension/flat_tensor/flat_tensor_data_map.cpp index ff526e359d4..bf54ae014b5 100644 --- a/extension/flat_tensor/flat_tensor_data_map.cpp +++ b/extension/flat_tensor/flat_tensor_data_map.cpp @@ -141,7 +141,7 @@ ET_NODISCARD Result FlatTensorDataMap::get_data( DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External)); } -ET_NODISCARD Result FlatTensorDataMap::load_data_into( +ET_NODISCARD Error FlatTensorDataMap::load_data_into( ET_UNUSED const char* key, ET_UNUSED void* buffer, ET_UNUSED size_t size) const { @@ -156,7 +156,7 @@ ET_NODISCARD Result FlatTensorDataMap::load_data_into( return tensor_layout.error(); } ET_CHECK_OR_RETURN_ERROR( - size < tensor_layout.get().nbytes(), + size <= tensor_layout.get().nbytes(), InvalidArgument, "Buffer size %zu is smaller than tensor size %zu", size, @@ -187,6 +187,7 @@ ET_NODISCARD Result FlatTensorDataMap::get_key( if (index < 0 || index >= flat_tensor_->tensors()->size()) { return Error::InvalidArgument; } + return flat_tensor_->tensors()->Get(index)->fully_qualified_name()->c_str(); } diff --git a/extension/flat_tensor/flat_tensor_data_map.h b/extension/flat_tensor/flat_tensor_data_map.h index 00f4bf07d19..972a5fa9c55 100644 --- a/extension/flat_tensor/flat_tensor_data_map.h +++ b/extension/flat_tensor/flat_tensor_data_map.h @@ -75,7 +75,7 @@ class FlatTensorDataMap final : public executorch::runtime::NamedDataMap { * * @returns an Error indicating if the load was successful. */ - ET_NODISCARD executorch::runtime::Result + ET_NODISCARD executorch::runtime::Error load_data_into(const char* key, void* buffer, size_t size) const override; /** diff --git a/extension/flat_tensor/test/flat_tensor_data_map_test.cpp b/extension/flat_tensor/test/flat_tensor_data_map_test.cpp index 681bc39a129..ac4583eda88 100644 --- a/extension/flat_tensor/test/flat_tensor_data_map_test.cpp +++ b/extension/flat_tensor/test/flat_tensor_data_map_test.cpp @@ -137,3 +137,26 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys) { Result key2_res = data_map->get_key(2); EXPECT_EQ(key2_res.error(), Error::InvalidArgument); } + +TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_LoadInto) { + Result data_map = + FlatTensorDataMap::load(data_map_loader_.get()); + EXPECT_EQ(data_map.error(), Error::Ok); + + // get the metadata + auto meta_data_res = data_map->get_metadata("a"); + ASSERT_EQ(meta_data_res.error(), Error::Ok); + + // get data blob + void* data = malloc(meta_data_res->nbytes()); + auto load_into_error = + data_map->load_data_into("a", data, meta_data_res->nbytes()); + ASSERT_EQ(load_into_error, Error::Ok); + + // Check tensor data is correct. + float* data_a = static_cast(data); + for (int i = 0; i < 4; i++) { + EXPECT_EQ(data_a[i], 3.0); + } + free(data); +} diff --git a/runtime/core/named_data_map.h b/runtime/core/named_data_map.h index 68639ed872a..ef5e413db67 100644 --- a/runtime/core/named_data_map.h +++ b/runtime/core/named_data_map.h @@ -53,10 +53,9 @@ class ET_EXPERIMENTAL NamedDataMap { * size of the data for a given key. * @param buffer The buffer to load the data into. Must point to at least * `size` bytes of memory. - * @return Result containing the number of bytes written on success. This will - * fail if the buffer is too small. + * @returns an Error indicating if the load was successful. */ - ET_NODISCARD virtual Result + ET_NODISCARD virtual Error load_data_into(const char* key, void* buffer, size_t size) const = 0; /** diff --git a/runtime/executor/tensor_parser_exec_aten.cpp b/runtime/executor/tensor_parser_exec_aten.cpp index a1ac245acca..66202acabc3 100644 --- a/runtime/executor/tensor_parser_exec_aten.cpp +++ b/runtime/executor/tensor_parser_exec_aten.cpp @@ -224,17 +224,12 @@ ET_NODISCARD Result getTensorDataPtr( if (!planned_ptr.ok()) { return planned_ptr.error(); } - auto size = + auto load_error = named_data_map->load_data_into(fqn, planned_ptr.get(), nbytes); - if (size.error() != Error::Ok) { - return size.error(); + if (load_error != Error::Ok) { + return load_error; } - ET_CHECK_OR_RETURN_ERROR( - size.get() == nbytes, - InvalidExternalData, - "Expected to load %zu bytes, actually loaded %u bytes", - nbytes, - static_cast(size.get())); + return planned_ptr; } }