diff --git a/extension/llm/runner/audio.h b/extension/llm/runner/audio.h index ce71513ed17..0fcbceb719d 100644 --- a/extension/llm/runner/audio.h +++ b/extension/llm/runner/audio.h @@ -11,7 +11,6 @@ #pragma once #include #include -#include #include #include @@ -41,27 +40,16 @@ struct ET_EXPERIMENTAL RawAudio { */ class ET_EXPERIMENTAL Audio final { public: - // Default constructor - Audio() : batch_size_(0), n_bins_(0), n_frames_(0) {} - // Constructor for uint8_t data Audio( std::vector&& data, int32_t batch_size, int32_t n_bins, int32_t n_frames) - : data_(std::move(data)), - batch_size_(batch_size), - n_bins_(n_bins), - n_frames_(n_frames) { - ET_CHECK_MSG( - data_.index() == 0 && - std::get>(data_).size() == - static_cast(batch_size * n_bins * n_frames), - "data.size() (%zu) does not match batch_size * n_bins * n_frames (%d)", - std::get>(data_).size(), - batch_size * n_bins * n_frames); - } + : Audio(make_tensor_ptr( + {batch_size, n_bins, n_frames}, + std::move(data), + executorch::aten::ScalarType::Byte)) {} // Constructor for float data Audio( @@ -69,53 +57,55 @@ class ET_EXPERIMENTAL Audio final { int32_t batch_size, int32_t n_bins, int32_t n_frames) - : data_(std::move(data)), - batch_size_(batch_size), - n_bins_(n_bins), - n_frames_(n_frames) { - ET_CHECK_MSG( - data_.index() == 1 && - std::get>(data_).size() == - static_cast(batch_size * n_bins * n_frames), - "data.size() (%zu) does not match batch_size * n_bins * n_frames (%d)", - std::get>(data_).size(), - batch_size * n_bins * n_frames); + : Audio(make_tensor_ptr({batch_size, n_bins, n_frames}, std::move(data))) {} + + Audio(executorch::extension::TensorPtr tensor) : tensor_(std::move(tensor)) { + ET_CHECK_MSG(tensor_, "Null tensor"); + ET_CHECK_MSG(tensor_->dim() == 3, "Invalid tensor rank"); } // Type checkers bool is_uint8() const { - return std::holds_alternative>(data_); + return tensor_->scalar_type() == ::executorch::aten::ScalarType::Byte; } bool is_float() const { - return std::holds_alternative>(data_); + return tensor_->scalar_type() == ::executorch::aten::ScalarType::Float; } // Data access - const std::vector& get_uint8_data() const& { - return std::get>(data_); + std::vector copy_uint8_data() const { + ET_DCHECK_MSG(is_uint8(), "Audio dtype is not uint8"); + auto data = tensor_->const_data_ptr(); + return std::vector(data, data + tensor_->numel()); } - std::vector& get_uint8_data() & { - return std::get>(data_); + std::vector copy_uint8_data() { + ET_DCHECK_MSG(is_uint8(), "Audio dtype is not uint8"); + auto data = tensor_->const_data_ptr(); + return std::vector(data, data + tensor_->numel()); } - const std::vector& get_float_data() const& { - return std::get>(data_); + std::vector copy_float_data() const { + ET_DCHECK_MSG(is_float(), "Audio dtype is not float"); + auto data = tensor_->const_data_ptr(); + return std::vector(data, data + tensor_->numel()); } - std::vector& get_float_data() & { - return std::get>(data_); + std::vector copy_float_data() { + ET_DCHECK_MSG(is_float(), "Audio dtype is not float"); + auto data = tensor_->const_data_ptr(); + return std::vector(data, data + tensor_->numel()); } int32_t get_batch_size() const { - return batch_size_; + return tensor_->size(0); } int32_t get_n_bins() const { - return n_bins_; + return tensor_->size(1); } int32_t get_n_frames() const { - return n_frames_; + return tensor_->size(2); } /** * Convert the audio data to a TensorPtr, with optional batch dimension. @@ -124,34 +114,20 @@ class ET_EXPERIMENTAL Audio final { */ executorch::runtime::Result toTensor( bool with_batch = false) const { - std::vector sizes = { - get_batch_size(), get_n_bins(), get_n_frames()}; if (with_batch) { - sizes.insert(sizes.begin(), 1); - } - if (is_float()) { - return executorch::extension::from_blob( - const_cast(get_float_data().data()), - sizes, - ::executorch::aten::ScalarType::Float); - } else if (is_uint8()) { - return executorch::extension::from_blob( - const_cast(get_uint8_data().data()), - sizes, - ::executorch::aten::ScalarType::Byte); + return make_tensor_ptr( + *tensor_, + {1, + static_cast(tensor_->size(0)), + static_cast(tensor_->size(1)), + static_cast(tensor_->size(2))}); } - ET_LOG( - Error, - "Shouldn't reach here, audio data is not initialized with uint8_t or float vector."); - return ::executorch::runtime::Error::NotSupported; + return tensor_; } private: // Members - std::variant, std::vector> data_; - int32_t batch_size_; - int32_t n_bins_; - int32_t n_frames_; + executorch::extension::TensorPtr tensor_; }; } // namespace llm diff --git a/extension/llm/runner/image.h b/extension/llm/runner/image.h index dbdba273536..d6fba70d357 100644 --- a/extension/llm/runner/image.h +++ b/extension/llm/runner/image.h @@ -10,9 +10,7 @@ #pragma once #include -#include #include -#include #include #include @@ -22,21 +20,19 @@ namespace executorch { namespace extension { namespace llm { +// Assuming NCHW format class ET_EXPERIMENTAL Image { public: - // Default constructor - Image() : width_(0), height_(0), channels_(0) {} - // Constructor for uint8_t data Image( std::vector&& data, int32_t width, int32_t height, int32_t channels) - : data_(std::move(data)), - width_(width), - height_(height), - channels_(channels) {} + : Image(make_tensor_ptr( + {channels, height, width}, + std::move(data), + executorch::aten::ScalarType::Byte)) {} // Constructor for float data Image( @@ -44,78 +40,74 @@ class ET_EXPERIMENTAL Image { int32_t width, int32_t height, int32_t channels) - : data_(std::move(data)), - width_(width), - height_(height), - channels_(channels) {} + : Image(make_tensor_ptr({channels, height, width}, std::move(data))) {} + + Image(executorch::extension::TensorPtr tensor) : tensor_(std::move(tensor)) { + ET_CHECK_MSG(tensor_, "Null tensor"); + ET_CHECK_MSG(tensor_->dim() == 3, "Invalid tensor rank"); + } // Getters - int32_t width() const { - return width_; + int32_t channels() const { + return tensor_->size(0); } + int32_t height() const { - return height_; + return tensor_->size(1); } - int32_t channels() const { - return channels_; + + int32_t width() const { + return tensor_->size(2); } // Data access bool is_uint8() const { - return std::holds_alternative>(data_); + return tensor_->scalar_type() == ::executorch::aten::ScalarType::Byte; } bool is_float() const { - return std::holds_alternative>(data_); + return tensor_->scalar_type() == ::executorch::aten::ScalarType::Float; } - const std::vector& get_uint8_data() const& { - return std::get>(data_); + std::vector copy_uint8_data() const { + ET_DCHECK_MSG(is_uint8(), "Image dtype is not uint8"); + auto data = tensor_->const_data_ptr(); + return std::vector(data, data + tensor_->numel()); } - std::vector& get_uint8_data() & { - return std::get>(data_); + std::vector copy_uint8_data() { + ET_DCHECK_MSG(is_uint8(), "Image dtype is not uint8"); + auto data = tensor_->const_data_ptr(); + return std::vector(data, data + tensor_->numel()); } - const std::vector& get_float_data() const& { - return std::get>(data_); + std::vector copy_float_data() const { + ET_DCHECK_MSG(is_float(), "Image dtype is not float"); + auto data = tensor_->const_data_ptr(); + return std::vector(data, data + tensor_->numel()); } - std::vector& get_float_data() & { - return std::get>(data_); + std::vector copy_float_data() { + ET_DCHECK_MSG(is_float(), "Image dtype is not float"); + auto data = tensor_->const_data_ptr(); + return std::vector(data, data + tensor_->numel()); } executorch::runtime::Result toTensor( bool with_batch = false) const { - // Note: This creates a 3D tensor (CHW). The model might expect a 4D - // tensor (NCHW). The caller should handle reshaping if needed. - std::vector sizes = { - channels(), height(), width()}; if (with_batch) { - sizes.insert(sizes.begin(), 1); - } - if (is_float()) { - return executorch::extension::from_blob( - const_cast(get_float_data().data()), - sizes, - ::executorch::aten::ScalarType::Float); - } else if (is_uint8()) { - return executorch::extension::from_blob( - const_cast(get_uint8_data().data()), - sizes, - ::executorch::aten::ScalarType::Byte); + return make_tensor_ptr( + *tensor_, + {1, + executorch::aten::SizesType(tensor_->size(0)), + executorch::aten::SizesType(tensor_->size(1)), + executorch::aten::SizesType(tensor_->size(2))}); } - ET_LOG( - Error, "Image data is not initialized with uint8_t or float vector."); - return ::executorch::runtime::Error::NotSupported; + return tensor_; } private: - // Assuming NCHW format - std::variant, std::vector> data_; - int32_t width_; - int32_t height_; - int32_t channels_; + executorch::extension::TensorPtr tensor_; }; } // namespace llm diff --git a/extension/llm/runner/pybindings.cpp b/extension/llm/runner/pybindings.cpp index bcc6aba0f8e..79a8883eead 100644 --- a/extension/llm/runner/pybindings.cpp +++ b/extension/llm/runner/pybindings.cpp @@ -277,11 +277,11 @@ PYBIND11_MODULE(_llm_runner, m) { .def_property_readonly( "uint8_data", static_cast& (Image::*)() const&>( - &Image::get_uint8_data)) + &Image::copy_uint8_data)) .def_property_readonly( "float_data", static_cast& (Image::*)() const&>( - &Image::get_float_data)) + &Image::copy_float_data)) .def("__repr__", [](const Image& img) { std::string dtype = "unknown"; if (img.is_uint8()) { @@ -317,15 +317,14 @@ PYBIND11_MODULE(_llm_runner, m) { .def_property_readonly( "uint8_data", static_cast& (Audio::*)() const&>( - &Audio::get_uint8_data)) + &Audio::copy_uint8_data)) .def_property_readonly( "float_data", static_cast& (Audio::*)() const&>( - &Audio::get_float_data)) + &Audio::copy_float_data)) .def_property_readonly("batch_size", &Audio::get_batch_size) .def_property_readonly("n_bins", &Audio::get_n_bins) .def_property_readonly("n_frames", &Audio::get_n_frames) - .def("toTensor", &Audio::toTensor) .def("__repr__", [](const Audio& audio) { std::string dtype = "unknown"; if (audio.is_uint8()) { @@ -473,6 +472,10 @@ PYBIND11_MODULE(_llm_runner, m) { m.def( "make_image_input", [](torch::Tensor image_tensor) -> MultimodalInput { + if (image_tensor.scalar_type() != torch::kUInt8 && image_tensor.scalar_type() != torch::kFloat) { + throw std::runtime_error( + "Unsupported image tensor dtype. Only uint8 and float32 are supported."); + } if (image_tensor.dim() == 4) { if (image_tensor.size(0) != 1) { throw std::runtime_error( @@ -480,56 +483,24 @@ PYBIND11_MODULE(_llm_runner, m) { } image_tensor = image_tensor.squeeze(0); } - if (image_tensor.dim() != 3) { throw std::runtime_error( "Image tensor must be 3-dimensional (H, W, C) or 4-dimensional (1, H, W, C)"); } - - int64_t height, width, channels; // Check for memory format and permute to CHW if necessary if (image_tensor.is_contiguous(at::MemoryFormat::ChannelsLast)) { // Input is HWC, permute to CHW - height = image_tensor.size(0); - width = image_tensor.size(1); - channels = image_tensor.size(2); image_tensor = image_tensor.permute({2, 0, 1}); - } else if (image_tensor.is_contiguous(at::MemoryFormat::Contiguous)) { - // Input is CHW - channels = image_tensor.size(0); - height = image_tensor.size(1); - width = image_tensor.size(2); - } else { + } else if (!image_tensor.is_contiguous(at::MemoryFormat::Contiguous)) { throw std::runtime_error( "Image tensor must be contiguous in either channels last (H, W, C) or contiguous (C, H, W) format."); } - + int64_t channels = image_tensor.size(0); if (channels != 3 && channels != 4) { throw std::runtime_error( "Image must have 3 (RGB) or 4 (RGBA) channels"); } - - image_tensor = image_tensor.contiguous(); - if (image_tensor.scalar_type() == torch::kUInt8) { - uint8_t* data = image_tensor.data_ptr(); - std::vector image_data(data, data + image_tensor.numel()); - return MultimodalInput(Image( - std::move(image_data), - static_cast(width), - static_cast(height), - static_cast(channels))); - } else if (image_tensor.scalar_type() == torch::kFloat) { - float* data = image_tensor.data_ptr(); - std::vector image_data(data, data + image_tensor.numel()); - return MultimodalInput(Image( - std::move(image_data), - static_cast(width), - static_cast(height), - static_cast(channels))); - } else { - throw std::runtime_error( - "Unsupported image tensor dtype. Only uint8 and float32 are supported."); - } + return MultimodalInput(Image(tensor_to_tensor_ptr(image_tensor))); }, "Create an image input from a torch tensor (H, W, C), (1, H, W, C), (C, H, W), or (1, C, H, W)", py::arg("image_tensor")); @@ -537,36 +508,15 @@ PYBIND11_MODULE(_llm_runner, m) { m.def( "make_audio_input", [](torch::Tensor audio_tensor) -> MultimodalInput { - if (audio_tensor.dim() != 3) { + if (audio_tensor.scalar_type() != torch::kUInt8 && audio_tensor.scalar_type() != torch::kFloat) { throw std::runtime_error( - "Audio tensor must be 3-dimensional (batch_size, n_bins, n_frames)"); + "Unsupported audio tensor dtype. Only uint8 and float32 are supported."); } - - int64_t batch_size = audio_tensor.size(0); - int64_t n_bins = audio_tensor.size(1); - int64_t n_frames = audio_tensor.size(2); - - audio_tensor = audio_tensor.contiguous(); - if (audio_tensor.scalar_type() == torch::kUInt8) { - uint8_t* data = audio_tensor.data_ptr(); - std::vector audio_data(data, data + audio_tensor.numel()); - return MultimodalInput(Audio( - std::move(audio_data), - static_cast(batch_size), - static_cast(n_bins), - static_cast(n_frames))); - } else if (audio_tensor.scalar_type() == torch::kFloat) { - float* data = audio_tensor.data_ptr(); - std::vector audio_data(data, data + audio_tensor.numel()); - return MultimodalInput(Audio( - std::move(audio_data), - static_cast(batch_size), - static_cast(n_bins), - static_cast(n_frames))); - } else { + if (audio_tensor.dim() != 3) { throw std::runtime_error( - "Unsupported audio tensor dtype. Only uint8 and float32 are supported for preprocessed audio."); + "Audio tensor must be 3-dimensional (batch_size, n_bins, n_frames)"); } + return MultimodalInput(Audio(tensor_to_tensor_ptr(audio_tensor))); }, "Create a preprocessed audio input from a torch tensor (batch_size, n_bins, n_frames)", py::arg("audio_tensor")); @@ -644,4 +594,4 @@ PYBIND11_MODULE(_llm_runner, m) { .def("__repr__", [](const PyMultimodalRunner& runner) { return ""; }); -} \ No newline at end of file +} diff --git a/extension/llm/runner/test/test_multimodal_input.cpp b/extension/llm/runner/test/test_multimodal_input.cpp index 85d45d69173..b472517be33 100644 --- a/extension/llm/runner/test/test_multimodal_input.cpp +++ b/extension/llm/runner/test/test_multimodal_input.cpp @@ -71,7 +71,7 @@ TEST_F(MultimodalInputTest, ImageConstructorFromImage) { EXPECT_EQ(input.get_image().width(), 224); EXPECT_EQ(input.get_image().height(), 224); EXPECT_EQ(input.get_image().channels(), 3); - EXPECT_EQ(input.get_image().get_uint8_data().size(), 224 * 224 * 3); + EXPECT_EQ(input.get_image().copy_uint8_data().size(), 224 * 224 * 3); } TEST_F(MultimodalInputTest, ImageConstructorFromRvalueImage) { @@ -79,7 +79,7 @@ TEST_F(MultimodalInputTest, ImageConstructorFromRvalueImage) { int width = img.width(); int height = img.height(); int channels = img.channels(); - size_t data_size = img.get_uint8_data().size(); + size_t data_size = img.copy_uint8_data().size(); MultimodalInput input(std::move(img)); @@ -89,7 +89,7 @@ TEST_F(MultimodalInputTest, ImageConstructorFromRvalueImage) { EXPECT_EQ(input.get_image().width(), width); EXPECT_EQ(input.get_image().height(), height); EXPECT_EQ(input.get_image().channels(), channels); - EXPECT_EQ(input.get_image().get_uint8_data().size(), data_size); + EXPECT_EQ(input.get_image().copy_uint8_data().size(), data_size); } // Test copy constructor and assignment @@ -356,7 +356,7 @@ TEST_F(MultimodalInputTest, DifferentImageSizes) { EXPECT_EQ(input.get_image().width(), 32); EXPECT_EQ(input.get_image().height(), 32); EXPECT_EQ(input.get_image().channels(), 1); - EXPECT_EQ(input.get_image().get_uint8_data().size(), 32 * 32); + EXPECT_EQ(input.get_image().copy_uint8_data().size(), 32 * 32); } // Test with empty text