Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions extension/flat_tensor/flat_tensor_data_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ Result<const flat_tensor_flatbuffer::NamedData*> get_named_data(
segments->size());
// Validate the segment.
ET_CHECK_OR_RETURN_ERROR(
segments->Get(segment_index)->offset() < segment_end_offset,
(segments->Get(segment_index)->offset() +
segments->Get(segment_index)->size()) <= segment_end_offset,
InvalidExternalData,
"Invalid segment offset %" PRIu64
" is larger than the segment_base_offset + segment_data_size %" PRIu64
Expand Down Expand Up @@ -206,15 +207,21 @@ ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
}
Result<FlatTensorHeader> fh =
FlatTensorHeader::Parse(header->data(), header->size());
if (fh.error() == Error::NotFound) {
// No header, throw error.
ET_LOG(Error, "No FlatTensorHeader found.");
return fh.error();
} else if (fh.error() != Error::Ok) {
// corruption, throw error.
ET_LOG(Error, "Flat tensor header may be corrupt.");
return fh.error();
}

ET_CHECK_OR_RETURN_ERROR(
fh.ok(),
InvalidExternalData,
"Failed to parse FlatTensor header with error code %u. File may be corrupt.",
static_cast<uint32_t>(fh.error()));

size_t expected_size = fh->segment_base_offset + fh->segment_data_size;
size_t actual_size = loader->size().get();
ET_CHECK_OR_RETURN_ERROR(
expected_size == actual_size,
InvalidExternalData,
"File size is too small; file may be corrupted or truncated. Expected %zu from flat_tensor header, received %zu from data loader",
expected_size,
actual_size);

// Load flatbuffer data as a segment.
Result<FreeableBuffer> flat_tensor_data = loader->load(
Expand Down
7 changes: 6 additions & 1 deletion extension/flat_tensor/serialize/serialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ runtime::Error save_ptd(
// Write the tensors.
size_t total_segment_size = 0;
uint32_t i = 0;
size_t tensor_count = tensor_map.size();
for (const auto& [name, tensor] : tensor_map) {
auto key = builder.CreateString(name);
// Write the tensor layouts.
Expand All @@ -99,7 +100,11 @@ runtime::Error save_ptd(
/*_fbb=*/builder,
/*offset=*/total_segment_size,
/*size=*/tensor.nbytes()));
total_segment_size += aligned_size(tensor.nbytes(), tensor_alignment);

// Do not pad the last tensor.
total_segment_size += (i == tensor_count - 1)
? tensor.nbytes()
: aligned_size(tensor.nbytes(), tensor_alignment);
i++;
}

Expand Down
32 changes: 27 additions & 5 deletions extension/flat_tensor/test/flat_tensor_data_map_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/extension/data_loader/buffer_data_loader.h>
#include <executorch/extension/data_loader/file_data_loader.h>
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
#include <executorch/extension/flat_tensor/serialize/flat_tensor_generated.h>
Expand All @@ -17,14 +18,15 @@
#include <gtest/gtest.h>

using namespace ::testing;
using executorch::extension::BufferDataLoader;
using executorch::extension::FileDataLoader;
using executorch::extension::FlatTensorDataMap;
using executorch::extension::FlatTensorHeader;
using executorch::runtime::DataLoader;
using executorch::runtime::Error;
using executorch::runtime::FreeableBuffer;
using executorch::runtime::Result;
using executorch::runtime::TensorLayout;
using torch::executor::util::FileDataLoader;

class FlatTensorDataMapTest : public ::testing::Test {
protected:
Expand All @@ -51,7 +53,7 @@ TEST_F(FlatTensorDataMapTest, LoadFlatTensorDataMap) {
EXPECT_EQ(data_map.error(), Error::Ok);
}

TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetMetadata) {
TEST_F(FlatTensorDataMapTest, GetMetadata) {
Result<FlatTensorDataMap> data_map =
FlatTensorDataMap::load(data_map_loader_.get());
EXPECT_EQ(data_map.error(), Error::Ok);
Expand Down Expand Up @@ -93,7 +95,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetMetadata) {
EXPECT_EQ(const_c_res.error(), Error::NotFound);
}

TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetData) {
TEST_F(FlatTensorDataMapTest, GetData) {
Result<FlatTensorDataMap> data_map =
FlatTensorDataMap::load(data_map_loader_.get());
EXPECT_EQ(data_map.error(), Error::Ok);
Expand All @@ -114,7 +116,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetData) {
EXPECT_EQ(data_c_res.error(), Error::NotFound);
}

TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys) {
TEST_F(FlatTensorDataMapTest, GetKeys) {
Result<FlatTensorDataMap> data_map =
FlatTensorDataMap::load(data_map_loader_.get());
EXPECT_EQ(data_map.error(), Error::Ok);
Expand All @@ -138,7 +140,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys) {
EXPECT_EQ(key2_res.error(), Error::InvalidArgument);
}

TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_LoadInto) {
TEST_F(FlatTensorDataMapTest, LoadInto) {
Result<FlatTensorDataMap> data_map =
FlatTensorDataMap::load(data_map_loader_.get());
EXPECT_EQ(data_map.error(), Error::Ok);
Expand All @@ -160,3 +162,23 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_LoadInto) {
}
free(data);
}

TEST_F(FlatTensorDataMapTest, LoadAndCheckSize) {
Result<FlatTensorDataMap> data_map =
FlatTensorDataMap::load(data_map_loader_.get());
EXPECT_EQ(data_map.error(), Error::Ok);

// Truncate the file.
size_t trunc_size = data_map_loader_->size().get() - 8;
Result<FreeableBuffer> truncated_file = data_map_loader_->load(
0,
trunc_size,
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
ASSERT_EQ(truncated_file.error(), Error::Ok);

BufferDataLoader truncated_loader =
BufferDataLoader(truncated_file->data(), trunc_size);
Result<FlatTensorDataMap> truncated_program =
FlatTensorDataMap::load(&truncated_loader);
ASSERT_EQ(truncated_program.error(), Error::InvalidExternalData);
}
1 change: 1 addition & 0 deletions extension/flat_tensor/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def define_common_targets(is_fbcode=False):
"flat_tensor_data_map_test.cpp",
],
deps = [
"//executorch/extension/data_loader:buffer_data_loader",
"//executorch/extension/data_loader:file_data_loader",
"//executorch/extension/flat_tensor:flat_tensor_data_map",
"//executorch/runtime/core:named_data_map",
Expand Down
6 changes: 4 additions & 2 deletions extension/flat_tensor/test/test_serialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,10 @@ TEST_F(FlatTensorSerializeTest, ValidFlatTensorSerialized) {
const uint64_t segment_offset = 48 + 280 + 8; // 8 is padding.
EXPECT_EQ(*(uint64_t*)(header_buffer + 24), segment_offset);

// Segment total size, 8 bytes of data (2 floats), 24 bytes of padding.
const uint64_t segment_size = 32;
// Segment total size = 20
// linear.bias: 4 bytes + 12 bytes of padding.
// linear.weight: 4 bytes + 0 padding (last segment).
const uint64_t segment_size = 20;
EXPECT_EQ(*(uint64_t*)(header_buffer + 32), segment_size);

// Check Flatbuffer
Expand Down
Loading