Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 88 additions & 1 deletion extension/llm/runner/multimodal_input.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
#include <executorch/extension/llm/runner/audio.h>
#include <executorch/extension/llm/runner/image.h>
#include <executorch/runtime/platform/compiler.h>
#include <cstdint>
#include <string>
#include <variant>
#include <vector>

namespace executorch::extension::llm {

Expand All @@ -29,15 +31,46 @@ class ET_EXPERIMENTAL MultimodalInput {
/// Type of multimodal input data
enum class Type {
TEXT, ///< Text string input
TOKENS, ///< Pre-tokenized input (vector of token IDs)
IMAGE, ///< Processed image input
AUDIO, ///< Processed audio input
RAW_AUDIO, ///< Raw unprocessed audio input (straight from audio file)
UNSUPPORTED ///< Unsupported input type
};

/**
* Return a human-readable name for a MultimodalInput::Type.
* Preferred for logging and debugging; returns string literals.
*/
static constexpr const char* TypeName(Type t) noexcept {
switch (t) {
case Type::TEXT:
return "text";
case Type::TOKENS:
return "tokens";
case Type::IMAGE:
return "image";
case Type::AUDIO:
return "audio";
case Type::RAW_AUDIO:
return "raw_audio";
default:
return "unknown";
}
}

/** Convenience wrapper that returns a std::string. */
static inline std::string TypeToString(Type t) {
return TypeName(t);
}

// Constructors
explicit MultimodalInput(const std::string& text) : data_(text) {}
explicit MultimodalInput(std::string&& text) : data_(std::move(text)) {}
explicit MultimodalInput(const std::vector<uint64_t>& tokens)
: data_(tokens) {}
explicit MultimodalInput(std::vector<uint64_t>&& tokens)
: data_(std::move(tokens)) {}
explicit MultimodalInput(const Image& image) : data_(image) {}
explicit MultimodalInput(Image&& image) : data_(std::move(image)) {}
explicit MultimodalInput(const Audio& audio) : data_(audio) {}
Expand Down Expand Up @@ -65,6 +98,13 @@ class ET_EXPERIMENTAL MultimodalInput {
return std::holds_alternative<std::string>(data_);
}

/**
* Check if this input contains pre-tokenized data.
*/
bool is_tokens() const noexcept {
return std::holds_alternative<std::vector<uint64_t>>(data_);
}

/**
* Check if this input contains image data.
* @return true if this input contains an image, false otherwise.
Expand Down Expand Up @@ -97,6 +137,8 @@ class ET_EXPERIMENTAL MultimodalInput {
Type get_type() const noexcept {
if (is_text())
return Type::TEXT;
if (is_tokens())
return Type::TOKENS;
if (is_image())
return Type::IMAGE;
if (is_audio())
Expand All @@ -106,6 +148,15 @@ class ET_EXPERIMENTAL MultimodalInput {
return Type::UNSUPPORTED;
}

/**
* Get a human-readable name for the contained input type.
* Returns one of: "text", "tokens", "image", "audio", "raw_audio", or
* "unknown".
*/
const char* type_name() const noexcept {
return TypeName(get_type());
}

/**
* Get the text data from this input.
* @return Reference to the stored text string.
Expand Down Expand Up @@ -133,6 +184,21 @@ class ET_EXPERIMENTAL MultimodalInput {
return std::get<std::string>(std::move(data_));
}

/**
* Get the token vector from this input.
*/
const std::vector<uint64_t>& get_tokens() const& {
return std::get<std::vector<uint64_t>>(data_);
}

std::vector<uint64_t>& get_tokens() & {
return std::get<std::vector<uint64_t>>(data_);
}

std::vector<uint64_t>&& get_tokens() && {
return std::get<std::vector<uint64_t>>(std::move(data_));
}

/**
* Get the image data from this input.
* @return Reference to the stored Image object.
Expand Down Expand Up @@ -250,6 +316,16 @@ class ET_EXPERIMENTAL MultimodalInput {
return std::get_if<Image>(&data_);
}

/** Try to get the tokens from this input safely. */
const std::vector<uint64_t>* try_get_tokens() const noexcept {
return std::get_if<std::vector<uint64_t>>(&data_);
}

/** Try to get the tokens from this input safely (mutable). */
std::vector<uint64_t>* try_get_tokens() noexcept {
return std::get_if<std::vector<uint64_t>>(&data_);
}

/**
* Try to get the audio data from this input safely.
* @return Pointer to the Audio object if this input contains audio,
Expand Down Expand Up @@ -287,7 +363,8 @@ class ET_EXPERIMENTAL MultimodalInput {
}

private:
std::variant<std::string, Image, Audio, RawAudio> data_;
std::variant<std::string, std::vector<uint64_t>, Image, Audio, RawAudio>
data_;
};

// Convenience factory functions
Expand All @@ -307,6 +384,16 @@ inline MultimodalInput make_image_input(Image&& image) noexcept {
return MultimodalInput(std::move(image));
}

inline MultimodalInput make_token_input(
const std::vector<uint64_t>& tokens) noexcept {
return MultimodalInput(tokens);
}

inline MultimodalInput make_token_input(
std::vector<uint64_t>&& tokens) noexcept {
return MultimodalInput(std::move(tokens));
}

inline MultimodalInput make_audio_input(const Audio& audio) noexcept {
return MultimodalInput(audio);
}
Expand Down
12 changes: 8 additions & 4 deletions extension/llm/runner/multimodal_prefiller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,14 @@ Result<uint64_t> MultimodalPrefiller::prefill(
auto audio_encoder_outputs = audio_encoder_result.get();

encoder_output = audio_encoder_outputs[0];
} else if (input.is_text()) {
auto& text = input.get_text();
std::vector<uint64_t> tokens =
ET_UNWRAP_TOKENIZER(tokenizer_->encode(text));
} else if (input.is_text() || input.is_tokens()) {
std::vector<uint64_t> tokens;
if (input.is_text()) {
auto& text = input.get_text();
tokens = ET_UNWRAP_TOKENIZER(tokenizer_->encode(text));
} else {
tokens = input.get_tokens();
}

auto text_tensor = executorch::extension::from_blob(
tokens.data(),
Expand Down
6 changes: 6 additions & 0 deletions extension/llm/runner/multimodal_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ Error MultimodalRunner::generate(
// Process multimodal inputs in order
for (size_t i = 0; i < inputs.size(); ++i) {
const MultimodalInput& input = inputs[i];
ET_LOG(
Info,
"Prefilling input %zu/%zu, type: %s",
i,
inputs.size(),
input.type_name());
if (config.echo && i == inputs.size() - 1 && input.is_text()) {
wrapped_callback(input.get_text());
}
Expand Down
Loading
Loading