Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 38 additions & 62 deletions extension/llm/runner/audio.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#pragma once
#include <executorch/runtime/platform/compiler.h>
#include <cstdint>
#include <variant>
#include <vector>

#include <executorch/extension/tensor/tensor.h>
Expand Down Expand Up @@ -41,81 +40,72 @@ struct ET_EXPERIMENTAL RawAudio {
*/
class ET_EXPERIMENTAL Audio final {
public:
// Default constructor
Audio() : batch_size_(0), n_bins_(0), n_frames_(0) {}

// Constructor for uint8_t data
Audio(
std::vector<uint8_t>&& data,
int32_t batch_size,
int32_t n_bins,
int32_t n_frames)
: data_(std::move(data)),
batch_size_(batch_size),
n_bins_(n_bins),
n_frames_(n_frames) {
ET_CHECK_MSG(
data_.index() == 0 &&
std::get<std::vector<uint8_t>>(data_).size() ==
static_cast<size_t>(batch_size * n_bins * n_frames),
"data.size() (%zu) does not match batch_size * n_bins * n_frames (%d)",
std::get<std::vector<uint8_t>>(data_).size(),
batch_size * n_bins * n_frames);
}
: Audio(make_tensor_ptr(
{batch_size, n_bins, n_frames},
std::move(data),
executorch::aten::ScalarType::Byte)) {}

// Constructor for float data
Audio(
std::vector<float>&& data,
int32_t batch_size,
int32_t n_bins,
int32_t n_frames)
: data_(std::move(data)),
batch_size_(batch_size),
n_bins_(n_bins),
n_frames_(n_frames) {
ET_CHECK_MSG(
data_.index() == 1 &&
std::get<std::vector<float>>(data_).size() ==
static_cast<size_t>(batch_size * n_bins * n_frames),
"data.size() (%zu) does not match batch_size * n_bins * n_frames (%d)",
std::get<std::vector<float>>(data_).size(),
batch_size * n_bins * n_frames);
: Audio(make_tensor_ptr({batch_size, n_bins, n_frames}, std::move(data))) {}

Audio(executorch::extension::TensorPtr tensor) : tensor_(std::move(tensor)) {
ET_CHECK_MSG(tensor_, "Null tensor");
ET_CHECK_MSG(tensor_->dim() == 3, "Invalid tensor rank");
}

// Type checkers
bool is_uint8() const {
return std::holds_alternative<std::vector<uint8_t>>(data_);
return tensor_->scalar_type() == ::executorch::aten::ScalarType::Byte;
}

bool is_float() const {
return std::holds_alternative<std::vector<float>>(data_);
return tensor_->scalar_type() == ::executorch::aten::ScalarType::Float;
}

// Data access
const std::vector<uint8_t>& get_uint8_data() const& {
return std::get<std::vector<uint8_t>>(data_);
std::vector<uint8_t> copy_uint8_data() const {
ET_DCHECK_MSG(is_uint8(), "Audio dtype is not uint8");
auto data = tensor_->const_data_ptr<uint8_t>();
return std::vector<uint8_t>(data, data + tensor_->numel());
}

std::vector<uint8_t>& get_uint8_data() & {
return std::get<std::vector<uint8_t>>(data_);
std::vector<uint8_t> copy_uint8_data() {
ET_DCHECK_MSG(is_uint8(), "Audio dtype is not uint8");
auto data = tensor_->const_data_ptr<uint8_t>();
return std::vector<uint8_t>(data, data + tensor_->numel());
}

const std::vector<float>& get_float_data() const& {
return std::get<std::vector<float>>(data_);
std::vector<float> copy_float_data() const {
ET_DCHECK_MSG(is_float(), "Audio dtype is not float");
auto data = tensor_->const_data_ptr<float>();
return std::vector<float>(data, data + tensor_->numel());
}

std::vector<float>& get_float_data() & {
return std::get<std::vector<float>>(data_);
std::vector<float> copy_float_data() {
ET_DCHECK_MSG(is_float(), "Audio dtype is not float");
auto data = tensor_->const_data_ptr<float>();
return std::vector<float>(data, data + tensor_->numel());
}

int32_t get_batch_size() const {
return batch_size_;
return tensor_->size(0);
}
int32_t get_n_bins() const {
return n_bins_;
return tensor_->size(1);
}
int32_t get_n_frames() const {
return n_frames_;
return tensor_->size(2);
}
/**
* Convert the audio data to a TensorPtr, with optional batch dimension.
Expand All @@ -124,34 +114,20 @@ class ET_EXPERIMENTAL Audio final {
*/
executorch::runtime::Result<executorch::extension::TensorPtr> toTensor(
bool with_batch = false) const {
std::vector<executorch::aten::SizesType> sizes = {
get_batch_size(), get_n_bins(), get_n_frames()};
if (with_batch) {
sizes.insert(sizes.begin(), 1);
}
if (is_float()) {
return executorch::extension::from_blob(
const_cast<float*>(get_float_data().data()),
sizes,
::executorch::aten::ScalarType::Float);
} else if (is_uint8()) {
return executorch::extension::from_blob(
const_cast<uint8_t*>(get_uint8_data().data()),
sizes,
::executorch::aten::ScalarType::Byte);
return make_tensor_ptr(
*tensor_,
{1,
static_cast<executorch::aten::SizesType>(tensor_->size(0)),
static_cast<executorch::aten::SizesType>(tensor_->size(1)),
static_cast<executorch::aten::SizesType>(tensor_->size(2))});
}
ET_LOG(
Error,
"Shouldn't reach here, audio data is not initialized with uint8_t or float vector.");
return ::executorch::runtime::Error::NotSupported;
return tensor_;
}

private:
// Members
std::variant<std::vector<uint8_t>, std::vector<float>> data_;
int32_t batch_size_;
int32_t n_bins_;
int32_t n_frames_;
executorch::extension::TensorPtr tensor_;
};

} // namespace llm
Expand Down
96 changes: 44 additions & 52 deletions extension/llm/runner/image.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@

#pragma once
#include <executorch/runtime/platform/compiler.h>
#include <cstddef>
#include <cstdint>
#include <variant>
#include <vector>

#include <executorch/extension/tensor/tensor.h>
Expand All @@ -22,100 +20,94 @@ namespace executorch {
namespace extension {
namespace llm {

// Assuming NCHW format
class ET_EXPERIMENTAL Image {
public:
// Default constructor
Image() : width_(0), height_(0), channels_(0) {}

// Constructor for uint8_t data
Image(
std::vector<uint8_t>&& data,
int32_t width,
int32_t height,
int32_t channels)
: data_(std::move(data)),
width_(width),
height_(height),
channels_(channels) {}
: Image(make_tensor_ptr(
{channels, height, width},
std::move(data),
executorch::aten::ScalarType::Byte)) {}

// Constructor for float data
Image(
std::vector<float>&& data,
int32_t width,
int32_t height,
int32_t channels)
: data_(std::move(data)),
width_(width),
height_(height),
channels_(channels) {}
: Image(make_tensor_ptr({channels, height, width}, std::move(data))) {}

Image(executorch::extension::TensorPtr tensor) : tensor_(std::move(tensor)) {
ET_CHECK_MSG(tensor_, "Null tensor");
ET_CHECK_MSG(tensor_->dim() == 3, "Invalid tensor rank");
}

// Getters
int32_t width() const {
return width_;
int32_t channels() const {
return tensor_->size(0);
}

int32_t height() const {
return height_;
return tensor_->size(1);
}
int32_t channels() const {
return channels_;

int32_t width() const {
return tensor_->size(2);
}

// Data access
bool is_uint8() const {
return std::holds_alternative<std::vector<uint8_t>>(data_);
return tensor_->scalar_type() == ::executorch::aten::ScalarType::Byte;
}

bool is_float() const {
return std::holds_alternative<std::vector<float>>(data_);
return tensor_->scalar_type() == ::executorch::aten::ScalarType::Float;
}

const std::vector<uint8_t>& get_uint8_data() const& {
return std::get<std::vector<uint8_t>>(data_);
std::vector<uint8_t> copy_uint8_data() const {
ET_DCHECK_MSG(is_uint8(), "Image dtype is not uint8");
auto data = tensor_->const_data_ptr<uint8_t>();
return std::vector<uint8_t>(data, data + tensor_->numel());
}

std::vector<uint8_t>& get_uint8_data() & {
return std::get<std::vector<uint8_t>>(data_);
std::vector<uint8_t> copy_uint8_data() {
ET_DCHECK_MSG(is_uint8(), "Image dtype is not uint8");
auto data = tensor_->const_data_ptr<uint8_t>();
return std::vector<uint8_t>(data, data + tensor_->numel());
}

const std::vector<float>& get_float_data() const& {
return std::get<std::vector<float>>(data_);
std::vector<float> copy_float_data() const {
ET_DCHECK_MSG(is_float(), "Image dtype is not float");
auto data = tensor_->const_data_ptr<float>();
return std::vector<float>(data, data + tensor_->numel());
}

std::vector<float>& get_float_data() & {
return std::get<std::vector<float>>(data_);
std::vector<float> copy_float_data() {
ET_DCHECK_MSG(is_float(), "Image dtype is not float");
auto data = tensor_->const_data_ptr<float>();
return std::vector<float>(data, data + tensor_->numel());
}

executorch::runtime::Result<executorch::extension::TensorPtr> toTensor(
bool with_batch = false) const {
// Note: This creates a 3D tensor (CHW). The model might expect a 4D
// tensor (NCHW). The caller should handle reshaping if needed.
std::vector<executorch::aten::SizesType> sizes = {
channels(), height(), width()};
if (with_batch) {
sizes.insert(sizes.begin(), 1);
}
if (is_float()) {
return executorch::extension::from_blob(
const_cast<float*>(get_float_data().data()),
sizes,
::executorch::aten::ScalarType::Float);
} else if (is_uint8()) {
return executorch::extension::from_blob(
const_cast<uint8_t*>(get_uint8_data().data()),
sizes,
::executorch::aten::ScalarType::Byte);
return make_tensor_ptr(
*tensor_,
{1,
executorch::aten::SizesType(tensor_->size(0)),
executorch::aten::SizesType(tensor_->size(1)),
executorch::aten::SizesType(tensor_->size(2))});
}
ET_LOG(
Error, "Image data is not initialized with uint8_t or float vector.");
return ::executorch::runtime::Error::NotSupported;
return tensor_;
}

private:
// Assuming NCHW format
std::variant<std::vector<uint8_t>, std::vector<float>> data_;
int32_t width_;
int32_t height_;
int32_t channels_;
executorch::extension::TensorPtr tensor_;
};

} // namespace llm
Expand Down
Loading
Loading