Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/models/llava/export_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,11 @@ def export_all(llava_model: LlavaModel):
{
"image_encoder": image_encoder_ep,
"token_embedding": token_embedding_ep,
"text_model": text_model_ep,
"text_decoder": text_model_ep,
},
partitioner={
"image_encoder": [XnnpackPartitioner()],
"text_model": [
"text_decoder": [
# First partition the DQLinear nodes, then partition the rest of the nodes,
# to avoid multiple DQLinear nodes in the same partition,
# to avoid holding multiple unpacked and packed weight buffers in memory,
Expand All @@ -254,7 +254,7 @@ def export_all(llava_model: LlavaModel):
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
sym_shape_eval_pass={
"image_encoder": ConstraintBasedSymShapeEvalPass(),
"text_model": ConstraintBasedSymShapeEvalPass(),
"text_decoder": ConstraintBasedSymShapeEvalPass(),
"token_embedding": HintBasedSymShapeEvalPass(),
},
)
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llava/runner/llava_text_decoder_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class ET_EXPERIMENTAL LlavaTextDecoderRunner
}

inline static const std::string kTokenEmbeddingMethod = "token_embedding";
inline static const std::string kTextModelMethod = "text_model";
inline static const std::string kTextModelMethod = "text_decoder";
};

} // namespace example
8 changes: 4 additions & 4 deletions examples/models/llava/test/test_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_llava_export(self):
"token_embedding", (prompt_before_image,)
)[0]
llava_module.run_method(
"text_model",
"text_decoder",
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img),
)

Expand All @@ -107,7 +107,7 @@ def test_llava_export(self):
# pte prefill image
pte_embeds_img = llava_module.run_method("image_encoder", (resized,))[0]
llava_module.run_method(
"text_model",
"text_decoder",
(
torch.tensor([start_pos], dtype=torch.int64),
pte_embeds_img,
Expand All @@ -122,7 +122,7 @@ def test_llava_export(self):
"token_embedding", (prompt_after_image,)
)[0]
pte_prefill_after_img = llava_module.run_method(
"text_model",
"text_decoder",
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img),
)[0]

Expand All @@ -139,7 +139,7 @@ def test_llava_export(self):
"token_embedding", (torch.tensor([[new_tokens[i]]], dtype=torch.int64),)
)[0]
logits = llava_module.run_method(
"text_model",
"text_decoder",
(torch.tensor([start_pos + i], dtype=torch.int64), token_embeds),
)[0]
new_tokens.append(torch.argmax(logits).item())
Expand Down
8 changes: 4 additions & 4 deletions examples/models/llava/test/test_pte.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def main():
"token_embedding", (prompt_before_image,)
)[0]
pte_prefill_before_img = llava_module.run_method(
"text_model",
"text_decoder",
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img),
)[0]
print(pte_prefill_before_img)
Expand All @@ -60,7 +60,7 @@ def main():
logging.warning("Image encoder finished")
logging.warning("Image token prefill started")
pte_prefill_img = llava_module.run_method(
"text_model",
"text_decoder",
(
torch.tensor([start_pos], dtype=torch.int64),
pte_embeds_img,
Expand All @@ -77,7 +77,7 @@ def main():
"token_embedding", (prompt_after_image,)
)[0]
pte_prefill_after_img = llava_module.run_method(
"text_model",
"text_decoder",
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img),
)[0]
logging.warning("Text token prefill finished")
Expand All @@ -91,7 +91,7 @@ def main():
"token_embedding", (torch.tensor([[new_tokens[i]]], dtype=torch.int64),)
)[0]
logits = llava_module.run_method(
"text_model",
"text_decoder",
(torch.tensor([start_pos + i], dtype=torch.int64), token_embeds),
)[0]
new_tokens.append(torch.argmax(logits[..., -1, :]).item())
Expand Down
52 changes: 52 additions & 0 deletions extension/llm/runner/audio.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

// A simple audio struct.

#pragma once
#include <executorch/runtime/platform/compiler.h>
#include <cstdint>
#include <vector>

