diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index d95bd7fb054..9ba6a510736 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -226,11 +226,11 @@ def export_all(llava_model: LlavaModel): { "image_encoder": image_encoder_ep, "token_embedding": token_embedding_ep, - "text_model": text_model_ep, + "text_decoder": text_model_ep, }, partitioner={ "image_encoder": [XnnpackPartitioner()], - "text_model": [ + "text_decoder": [ # First partition the DQLinear nodes, then partition the rest of the nodes, # to avoid multiple DQLinear nodes in the same partition, # to avoid holding multiple unpacked and packed weight buffers in memory, @@ -254,7 +254,7 @@ def export_all(llava_model: LlavaModel): memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), sym_shape_eval_pass={ "image_encoder": ConstraintBasedSymShapeEvalPass(), - "text_model": ConstraintBasedSymShapeEvalPass(), + "text_decoder": ConstraintBasedSymShapeEvalPass(), "token_embedding": HintBasedSymShapeEvalPass(), }, ) diff --git a/examples/models/llava/runner/llava_text_decoder_runner.h b/examples/models/llava/runner/llava_text_decoder_runner.h index 09b8e82d49d..cfa92e0c253 100644 --- a/examples/models/llava/runner/llava_text_decoder_runner.h +++ b/examples/models/llava/runner/llava_text_decoder_runner.h @@ -89,7 +89,7 @@ class ET_EXPERIMENTAL LlavaTextDecoderRunner } inline static const std::string kTokenEmbeddingMethod = "token_embedding"; - inline static const std::string kTextModelMethod = "text_model"; + inline static const std::string kTextModelMethod = "text_decoder"; }; } // namespace example diff --git a/examples/models/llava/test/test_llava.py b/examples/models/llava/test/test_llava.py index 05cfd5b1497..def9eaa02bd 100644 --- a/examples/models/llava/test/test_llava.py +++ b/examples/models/llava/test/test_llava.py @@ -96,7 +96,7 @@ def test_llava_export(self): "token_embedding", (prompt_before_image,) )[0] llava_module.run_method( - "text_model", + "text_decoder", (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img), ) @@ -107,7 +107,7 @@ def test_llava_export(self): # pte prefill image pte_embeds_img = llava_module.run_method("image_encoder", (resized,))[0] llava_module.run_method( - "text_model", + "text_decoder", ( torch.tensor([start_pos], dtype=torch.int64), pte_embeds_img, @@ -122,7 +122,7 @@ def test_llava_export(self): "token_embedding", (prompt_after_image,) )[0] pte_prefill_after_img = llava_module.run_method( - "text_model", + "text_decoder", (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img), )[0] @@ -139,7 +139,7 @@ def test_llava_export(self): "token_embedding", (torch.tensor([[new_tokens[i]]], dtype=torch.int64),) )[0] logits = llava_module.run_method( - "text_model", + "text_decoder", (torch.tensor([start_pos + i], dtype=torch.int64), token_embeds), )[0] new_tokens.append(torch.argmax(logits).item()) diff --git a/examples/models/llava/test/test_pte.py b/examples/models/llava/test/test_pte.py index f12d72f854b..1f4aaa9938c 100644 --- a/examples/models/llava/test/test_pte.py +++ b/examples/models/llava/test/test_pte.py @@ -47,7 +47,7 @@ def main(): "token_embedding", (prompt_before_image,) )[0] pte_prefill_before_img = llava_module.run_method( - "text_model", + "text_decoder", (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img), )[0] print(pte_prefill_before_img) @@ -60,7 +60,7 @@ def main(): logging.warning("Image encoder finished") logging.warning("Image token prefill started") pte_prefill_img = llava_module.run_method( - "text_model", + "text_decoder", ( torch.tensor([start_pos], dtype=torch.int64), pte_embeds_img, @@ -77,7 +77,7 @@ def main(): "token_embedding", (prompt_after_image,) )[0] pte_prefill_after_img = llava_module.run_method( - "text_model", + "text_decoder", (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img), )[0] logging.warning("Text token prefill finished") @@ -91,7 +91,7 @@ def main(): "token_embedding", (torch.tensor([[new_tokens[i]]], dtype=torch.int64),) )[0] logits = llava_module.run_method( - "text_model", + "text_decoder", (torch.tensor([start_pos + i], dtype=torch.int64), token_embeds), )[0] new_tokens.append(torch.argmax(logits[..., -1, :]).item()) diff --git a/extension/llm/runner/audio.h b/extension/llm/runner/audio.h new file mode 100644 index 00000000000..868765950af --- /dev/null +++ b/extension/llm/runner/audio.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// A simple audio struct. + +#pragma once +#include +#include +#include + +namespace executorch { +namespace extension { +namespace llm { + +/** + * Audio inputs as a raw audio tensor, for use when the audio processing + * into a mel spectrogram is baked into the audio encoder with torch.export. + */ +struct ET_EXPERIMENTAL RawAudio { + std::vector data; + int32_t batch_size; + int32_t n_channels; // For mono, use n_channels = 1. + int32_t n_samples; +}; + +/** + * Pre-processed audio inputs, ready to feed directly into an audio + * encoder. + */ +struct ET_EXPERIMENTAL Audio { + std::vector data; + int32_t batch_size; + int32_t n_bins; + int32_t n_frames; +}; + +} // namespace llm +} // namespace extension +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::extension::llm::Audio; +} // namespace executor +} // namespace torch diff --git a/extension/llm/runner/constants.h b/extension/llm/runner/constants.h index fc6ddcb451c..4ba88203c50 100644 --- a/extension/llm/runner/constants.h +++ b/extension/llm/runner/constants.h @@ -21,7 +21,8 @@ inline constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache"; // Multimodal method name conventions inline constexpr auto kImageEncoderMethod = "image_encoder"; +inline constexpr auto kAudioEncoderMethod = "audio_encoder"; inline constexpr auto kTokenEmbeddingMethod = "token_embedding"; -inline constexpr auto kTextModelMethod = "text_model"; +inline constexpr auto kTextModelMethod = "text_decoder"; } // namespace executorch::extension::llm diff --git a/extension/llm/runner/multimodal_input.h b/extension/llm/runner/multimodal_input.h index ae243992fec..728d8aef08f 100644 --- a/extension/llm/runner/multimodal_input.h +++ b/extension/llm/runner/multimodal_input.h @@ -11,6 +11,7 @@ #pragma once +#include #include #include #include @@ -19,19 +20,31 @@ namespace executorch::extension::llm { /** - * A generic class to hold either image or text data for multimodal inputs. - * This allows the generate() API to take a std::vector of these objects - * instead of separate image and text parameters. + * A generic class to hold either image, text, or audio data for multimodal + * inputs. This allows the generate() API to take a std::vector of these objects + * instead of separate image, text, and audio parameters. */ class ET_EXPERIMENTAL MultimodalInput { public: - enum class Type { TEXT, IMAGE }; + /// Type of multimodal input data + enum class Type { + TEXT, ///< Text string input + IMAGE, ///< Processed image input + AUDIO, ///< Processed audio input + RAW_AUDIO, ///< Raw unprocessed audio input (straight from audio file) + UNSUPPORTED ///< Unsupported input type + }; // Constructors explicit MultimodalInput(const std::string& text) : data_(text) {} explicit MultimodalInput(std::string&& text) : data_(std::move(text)) {} explicit MultimodalInput(const Image& image) : data_(image) {} explicit MultimodalInput(Image&& image) : data_(std::move(image)) {} + explicit MultimodalInput(const Audio& audio) : data_(audio) {} + explicit MultimodalInput(Audio&& audio) : data_(std::move(audio)) {} + explicit MultimodalInput(const RawAudio& raw_audio) : data_(raw_audio) {} + explicit MultimodalInput(RawAudio&& raw_audio) + : data_(std::move(raw_audio)) {} // Copy constructor and assignment MultimodalInput(const MultimodalInput& other) = default; @@ -60,12 +73,37 @@ class ET_EXPERIMENTAL MultimodalInput { return std::holds_alternative(data_); } + /** + * Check if this input contains audio data. + * @return true if this input contains audio, false otherwise. + */ + bool is_audio() const noexcept { + return std::holds_alternative