diff --git a/examples/portable/executor_runner/executor_runner.cpp b/examples/portable/executor_runner/executor_runner.cpp index 93c150c0b90..f1a2d3b8f2f 100644 --- a/examples/portable/executor_runner/executor_runner.cpp +++ b/examples/portable/executor_runner/executor_runner.cpp @@ -22,6 +22,9 @@ #include +#include +#include + #include #include #include @@ -36,6 +39,10 @@ DEFINE_string( model_path, "model.pte", "Model serialized in flatbuffer format."); +DEFINE_bool( + is_fd_uri, + false, + "True if the model_path passed is a file descriptor with the prefix \"fd:///\"."); using executorch::extension::FileDataLoader; using executorch::runtime::Error; @@ -66,7 +73,12 @@ int main(int argc, char** argv) { // DataLoaders that use mmap() or point to data that's already in memory, and // users can create their own DataLoaders to load from arbitrary sources. const char* model_path = FLAGS_model_path.c_str(); - Result loader = FileDataLoader::from(model_path); + const bool is_fd_uri = FLAGS_is_fd_uri; + + Result loader = is_fd_uri + ? FileDataLoader::fromFileDescriptorUri(model_path) + : FileDataLoader::from(model_path); + ET_CHECK_MSG( loader.ok(), "FileDataLoader::from() failed: 0x%" PRIx32, diff --git a/extension/data_loader/file_data_loader.cpp b/extension/data_loader/file_data_loader.cpp index 1d097cfd989..0324751bfa4 100644 --- a/extension/data_loader/file_data_loader.cpp +++ b/extension/data_loader/file_data_loader.cpp @@ -43,6 +43,8 @@ namespace extension { namespace { +static constexpr char kFdFilesystemPrefix[] = "fd:///"; + /** * Returns true if the value is an integer power of 2. */ @@ -74,25 +76,36 @@ FileDataLoader::~FileDataLoader() { ::close(fd_); } -Result FileDataLoader::from( - const char* file_name, - size_t alignment) { +Result getFDFromUri(const char* file_descriptor_uri) { + // check if the uri starts with the prefix "fd://" ET_CHECK_OR_RETURN_ERROR( - is_power_of_2(alignment), + strncmp( + file_descriptor_uri, + kFdFilesystemPrefix, + strlen(kFdFilesystemPrefix)) == 0, InvalidArgument, - "Alignment %zu is not a power of 2", - alignment); + "File descriptor uri (%s) does not start with %s", + file_descriptor_uri, + kFdFilesystemPrefix); - // Use open() instead of fopen() to avoid the layer of buffering that - // fopen() does. We will be reading large portions of the file in one shot, - // so buffering does not help. - int fd = ::open(file_name, O_RDONLY); - if (fd < 0) { - ET_LOG( - Error, "Failed to open %s: %s (%d)", file_name, strerror(errno), errno); - return Error::AccessFailed; - } + // strip "fd:///" from the uri + int fd_len = strlen(file_descriptor_uri) - strlen(kFdFilesystemPrefix); + char fd_without_prefix[fd_len + 1]; + memcpy( + fd_without_prefix, + &file_descriptor_uri[strlen(kFdFilesystemPrefix)], + fd_len); + fd_without_prefix[fd_len] = '\0'; + + // check if remaining fd string is a valid integer + int fd = ::atoi(fd_without_prefix); + return fd; +} +Result FileDataLoader::fromFileDescriptor( + const char* file_name, + const int fd, + size_t alignment) { // Cache the file size. struct stat st; int err = ::fstat(fd, &st); @@ -119,6 +132,47 @@ Result FileDataLoader::from( return FileDataLoader(fd, file_size, alignment, file_name_copy); } +Result FileDataLoader::fromFileDescriptorUri( + const char* file_descriptor_uri, + size_t alignment) { + ET_CHECK_OR_RETURN_ERROR( + is_power_of_2(alignment), + InvalidArgument, + "Alignment %zu is not a power of 2", + alignment); + + auto parsed_fd = getFDFromUri(file_descriptor_uri); + if (!parsed_fd.ok()) { + return parsed_fd.error(); + } + + int fd = parsed_fd.get(); + + return fromFileDescriptor(file_descriptor_uri, fd, alignment); +} + +Result FileDataLoader::from( + const char* file_name, + size_t alignment) { + ET_CHECK_OR_RETURN_ERROR( + is_power_of_2(alignment), + InvalidArgument, + "Alignment %zu is not a power of 2", + alignment); + + // Use open() instead of fopen() to avoid the layer of buffering that + // fopen() does. We will be reading large portions of the file in one shot, + // so buffering does not help. + int fd = ::open(file_name, O_RDONLY); + if (fd < 0) { + ET_LOG( + Error, "Failed to open %s: %s (%d)", file_name, strerror(errno), errno); + return Error::AccessFailed; + } + + return fromFileDescriptor(file_name, fd, alignment); +} + namespace { /** * FreeableBuffer::FreeFn-compatible callback. diff --git a/extension/data_loader/file_data_loader.h b/extension/data_loader/file_data_loader.h index 7cf2a92c4ad..959684137b8 100644 --- a/extension/data_loader/file_data_loader.h +++ b/extension/data_loader/file_data_loader.h @@ -26,6 +26,27 @@ namespace extension { */ class FileDataLoader final : public executorch::runtime::DataLoader { public: + /** + * Creates a new FileDataLoader that wraps the named file descriptor, and the + * ownership of the file descriptor is passed. This helper is used when ET is + * running in a process that does not have access to the filesystem, and the + * caller is able to open the file and pass the file descriptor. + * + * @param[in] file_descriptor_uri File descriptor with the prefix "fd:///", + * followed by the file descriptor number. + * @param[in] alignment Alignment in bytes of pointers returned by this + * instance. Must be a power of two. + * + * @returns A new FileDataLoader on success. + * @retval Error::InvalidArgument `alignment` is not a power of two. + * @retval Error::AccessFailed `file_name` could not be opened, or its size + * could not be found. + * @retval Error::MemoryAllocationFailed Internal memory allocation failure. + */ + static executorch::runtime::Result fromFileDescriptorUri( + const char* file_descriptor_uri, + size_t alignment = alignof(std::max_align_t)); + /** * Creates a new FileDataLoader that wraps the named file. * @@ -79,6 +100,11 @@ class FileDataLoader final : public executorch::runtime::DataLoader { void* buffer) const override; private: + static executorch::runtime::Result fromFileDescriptor( + const char* file_name, + const int fd, + size_t alignment = alignof(std::max_align_t)); + FileDataLoader( int fd, size_t file_size, diff --git a/extension/data_loader/test/file_data_loader_test.cpp b/extension/data_loader/test/file_data_loader_test.cpp index 1d4f4c16196..b8921aebb54 100644 --- a/extension/data_loader/test/file_data_loader_test.cpp +++ b/extension/data_loader/test/file_data_loader_test.cpp @@ -40,6 +40,103 @@ class FileDataLoaderTest : public ::testing::TestWithParam { } }; +TEST_P(FileDataLoaderTest, InBoundsFileDescriptorLoadsSucceed) { + // Write some heterogeneous data to a file. + uint8_t data[256]; + for (int i = 0; i < sizeof(data); ++i) { + data[i] = i; + } + TempFile tf(data, sizeof(data)); + + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = FileDataLoader::fromFileDescriptorUri( + ("fd:///" + std::to_string(fd)).c_str(), alignment()); + ASSERT_EQ(fdl.error(), Error::Ok); + + // size() should succeed and reflect the total size. + Result size = fdl->size(); + ASSERT_EQ(size.error(), Error::Ok); + EXPECT_EQ(*size, sizeof(data)); + + // Load the first bytes of the data. + { + Result fb = fdl->load( + /*offset=*/0, + /*size=*/8, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_ALIGNED(fb->data(), alignment()); + EXPECT_EQ(fb->size(), 8); + EXPECT_EQ( + 0, + std::memcmp( + fb->data(), + "\x00\x01\x02\x03" + "\x04\x05\x06\x07", + fb->size())); + + // Freeing should release the buffer and clear out the segment. + fb->Free(); + EXPECT_EQ(fb->size(), 0); + EXPECT_EQ(fb->data(), nullptr); + + // Safe to call multiple times. + fb->Free(); + } + + // Load the last few bytes of the data, a different size than the first time. + { + Result fb = fdl->load( + /*offset=*/sizeof(data) - 3, + /*size=*/3, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_ALIGNED(fb->data(), alignment()); + EXPECT_EQ(fb->size(), 3); + EXPECT_EQ(0, std::memcmp(fb->data(), "\xfd\xfe\xff", fb->size())); + } + + // Loading all of the data succeeds. + { + Result fb = fdl->load( + /*offset=*/0, + /*size=*/sizeof(data), + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_ALIGNED(fb->data(), alignment()); + EXPECT_EQ(fb->size(), sizeof(data)); + EXPECT_EQ(0, std::memcmp(fb->data(), data, fb->size())); + } + + // Loading zero-sized data succeeds, even at the end of the data. + { + Result fb = fdl->load( + /*offset=*/sizeof(data), + /*size=*/0, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_EQ(fb->size(), 0); + } +} + +TEST_P(FileDataLoaderTest, FileDescriptorLoadPrefixFail) { + // Write some heterogeneous data to a file. + uint8_t data[256]; + for (int i = 0; i < sizeof(data); ++i) { + data[i] = i; + } + TempFile tf(data, sizeof(data)); + + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = FileDataLoader::fromFileDescriptorUri( + std::to_string(fd).c_str(), alignment()); + ASSERT_EQ(fdl.error(), Error::InvalidArgument); +} + TEST_P(FileDataLoaderTest, InBoundsLoadsSucceed) { // Write some heterogeneous data to a file. uint8_t data[256];