namespace executorch {
namespace extension {
namespace llm {

/**
* Audio inputs as a raw audio tensor, for use when the audio processing
* into a mel spectrogram is baked into the audio encoder with torch.export.
*/
struct ET_EXPERIMENTAL RawAudio {
std::vector<uint8_t> data;
int32_t batch_size;
int32_t n_channels; // For mono, use n_channels = 1.
int32_t n_samples;
};

/**
* Pre-processed audio inputs, ready to feed directly into an audio
* encoder.
*/
struct ET_EXPERIMENTAL Audio {
std::vector<uint8_t> data;
int32_t batch_size;
int32_t n_bins;
int32_t n_frames;
};

} // namespace llm
} // namespace extension
} // namespace executorch

namespace torch {
namespace executor {
// TODO(T197294990): Remove these deprecated aliases once all users have moved
// to the new `::executorch` namespaces.
using ::executorch::extension::llm::Audio;
} // namespace executor
} // namespace torch
3 changes: 2 additions & 1 deletion extension/llm/runner/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ inline constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";

// Multimodal method name conventions
inline constexpr auto kImageEncoderMethod = "image_encoder";
inline constexpr auto kAudioEncoderMethod = "audio_encoder";
inline constexpr auto kTokenEmbeddingMethod = "token_embedding";
inline constexpr auto kTextModelMethod = "text_model";
inline constexpr auto kTextModelMethod = "text_decoder";

} // namespace executorch::extension::llm
161 changes: 153 additions & 8 deletions extension/llm/runner/multimodal_input.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#pragma once

