diff --git a/extension/llm/runner/audio.h b/extension/llm/runner/audio.h index ce71513ed17..cc7e6b1714a 100644 --- a/extension/llm/runner/audio.h +++ b/extension/llm/runner/audio.h @@ -11,7 +11,6 @@ #pragma once #include #include -#include #include #include @@ -41,27 +40,16 @@ struct ET_EXPERIMENTAL RawAudio { */ class ET_EXPERIMENTAL Audio final { public: - // Default constructor - Audio() : batch_size_(0), n_bins_(0), n_frames_(0) {} - // Constructor for uint8_t data Audio( std::vector&& data, int32_t batch_size, int32_t n_bins, int32_t n_frames) - : data_(std::move(data)), - batch_size_(batch_size), - n_bins_(n_bins), - n_frames_(n_frames) { - ET_CHECK_MSG( - data_.index() == 0 && - std::get>(data_).size() == - static_cast(batch_size * n_bins * n_frames), - "data.size() (%zu) does not match batch_size * n_bins * n_frames (%d)", - std::get>(data_).size(), - batch_size * n_bins * n_frames); - } + : Audio(make_tensor_ptr( + {batch_size, n_bins, n_frames}, + std::move(data), + executorch::aten::ScalarType::Byte)) {} // Constructor for float data Audio( @@ -69,89 +57,64 @@ class ET_EXPERIMENTAL Audio final { int32_t batch_size, int32_t n_bins, int32_t n_frames) - : data_(std::move(data)), - batch_size_(batch_size), - n_bins_(n_bins), - n_frames_(n_frames) { - ET_CHECK_MSG( - data_.index() == 1 && - std::get>(data_).size() == - static_cast(batch_size * n_bins * n_frames), - "data.size() (%zu) does not match batch_size * n_bins * n_frames (%d)", - std::get>(data_).size(), - batch_size * n_bins * n_frames); + : Audio(make_tensor_ptr({batch_size, n_bins, n_frames}, std::move(data))) {} + + explicit Audio( + executorch::extension::TensorPtr tensor) : tensor_(std::move(tensor)) { + ET_CHECK_MSG(tensor_, "Null tensor"); + ET_CHECK_MSG(tensor_->dim() == 3, "Invalid tensor rank"); } // Type checkers bool is_uint8() const { - return std::holds_alternative>(data_); + return tensor_->scalar_type() == ::executorch::aten::ScalarType::Byte; } bool is_float() const { - return std::holds_alternative>(data_); + return tensor_->scalar_type() == ::executorch::aten::ScalarType::Float; } // Data access - const std::vector& get_uint8_data() const& { - return std::get>(data_); - } - - std::vector& get_uint8_data() & { - return std::get>(data_); - } - - const std::vector& get_float_data() const& { - return std::get>(data_); + const uint8_t* uint8_data() const { + ET_DCHECK_MSG(is_uint8(), "Dtype is not uint8"); + return tensor_->const_data_ptr(); } - std::vector& get_float_data() & { - return std::get>(data_); + const float* float_data() const { + ET_DCHECK_MSG(is_float(), "Dtype is not float"); + return tensor_->const_data_ptr(); } int32_t get_batch_size() const { - return batch_size_; + return tensor_->size(0); } int32_t get_n_bins() const { - return n_bins_; + return tensor_->size(1); } int32_t get_n_frames() const { - return n_frames_; + return tensor_->size(2); } /** * Convert the audio data to a TensorPtr, with optional batch dimension. * The tensor will have shape (batch_size, n_bins, n_frames) or (1, * batch_size, n_bins, n_frames) if with_batch is true. */ - executorch::runtime::Result toTensor( + executorch::extension::TensorPtr tensor( bool with_batch = false) const { - std::vector sizes = { - get_batch_size(), get_n_bins(), get_n_frames()}; if (with_batch) { - sizes.insert(sizes.begin(), 1); - } - if (is_float()) { - return executorch::extension::from_blob( - const_cast(get_float_data().data()), - sizes, - ::executorch::aten::ScalarType::Float); - } else if (is_uint8()) { - return executorch::extension::from_blob( - const_cast(get_uint8_data().data()), - sizes, - ::executorch::aten::ScalarType::Byte); + return make_tensor_ptr( + *tensor_, + {1, + static_cast(tensor_->size(0)), + static_cast(tensor_->size(1)), + static_cast(tensor_->size(2))}); } - ET_LOG( - Error, - "Shouldn't reach here, audio data is not initialized with uint8_t or float vector."); - return ::executorch::runtime::Error::NotSupported; + return tensor_; } private: // Members - std::variant, std::vector> data_; - int32_t batch_size_; - int32_t n_bins_; - int32_t n_frames_; + executorch::extension::TensorPtr tensor_; }; } // namespace llm diff --git a/extension/llm/runner/image.h b/extension/llm/runner/image.h index dbdba273536..9c7746fff2a 100644 --- a/extension/llm/runner/image.h +++ b/extension/llm/runner/image.h @@ -10,9 +10,7 @@ #pragma once #include -#include #include -#include #include #include @@ -22,21 +20,19 @@ namespace executorch { namespace extension { namespace llm { +// Assuming NCHW format class ET_EXPERIMENTAL Image { public: - // Default constructor - Image() : width_(0), height_(0), channels_(0) {} - // Constructor for uint8_t data Image( std::vector&& data, int32_t width, int32_t height, int32_t channels) - : data_(std::move(data)), - width_(width), - height_(height), - channels_(channels) {} + : Image(make_tensor_ptr( + {channels, height, width}, + std::move(data), + executorch::aten::ScalarType::Byte)) {} // Constructor for float data Image( @@ -44,78 +40,60 @@ class ET_EXPERIMENTAL Image { int32_t width, int32_t height, int32_t channels) - : data_(std::move(data)), - width_(width), - height_(height), - channels_(channels) {} + : Image(make_tensor_ptr({channels, height, width}, std::move(data))) {} + + explicit Image(executorch::extension::TensorPtr tensor) : tensor_(std::move(tensor)) { + ET_CHECK_MSG(tensor_, "Null tensor"); + ET_CHECK_MSG(tensor_->dim() == 3, "Invalid tensor rank"); + } // Getters - int32_t width() const { - return width_; + int32_t channels() const { + return tensor_->size(0); } + int32_t height() const { - return height_; + return tensor_->size(1); } - int32_t channels() const { - return channels_; + + int32_t width() const { + return tensor_->size(2); } // Data access bool is_uint8() const { - return std::holds_alternative>(data_); + return tensor_->scalar_type() == ::executorch::aten::ScalarType::Byte; } bool is_float() const { - return std::holds_alternative>(data_); - } - - const std::vector& get_uint8_data() const& { - return std::get>(data_); + return tensor_->scalar_type() == ::executorch::aten::ScalarType::Float; } - std::vector& get_uint8_data() & { - return std::get>(data_); + const uint8_t* uint8_data() const { + ET_DCHECK_MSG(is_uint8(), "Dtype is not uint8"); + return tensor_->const_data_ptr(); } - const std::vector& get_float_data() const& { - return std::get>(data_); + const float* float_data() const { + ET_DCHECK_MSG(is_float(), "Dtype is not float"); + return tensor_->const_data_ptr(); } - std::vector& get_float_data() & { - return std::get>(data_); - } - - executorch::runtime::Result toTensor( + executorch::extension::TensorPtr tensor( bool with_batch = false) const { - // Note: This creates a 3D tensor (CHW). The model might expect a 4D - // tensor (NCHW). The caller should handle reshaping if needed. - std::vector sizes = { - channels(), height(), width()}; if (with_batch) { - sizes.insert(sizes.begin(), 1); - } - if (is_float()) { - return executorch::extension::from_blob( - const_cast(get_float_data().data()), - sizes, - ::executorch::aten::ScalarType::Float); - } else if (is_uint8()) { - return executorch::extension::from_blob( - const_cast(get_uint8_data().data()), - sizes, - ::executorch::aten::ScalarType::Byte); + return make_tensor_ptr( + *tensor_, + {1, + executorch::aten::SizesType(tensor_->size(0)), + executorch::aten::SizesType(tensor_->size(1)), + executorch::aten::SizesType(tensor_->size(2))}); } - ET_LOG( - Error, "Image data is not initialized with uint8_t or float vector."); - return ::executorch::runtime::Error::NotSupported; + return tensor_; } private: - // Assuming NCHW format - std::variant, std::vector> data_; - int32_t width_; - int32_t height_; - int32_t channels_; + executorch::extension::TensorPtr tensor_; }; } // namespace llm diff --git a/extension/llm/runner/multimodal_prefiller.cpp b/extension/llm/runner/multimodal_prefiller.cpp index 7f5a8356979..97d52268fd8 100644 --- a/extension/llm/runner/multimodal_prefiller.cpp +++ b/extension/llm/runner/multimodal_prefiller.cpp @@ -77,9 +77,7 @@ Result MultimodalPrefiller::prefill( // The model might expect a 4D tensor (NCHW), but toTensor() returns a 3D // tensor (CHW). Add a batch dimension of 1 if needed. auto expected_dims = input_meta.sizes(); - auto image_tensor = ET_UNWRAP( - image.toTensor(/*with_batch*/ expected_dims.size() == 4), - "Failed to convert image to tensor"); + auto image_tensor = image.tensor(/*with_batch*/ expected_dims.size() == 4); ET_LOG( Info, "Image tensor dim: %zu, dtype: %s", @@ -108,8 +106,7 @@ Result MultimodalPrefiller::prefill( auto expected_dtype = input_meta.scalar_type(); // Create tensor with original dtype - auto audio_tensor = - ET_UNWRAP(audio.toTensor(), "Failed to convert audio to tensor"); + auto audio_tensor = audio.tensor(); // Convert to expected dtype if needed if (audio_tensor->scalar_type() != expected_dtype) { diff --git a/extension/llm/runner/pybindings.cpp b/extension/llm/runner/pybindings.cpp index bcc6aba0f8e..993358d46fb 100644 --- a/extension/llm/runner/pybindings.cpp +++ b/extension/llm/runner/pybindings.cpp @@ -42,6 +42,23 @@ using namespace executorch::runtime; } \ }) +static TensorPtr tensor_to_tensor_ptr(const torch::Tensor& tensor) { + auto contiguous_tensor = tensor.contiguous(); + void* data_ptr = contiguous_tensor.data_ptr(); + const auto dtype = contiguous_tensor.options().dtype(); + std::vector sizes; + sizes.reserve(contiguous_tensor.sizes().size()); + + for (const auto size : contiguous_tensor.sizes()) { + sizes.push_back(size); + } + return executorch::extension::from_blob( + data_ptr, + sizes, + torch_to_executorch_scalar_type(dtype), + [tensor = std::move(contiguous_tensor)](void*) {}); +} + // Python wrapper class for MultimodalRunner class PyMultimodalRunner { public: @@ -132,7 +149,7 @@ class PyMultimodalRunner { } } - void prefill(std::vector inputs) { + void prefill(const std::vector& inputs) { if (!runner_) { throw std::runtime_error("Runner not initialized"); } @@ -274,14 +291,29 @@ PYBIND11_MODULE(_llm_runner, m) { .def_property_readonly("width", &Image::width) .def_property_readonly("height", &Image::height) .def_property_readonly("channels", &Image::channels) - .def_property_readonly( - "uint8_data", - static_cast& (Image::*)() const&>( - &Image::get_uint8_data)) - .def_property_readonly( - "float_data", - static_cast& (Image::*)() const&>( - &Image::get_float_data)) + .def( + "tensor", + [](const Image& image, bool with_batch) { + return tensor_to_torch_tensor(*image.tensor(with_batch)); + }, + py::arg("with_batch") = false) + .def_buffer([](Image& image) -> py::buffer_info { + auto tensor = image.tensor(); + const auto scalar_type = tensor->scalar_type(); + const auto element_size = elementSize(scalar_type); + const auto* format = scalar_type == aten::ScalarType::Byte + ? py::format_descriptor::format() + : py::format_descriptor::format(); + py::buffer_info buffer_info( + tensor->mutable_data_ptr(), + element_size, + format, + tensor->dim(), + std::vector{tensor->sizes().begin(), tensor->sizes().end()} + ); + buffer_info.readonly = true; + return buffer_info; + }) .def("__repr__", [](const Image& img) { std::string dtype = "unknown"; if (img.is_uint8()) { @@ -297,7 +329,6 @@ PYBIND11_MODULE(_llm_runner, m) { // Bind Audio class py::class_