diff --git a/extension/flat_tensor/flat_tensor_data_map.cpp b/extension/flat_tensor/flat_tensor_data_map.cpp index 3a69dc8b92c..9a17e226a58 100644 --- a/extension/flat_tensor/flat_tensor_data_map.cpp +++ b/extension/flat_tensor/flat_tensor_data_map.cpp @@ -73,7 +73,8 @@ Result get_named_data( segments->size()); // Validate the segment. ET_CHECK_OR_RETURN_ERROR( - segments->Get(segment_index)->offset() < segment_end_offset, + (segments->Get(segment_index)->offset() + + segments->Get(segment_index)->size()) <= segment_end_offset, InvalidExternalData, "Invalid segment offset %" PRIu64 " is larger than the segment_base_offset + segment_data_size %" PRIu64 @@ -206,15 +207,21 @@ ET_NODISCARD Result FlatTensorDataMap::get_key( } Result fh = FlatTensorHeader::Parse(header->data(), header->size()); - if (fh.error() == Error::NotFound) { - // No header, throw error. - ET_LOG(Error, "No FlatTensorHeader found."); - return fh.error(); - } else if (fh.error() != Error::Ok) { - // corruption, throw error. - ET_LOG(Error, "Flat tensor header may be corrupt."); - return fh.error(); - } + + ET_CHECK_OR_RETURN_ERROR( + fh.ok(), + InvalidExternalData, + "Failed to parse FlatTensor header with error code %u. File may be corrupt.", + static_cast(fh.error())); + + size_t expected_size = fh->segment_base_offset + fh->segment_data_size; + size_t actual_size = loader->size().get(); + ET_CHECK_OR_RETURN_ERROR( + expected_size == actual_size, + InvalidExternalData, + "File size is too small; file may be corrupted or truncated. Expected %zu from flat_tensor header, received %zu from data loader", + expected_size, + actual_size); // Load flatbuffer data as a segment. Result flat_tensor_data = loader->load( diff --git a/extension/flat_tensor/serialize/serialize.cpp b/extension/flat_tensor/serialize/serialize.cpp index 9930de6bba6..ff9d568fdf4 100644 --- a/extension/flat_tensor/serialize/serialize.cpp +++ b/extension/flat_tensor/serialize/serialize.cpp @@ -76,6 +76,7 @@ runtime::Error save_ptd( // Write the tensors. size_t total_segment_size = 0; uint32_t i = 0; + size_t tensor_count = tensor_map.size(); for (const auto& [name, tensor] : tensor_map) { auto key = builder.CreateString(name); // Write the tensor layouts. @@ -99,7 +100,11 @@ runtime::Error save_ptd( /*_fbb=*/builder, /*offset=*/total_segment_size, /*size=*/tensor.nbytes())); - total_segment_size += aligned_size(tensor.nbytes(), tensor_alignment); + + // Do not pad the last tensor. + total_segment_size += (i == tensor_count - 1) + ? tensor.nbytes() + : aligned_size(tensor.nbytes(), tensor_alignment); i++; } diff --git a/extension/flat_tensor/test/flat_tensor_data_map_test.cpp b/extension/flat_tensor/test/flat_tensor_data_map_test.cpp index 37e1cd2edac..0872a988333 100644 --- a/extension/flat_tensor/test/flat_tensor_data_map_test.cpp +++ b/extension/flat_tensor/test/flat_tensor_data_map_test.cpp @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include #include @@ -17,6 +18,8 @@ #include using namespace ::testing; +using executorch::extension::BufferDataLoader; +using executorch::extension::FileDataLoader; using executorch::extension::FlatTensorDataMap; using executorch::extension::FlatTensorHeader; using executorch::runtime::DataLoader; @@ -24,7 +27,6 @@ using executorch::runtime::Error; using executorch::runtime::FreeableBuffer; using executorch::runtime::Result; using executorch::runtime::TensorLayout; -using torch::executor::util::FileDataLoader; class FlatTensorDataMapTest : public ::testing::Test { protected: @@ -51,7 +53,7 @@ TEST_F(FlatTensorDataMapTest, LoadFlatTensorDataMap) { EXPECT_EQ(data_map.error(), Error::Ok); } -TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetMetadata) { +TEST_F(FlatTensorDataMapTest, GetMetadata) { Result data_map = FlatTensorDataMap::load(data_map_loader_.get()); EXPECT_EQ(data_map.error(), Error::Ok); @@ -93,7 +95,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetMetadata) { EXPECT_EQ(const_c_res.error(), Error::NotFound); } -TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetData) { +TEST_F(FlatTensorDataMapTest, GetData) { Result data_map = FlatTensorDataMap::load(data_map_loader_.get()); EXPECT_EQ(data_map.error(), Error::Ok); @@ -114,7 +116,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetData) { EXPECT_EQ(data_c_res.error(), Error::NotFound); } -TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys) { +TEST_F(FlatTensorDataMapTest, GetKeys) { Result data_map = FlatTensorDataMap::load(data_map_loader_.get()); EXPECT_EQ(data_map.error(), Error::Ok); @@ -138,7 +140,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys) { EXPECT_EQ(key2_res.error(), Error::InvalidArgument); } -TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_LoadInto) { +TEST_F(FlatTensorDataMapTest, LoadInto) { Result data_map = FlatTensorDataMap::load(data_map_loader_.get()); EXPECT_EQ(data_map.error(), Error::Ok); @@ -160,3 +162,23 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_LoadInto) { } free(data); } + +TEST_F(FlatTensorDataMapTest, LoadAndCheckSize) { + Result data_map = + FlatTensorDataMap::load(data_map_loader_.get()); + EXPECT_EQ(data_map.error(), Error::Ok); + + // Truncate the file. + size_t trunc_size = data_map_loader_->size().get() - 8; + Result truncated_file = data_map_loader_->load( + 0, + trunc_size, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External)); + ASSERT_EQ(truncated_file.error(), Error::Ok); + + BufferDataLoader truncated_loader = + BufferDataLoader(truncated_file->data(), trunc_size); + Result truncated_program = + FlatTensorDataMap::load(&truncated_loader); + ASSERT_EQ(truncated_program.error(), Error::InvalidExternalData); +} diff --git a/extension/flat_tensor/test/targets.bzl b/extension/flat_tensor/test/targets.bzl index 4d798cc1a7c..d1272ec5d97 100644 --- a/extension/flat_tensor/test/targets.bzl +++ b/extension/flat_tensor/test/targets.bzl @@ -45,6 +45,7 @@ def define_common_targets(is_fbcode=False): "flat_tensor_data_map_test.cpp", ], deps = [ + "//executorch/extension/data_loader:buffer_data_loader", "//executorch/extension/data_loader:file_data_loader", "//executorch/extension/flat_tensor:flat_tensor_data_map", "//executorch/runtime/core:named_data_map", diff --git a/extension/flat_tensor/test/test_serialize.cpp b/extension/flat_tensor/test/test_serialize.cpp index 57a0253485b..35a1e9ee8dc 100644 --- a/extension/flat_tensor/test/test_serialize.cpp +++ b/extension/flat_tensor/test/test_serialize.cpp @@ -86,8 +86,10 @@ TEST_F(FlatTensorSerializeTest, ValidFlatTensorSerialized) { const uint64_t segment_offset = 48 + 280 + 8; // 8 is padding. EXPECT_EQ(*(uint64_t*)(header_buffer + 24), segment_offset); - // Segment total size, 8 bytes of data (2 floats), 24 bytes of padding. - const uint64_t segment_size = 32; + // Segment total size = 20 + // linear.bias: 4 bytes + 12 bytes of padding. + // linear.weight: 4 bytes + 0 padding (last segment). + const uint64_t segment_size = 20; EXPECT_EQ(*(uint64_t*)(header_buffer + 32), segment_size); // Check Flatbuffer