Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions extension/flat_tensor/flat_tensor_data_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data(
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
}

ET_NODISCARD Result<size_t> FlatTensorDataMap::load_data_into(
ET_NODISCARD Error FlatTensorDataMap::load_data_into(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for updating this 🙏

ET_UNUSED const char* key,
ET_UNUSED void* buffer,
ET_UNUSED size_t size) const {
Expand All @@ -156,7 +156,7 @@ ET_NODISCARD Result<size_t> FlatTensorDataMap::load_data_into(
return tensor_layout.error();
}
ET_CHECK_OR_RETURN_ERROR(
size < tensor_layout.get().nbytes(),
size <= tensor_layout.get().nbytes(),
InvalidArgument,
"Buffer size %zu is smaller than tensor size %zu",
size,
Expand Down Expand Up @@ -187,6 +187,7 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
if (index < 0 || index >= flat_tensor_->tensors()->size()) {
return Error::InvalidArgument;
}

return flat_tensor_->tensors()->Get(index)->fully_qualified_name()->c_str();
}

Expand Down
2 changes: 1 addition & 1 deletion extension/flat_tensor/flat_tensor_data_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class FlatTensorDataMap final : public executorch::runtime::NamedDataMap {
*
* @returns an Error indicating if the load was successful.
*/
ET_NODISCARD executorch::runtime::Result<size_t>
ET_NODISCARD executorch::runtime::Error
load_data_into(const char* key, void* buffer, size_t size) const override;

/**
Expand Down
23 changes: 23 additions & 0 deletions extension/flat_tensor/test/flat_tensor_data_map_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,26 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys) {
Result<const char*> key2_res = data_map->get_key(2);
EXPECT_EQ(key2_res.error(), Error::InvalidArgument);
}

TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_LoadInto) {
Result<FlatTensorDataMap> data_map =
FlatTensorDataMap::load(data_map_loader_.get());
EXPECT_EQ(data_map.error(), Error::Ok);

// get the metadata
auto meta_data_res = data_map->get_metadata("a");
ASSERT_EQ(meta_data_res.error(), Error::Ok);

// get data blob
void* data = malloc(meta_data_res->nbytes());
auto load_into_error =
data_map->load_data_into("a", data, meta_data_res->nbytes());
ASSERT_EQ(load_into_error, Error::Ok);

// Check tensor data is correct.
float* data_a = static_cast<float*>(data);
for (int i = 0; i < 4; i++) {
EXPECT_EQ(data_a[i], 3.0);
}
free(data);
}
5 changes: 2 additions & 3 deletions runtime/core/named_data_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@ class ET_EXPERIMENTAL NamedDataMap {
* size of the data for a given key.
* @param buffer The buffer to load the data into. Must point to at least
* `size` bytes of memory.
* @return Result containing the number of bytes written on success. This will
* fail if the buffer is too small.
* @returns an Error indicating if the load was successful.
*/
ET_NODISCARD virtual Result<size_t>
ET_NODISCARD virtual Error
load_data_into(const char* key, void* buffer, size_t size) const = 0;

/**
Expand Down
13 changes: 4 additions & 9 deletions runtime/executor/tensor_parser_exec_aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,17 +224,12 @@ ET_NODISCARD Result<void*> getTensorDataPtr(
if (!planned_ptr.ok()) {
return planned_ptr.error();
}
auto size =
auto load_error =
named_data_map->load_data_into(fqn, planned_ptr.get(), nbytes);
if (size.error() != Error::Ok) {
return size.error();
if (load_error != Error::Ok) {
return load_error;
}
ET_CHECK_OR_RETURN_ERROR(
size.get() == nbytes,
InvalidExternalData,
"Expected to load %zu bytes, actually loaded %u bytes",
nbytes,
static_cast<unsigned int>(size.get()));

return planned_ptr;
}
}
Expand Down
Loading