diff --git a/extension/data_loader/mmap_data_loader.cpp b/extension/data_loader/mmap_data_loader.cpp index 32999930b0a..10bd2f35f5e 100644 --- a/extension/data_loader/mmap_data_loader.cpp +++ b/extension/data_loader/mmap_data_loader.cpp @@ -150,10 +150,10 @@ void MunmapSegment(void* context, void* data, size_t size) { } } // namespace -Result MmapDataLoader::load( - size_t offset, - size_t size, - ET_UNUSED const DataLoader::SegmentInfo& segment_info) const { +/** + * Validates that file read range is within bounds. + */ +Error MmapDataLoader::validate_input(size_t offset, size_t size) const { ET_CHECK_OR_RETURN_ERROR( // Probably had its value moved to another instance. fd_ >= 0, @@ -173,6 +173,18 @@ Result MmapDataLoader::load( InvalidArgument, "Offset %zu too large for off_t", offset); + return Error::Ok; +} + +Result MmapDataLoader::load( + size_t offset, + size_t size, + ET_UNUSED const DataLoader::SegmentInfo& segment_info) const { + // Ensure read range is valid. + auto validation_err = validate_input(offset, size); + if (validation_err != Error::Ok) { + return validation_err; + } // mmap() will fail if the size is zero. if (size == 0) { @@ -267,5 +279,69 @@ Result MmapDataLoader::size() const { return file_size_; } +Error MmapDataLoader::load_into( + size_t offset, + size_t size, + ET_UNUSED const SegmentInfo& segment_info, + void* buffer) const { + ET_CHECK_OR_RETURN_ERROR( + buffer != nullptr, InvalidArgument, "Buffer is null"); + + // Ensure read range is valid. + auto err = validate_input(offset, size); + if (err != Error::Ok) { + return err; + } + + // Nothing to copy. + if (size == 0) { + return Error::Ok; + } + + // Find the range of pages that covers the requested region. + Range range = + get_overlapping_pages(static_cast(offset), size, page_size_); + + size_t map_size = range.size; + if (range.start + map_size > file_size_) { + // Clamp to the end of the file. + // + // The Windows implementation of mmap uses CreateFileMapping which returns + // error STATUS_SECTION_TOO_BIG (0xc0000040) if we try to map past the end + // of the last page of a file mapped in as read-only. + map_size = file_size_ - range.start; + } + + // Map the pages read-only. MAP_PRIVATE vs. MAP_SHARED doesn't matter since + // the data is read-only, but use PRIVATE just to further avoid accidentally + // modifying the file. + void* pages = ::mmap( + nullptr, + map_size, + PROT_READ, + MAP_PRIVATE, + fd_, + static_cast(range.start)); + ET_CHECK_OR_RETURN_ERROR( + pages != MAP_FAILED, + AccessFailed, + "Failed to map %s: mmap(..., size=%zd, ..., fd=%d, offset=0x%zx)", + file_name_, + range.size, + fd_, + range.start); + + // Offset into mapped region. + const size_t map_delta = offset - range.start; + + // Copy data into caller's buffer. + std::memcpy(buffer, static_cast(pages) + map_delta, size); + + // Unmap mapped region. + ::munmap(pages, map_size); + + return Error::Ok; +} + } // namespace extension } // namespace executorch diff --git a/extension/data_loader/mmap_data_loader.h b/extension/data_loader/mmap_data_loader.h index c55f81a490b..c0496a39d4b 100644 --- a/extension/data_loader/mmap_data_loader.h +++ b/extension/data_loader/mmap_data_loader.h @@ -95,6 +95,13 @@ class MmapDataLoader final : public executorch::runtime::DataLoader { ET_NODISCARD executorch::runtime::Result size() const override; + ET_NODISCARD + executorch::runtime::Error load_into( + size_t offset, + size_t size, + ET_UNUSED const SegmentInfo& segment_info, + void* buffer) const override; + private: MmapDataLoader( int fd, @@ -113,6 +120,10 @@ class MmapDataLoader final : public executorch::runtime::DataLoader { MmapDataLoader& operator=(const MmapDataLoader&) = delete; MmapDataLoader& operator=(MmapDataLoader&&) = delete; + ET_NODISCARD executorch::runtime::Error validate_input( + size_t offset, + size_t size) const; + const char* const file_name_; // String data is owned by the instance. const size_t file_size_; const size_t page_size_; diff --git a/extension/data_loader/test/mmap_data_loader_test.cpp b/extension/data_loader/test/mmap_data_loader_test.cpp index c01b3454493..76b972c46d0 100644 --- a/extension/data_loader/test/mmap_data_loader_test.cpp +++ b/extension/data_loader/test/mmap_data_loader_test.cpp @@ -376,3 +376,56 @@ TEST_F(MmapDataLoaderTest, DEPRECATEDFrom) { ASSERT_EQ(total_size.error(), Error::Ok); EXPECT_EQ(*total_size, contents_size); } + +// Tests that load_into copies bytes correctly. +TEST_F(MmapDataLoaderTest, LoadIntoCopiesCorrectly) { + // Create a test string. + const char* test_text = "FILE_CONTENTS"; + const size_t text_size = std::strlen(test_text); + TempFile tf(test_text); + + // Wrap it in a loader. + Result mdl = MmapDataLoader::from(tf.path().c_str()); + ASSERT_EQ(mdl.error(), Error::Ok); + + // Destination buffer. + std::vector dst(text_size); + + // Call load_into() + Error err = mdl->load_into( + /*offset=*/0, + /*size=*/text_size, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program), + dst.data()); + ASSERT_EQ(err, Error::Ok); + + // Verify memory copied correctly. + EXPECT_EQ(0, std::memcmp(dst.data(), test_text, text_size)); +} + +// Tests that load_into copies offset slice correctly. +TEST_F(MmapDataLoaderTest, LoadIntoCopiesOffsetCorrectly) { + // Create a test string. + const char* contents = "ABCDEFGH"; + TempFile tf(contents); + + // Wrap it in a loader. + Result mdl = MmapDataLoader::from(tf.path().c_str()); + ASSERT_EQ(mdl.error(), Error::Ok); + + // Copying 3 bytes starting at offset 2 = "CDE" + const size_t offset = 2; + const size_t size = 3; + uint8_t dst[size]; + + // Call load_into() + Error err = mdl->load_into( + offset, + size, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program), + dst); + ASSERT_EQ(err, Error::Ok); + + // Verify memory copied correctly. + EXPECT_EQ(0, std::memcmp(dst, contents + offset, size)); +} \ No newline at end of file