#include <executorch/extension/llm/runner/audio.h>
#include <executorch/extension/llm/runner/image.h>
#include <executorch/runtime/platform/compiler.h>
#include <string>
Expand All @@ -19,19 +20,31 @@
namespace executorch::extension::llm {

/**
* A generic class to hold either image or text data for multimodal inputs.
* This allows the generate() API to take a std::vector of these objects
* instead of separate image and text parameters.
* A generic class to hold either image, text, or audio data for multimodal
* inputs. This allows the generate() API to take a std::vector of these objects
* instead of separate image, text, and audio parameters.
*/
class ET_EXPERIMENTAL MultimodalInput {
public:
enum class Type { TEXT, IMAGE };
/// Type of multimodal input data
enum class Type {
TEXT, ///< Text string input
IMAGE, ///< Processed image input
AUDIO, ///< Processed audio input
RAW_AUDIO, ///< Raw unprocessed audio input (straight from audio file)
UNSUPPORTED ///< Unsupported input type
};

// Constructors
explicit MultimodalInput(const std::string& text) : data_(text) {}
explicit MultimodalInput(std::string&& text) : data_(std::move(text)) {}
explicit MultimodalInput(const Image& image) : data_(image) {}
explicit MultimodalInput(Image&& image) : data_(std::move(image)) {}
explicit MultimodalInput(const Audio& audio) : data_(audio) {}
explicit MultimodalInput(Audio&& audio) : data_(std::move(audio)) {}
explicit MultimodalInput(const RawAudio& raw_audio) : data_(raw_audio) {}
explicit MultimodalInput(RawAudio&& raw_audio)
: data_(std::move(raw_audio)) {}

// Copy constructor and assignment
MultimodalInput(const MultimodalInput& other) = default;
Expand Down Expand Up @@ -60,12 +73,37 @@ class ET_EXPERIMENTAL MultimodalInput {
return std::holds_alternative<Image>(data_);
}

/**
* Check if this input contains audio data.
* @return true if this input contains audio, false otherwise.
*/
bool is_audio() const noexcept {
return std::holds_alternative<Audio>(data_);
}

/**
* Check if this input contains raw audio data.
* @return true if this input contains raw audio, false otherwise.
*/
bool is_raw_audio() const noexcept {
return std::holds_alternative<RawAudio>(data_);
}

/**
* Get the type of data stored in this input.
* @return Type::TEXT if text data, Type::IMAGE if image data.
* @return Type::TEXT if text data, Type::IMAGE if image data, Type::AUDIO if
* audio data, Type::RAW_AUDIO if raw audio data.
*/
Type get_type() const noexcept {
return is_text() ? Type::TEXT : Type::IMAGE;
if (is_text())
return Type::TEXT;
if (is_image())
return Type::IMAGE;
if (is_audio())
return Type::AUDIO;
if (is_raw_audio())
return Type::RAW_AUDIO;
return Type::UNSUPPORTED;
}

/**
Expand Down Expand Up @@ -122,6 +160,60 @@ class ET_EXPERIMENTAL MultimodalInput {
return std::get<Image>(std::move(data_));
}

/**
* Get the audio data from this input.
* @return Reference to the stored Audio object.
* @throws std::bad_variant_access if this input doesn't contain audio.
*/
const Audio& get_audio() const& {
return std::get<Audio>(data_);
}

/**
* Get the audio data from this input (mutable version).
* @return Mutable reference to the stored Audio object.
* @throws std::bad_variant_access if this input doesn't contain audio.
*/
Audio& get_audio() & {
return std::get<Audio>(data_);
}

/**
* Get the audio data from this input (rvalue version).
* @return Rvalue reference to the stored Audio object for efficient moves.
* @throws std::bad_variant_access if this input doesn't contain audio.
*/
Audio&& get_audio() && {
return std::get<Audio>(std::move(data_));
}

/**
* Get the raw audio data from this input.
* @return Reference to the stored RawAudio object.
* @throws std::bad_variant_access if this input doesn't contain raw audio.
*/
const RawAudio& get_raw_audio() const& {
return std::get<RawAudio>(data_);
}

/**
* Get the raw audio data from this input (mutable version).
* @return Mutable reference to the stored RawAudio object.
* @throws std::bad_variant_access if this input doesn't contain raw audio.
*/
RawAudio& get_raw_audio() & {
return std::get<RawAudio>(data_);
}

/**
* Get the raw audio data from this input (rvalue version).
* @return Rvalue reference to the stored RawAudio object for efficient moves.
* @throws std::bad_variant_access if this input doesn't contain raw audio.
*/
RawAudio&& get_raw_audio() && {
return std::get<RawAudio>(std::move(data_));
}

/**
* Try to get the text data from this input safely.
* @return Pointer to the text string if this input contains text, nullptr
Expand Down Expand Up @@ -158,8 +250,44 @@ class ET_EXPERIMENTAL MultimodalInput {
return std::get_if<Image>(&data_);
}

/**
* Try to get the audio data from this input safely.
* @return Pointer to the Audio object if this input contains audio,
* nullptr otherwise.
*/
const Audio* try_get_audio() const noexcept {
return std::get_if<Audio>(&data_);
}

/**
* Try to get the audio data from this input safely (mutable version).
* @return Pointer to the Audio object if this input contains audio,
* nullptr otherwise.
*/
Audio* try_get_audio() noexcept {
return std::get_if<Audio>(&data_);
}

/**
* Try to get the raw audio data from this input safely.
* @return Pointer to the RawAudio object if this input contains raw audio,
* nullptr otherwise.
*/
const RawAudio* try_get_raw_audio() const noexcept {
return std::get_if<RawAudio>(&data_);
}

/**
* Try to get the raw audio data from this input safely (mutable version).
* @return Pointer to the RawAudio object if this input contains raw audio,
* nullptr otherwise.
*/
RawAudio* try_get_raw_audio() noexcept {
return std::get_if<RawAudio>(&data_);
}

private:
std::variant<std::string, Image> data_;
std::variant<std::string, Image, Audio, RawAudio> data_;
};

// Convenience factory functions
Expand All @@ -179,4 +307,21 @@ inline MultimodalInput make_image_input(Image&& image) noexcept {
return MultimodalInput(std::move(image));
}

} // namespace executorch::extension::llm
inline MultimodalInput make_audio_input(const Audio& audio) noexcept {
return MultimodalInput(audio);
}

inline MultimodalInput make_audio_input(Audio&& audio) noexcept {
return MultimodalInput(std::move(audio));
}

inline MultimodalInput make_raw_audio_input(
const RawAudio& raw_audio) noexcept {
return MultimodalInput(raw_audio);
}

inline MultimodalInput make_raw_audio_input(RawAudio&& raw_audio) noexcept {
return MultimodalInput(std::move(raw_audio));
}

} // namespace executorch::extension::llm
Loading
Loading