From 71dea2a8b09e2544c5f0dd58c9b63dc27ae46e58 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Fri, 19 Sep 2025 16:36:10 -0700 Subject: [PATCH] [multimodal] Add token support to MultimodalInput --- extension/llm/runner/multimodal_input.h | 89 +++++- extension/llm/runner/multimodal_prefiller.cpp | 12 +- extension/llm/runner/multimodal_runner.cpp | 6 + .../llm/runner/test/test_multimodal_input.cpp | 255 ++++++++++++++++++ 4 files changed, 357 insertions(+), 5 deletions(-) diff --git a/extension/llm/runner/multimodal_input.h b/extension/llm/runner/multimodal_input.h index 728d8aef08f..737821f51e9 100644 --- a/extension/llm/runner/multimodal_input.h +++ b/extension/llm/runner/multimodal_input.h @@ -14,8 +14,10 @@ #include #include #include +#include #include #include +#include namespace executorch::extension::llm { @@ -29,15 +31,46 @@ class ET_EXPERIMENTAL MultimodalInput { /// Type of multimodal input data enum class Type { TEXT, ///< Text string input + TOKENS, ///< Pre-tokenized input (vector of token IDs) IMAGE, ///< Processed image input AUDIO, ///< Processed audio input RAW_AUDIO, ///< Raw unprocessed audio input (straight from audio file) UNSUPPORTED ///< Unsupported input type }; + /** + * Return a human-readable name for a MultimodalInput::Type. + * Preferred for logging and debugging; returns string literals. + */ + static constexpr const char* TypeName(Type t) noexcept { + switch (t) { + case Type::TEXT: + return "text"; + case Type::TOKENS: + return "tokens"; + case Type::IMAGE: + return "image"; + case Type::AUDIO: + return "audio"; + case Type::RAW_AUDIO: + return "raw_audio"; + default: + return "unknown"; + } + } + + /** Convenience wrapper that returns a std::string. */ + static inline std::string TypeToString(Type t) { + return TypeName(t); + } + // Constructors explicit MultimodalInput(const std::string& text) : data_(text) {} explicit MultimodalInput(std::string&& text) : data_(std::move(text)) {} + explicit MultimodalInput(const std::vector& tokens) + : data_(tokens) {} + explicit MultimodalInput(std::vector&& tokens) + : data_(std::move(tokens)) {} explicit MultimodalInput(const Image& image) : data_(image) {} explicit MultimodalInput(Image&& image) : data_(std::move(image)) {} explicit MultimodalInput(const Audio& audio) : data_(audio) {} @@ -65,6 +98,13 @@ class ET_EXPERIMENTAL MultimodalInput { return std::holds_alternative(data_); } + /** + * Check if this input contains pre-tokenized data. + */ + bool is_tokens() const noexcept { + return std::holds_alternative>(data_); + } + /** * Check if this input contains image data. * @return true if this input contains an image, false otherwise. @@ -97,6 +137,8 @@ class ET_EXPERIMENTAL MultimodalInput { Type get_type() const noexcept { if (is_text()) return Type::TEXT; + if (is_tokens()) + return Type::TOKENS; if (is_image()) return Type::IMAGE; if (is_audio()) @@ -106,6 +148,15 @@ class ET_EXPERIMENTAL MultimodalInput { return Type::UNSUPPORTED; } + /** + * Get a human-readable name for the contained input type. + * Returns one of: "text", "tokens", "image", "audio", "raw_audio", or + * "unknown". + */ + const char* type_name() const noexcept { + return TypeName(get_type()); + } + /** * Get the text data from this input. * @return Reference to the stored text string. @@ -133,6 +184,21 @@ class ET_EXPERIMENTAL MultimodalInput { return std::get(std::move(data_)); } + /** + * Get the token vector from this input. + */ + const std::vector& get_tokens() const& { + return std::get>(data_); + } + + std::vector& get_tokens() & { + return std::get>(data_); + } + + std::vector&& get_tokens() && { + return std::get>(std::move(data_)); + } + /** * Get the image data from this input. * @return Reference to the stored Image object. @@ -250,6 +316,16 @@ class ET_EXPERIMENTAL MultimodalInput { return std::get_if(&data_); } + /** Try to get the tokens from this input safely. */ + const std::vector* try_get_tokens() const noexcept { + return std::get_if>(&data_); + } + + /** Try to get the tokens from this input safely (mutable). */ + std::vector* try_get_tokens() noexcept { + return std::get_if>(&data_); + } + /** * Try to get the audio data from this input safely. * @return Pointer to the Audio object if this input contains audio, @@ -287,7 +363,8 @@ class ET_EXPERIMENTAL MultimodalInput { } private: - std::variant data_; + std::variant, Image, Audio, RawAudio> + data_; }; // Convenience factory functions @@ -307,6 +384,16 @@ inline MultimodalInput make_image_input(Image&& image) noexcept { return MultimodalInput(std::move(image)); } +inline MultimodalInput make_token_input( + const std::vector& tokens) noexcept { + return MultimodalInput(tokens); +} + +inline MultimodalInput make_token_input( + std::vector&& tokens) noexcept { + return MultimodalInput(std::move(tokens)); +} + inline MultimodalInput make_audio_input(const Audio& audio) noexcept { return MultimodalInput(audio); } diff --git a/extension/llm/runner/multimodal_prefiller.cpp b/extension/llm/runner/multimodal_prefiller.cpp index 824fdf943a9..2c83df24f55 100644 --- a/extension/llm/runner/multimodal_prefiller.cpp +++ b/extension/llm/runner/multimodal_prefiller.cpp @@ -110,10 +110,14 @@ Result MultimodalPrefiller::prefill( auto audio_encoder_outputs = audio_encoder_result.get(); encoder_output = audio_encoder_outputs[0]; - } else if (input.is_text()) { - auto& text = input.get_text(); - std::vector tokens = - ET_UNWRAP_TOKENIZER(tokenizer_->encode(text)); + } else if (input.is_text() || input.is_tokens()) { + std::vector tokens; + if (input.is_text()) { + auto& text = input.get_text(); + tokens = ET_UNWRAP_TOKENIZER(tokenizer_->encode(text)); + } else { + tokens = input.get_tokens(); + } auto text_tensor = executorch::extension::from_blob( tokens.data(), diff --git a/extension/llm/runner/multimodal_runner.cpp b/extension/llm/runner/multimodal_runner.cpp index 6928a9b2827..a5de59cbe98 100644 --- a/extension/llm/runner/multimodal_runner.cpp +++ b/extension/llm/runner/multimodal_runner.cpp @@ -116,6 +116,12 @@ Error MultimodalRunner::generate( // Process multimodal inputs in order for (size_t i = 0; i < inputs.size(); ++i) { const MultimodalInput& input = inputs[i]; + ET_LOG( + Info, + "Prefilling input %zu/%zu, type: %s", + i, + inputs.size(), + input.type_name()); if (config.echo && i == inputs.size() - 1 && input.is_text()) { wrapped_callback(input.get_text()); } diff --git a/extension/llm/runner/test/test_multimodal_input.cpp b/extension/llm/runner/test/test_multimodal_input.cpp index 486515175e8..85d45d69173 100644 --- a/extension/llm/runner/test/test_multimodal_input.cpp +++ b/extension/llm/runner/test/test_multimodal_input.cpp @@ -14,6 +14,7 @@ using namespace ::testing; using executorch::extension::llm::Image; using executorch::extension::llm::make_image_input; using executorch::extension::llm::make_text_input; +using executorch::extension::llm::make_token_input; using executorch::extension::llm::MultimodalInput; class MultimodalInputTest : public Test { @@ -415,3 +416,257 @@ TEST_F(MultimodalInputTest, AssignmentBetweenTypes) { EXPECT_TRUE(input.is_text()); EXPECT_EQ(input.get_text(), text); } + +// Token-related tests +class MultimodalInputTokenTest : public Test { + protected: + std::vector createTestTokens() { + return {1, 2, 3, 4, 5}; + } +}; + +// Test token constructors +TEST_F(MultimodalInputTokenTest, TokenConstructorFromVector) { + std::vector tokens = createTestTokens(); + MultimodalInput input(tokens); + + EXPECT_TRUE(input.is_tokens()); + EXPECT_FALSE(input.is_text()); + EXPECT_FALSE(input.is_image()); + EXPECT_EQ(input.get_type(), MultimodalInput::Type::TOKENS); + EXPECT_EQ(input.get_tokens(), tokens); + EXPECT_EQ(input.get_tokens().size(), 5); +} + +TEST_F(MultimodalInputTokenTest, TokenConstructorFromRvalueVector) { + std::vector tokens = createTestTokens(); + std::vector original_tokens = tokens; + MultimodalInput input(std::move(tokens)); + + EXPECT_TRUE(input.is_tokens()); + EXPECT_FALSE(input.is_text()); + EXPECT_FALSE(input.is_image()); + EXPECT_EQ(input.get_type(), MultimodalInput::Type::TOKENS); + EXPECT_EQ(input.get_tokens(), original_tokens); + EXPECT_EQ(input.get_tokens().size(), 5); +} + +// Test token type checking +TEST_F(MultimodalInputTokenTest, TokenTypeChecking) { + std::vector tokens = createTestTokens(); + MultimodalInput input(tokens); + + EXPECT_TRUE(input.is_tokens()); + EXPECT_FALSE(input.is_text()); + EXPECT_FALSE(input.is_image()); + EXPECT_FALSE(input.is_audio()); + EXPECT_FALSE(input.is_raw_audio()); + EXPECT_EQ(input.get_type(), MultimodalInput::Type::TOKENS); + EXPECT_STREQ(input.type_name(), "tokens"); +} + +// Test token getters +TEST_F(MultimodalInputTokenTest, GetTokensWithTokenInput) { + std::vector tokens = createTestTokens(); + MultimodalInput input(tokens); + + // Test const lvalue reference version + const MultimodalInput& const_input = input; + EXPECT_EQ(const_input.get_tokens(), tokens); + EXPECT_EQ(const_input.get_tokens().size(), 5); + + // Test mutable lvalue reference version + std::vector& mutable_tokens = input.get_tokens(); + mutable_tokens.push_back(6); + EXPECT_EQ(input.get_tokens().size(), 6); + EXPECT_EQ(input.get_tokens().back(), 6); + + // Test rvalue reference version + std::vector moved_tokens = std::move(input).get_tokens(); + EXPECT_EQ(moved_tokens.size(), 6); + EXPECT_EQ(moved_tokens.back(), 6); +} + +// Test token getters with wrong types (should throw) +TEST_F(MultimodalInputTokenTest, GetTokensWithTextInputThrows) { + std::string text = "Hello"; + MultimodalInput input(text); + + EXPECT_THROW(input.get_tokens(), std::bad_variant_access); + EXPECT_THROW(std::move(input).get_tokens(), std::bad_variant_access); +} + +TEST_F(MultimodalInputTokenTest, GetTextWithTokenInputThrows) { + std::vector tokens = createTestTokens(); + MultimodalInput input(tokens); + + EXPECT_THROW(input.get_text(), std::bad_variant_access); + EXPECT_THROW(std::move(input).get_text(), std::bad_variant_access); +} + +// Test safe token getters (try_get_*) +TEST_F(MultimodalInputTokenTest, TryGetTokensWithTokenInput) { + std::vector tokens = createTestTokens(); + MultimodalInput input(tokens); + + // Test const version + const MultimodalInput& const_input = input; + const std::vector* tokens_ptr = const_input.try_get_tokens(); + ASSERT_NE(tokens_ptr, nullptr); + EXPECT_EQ(*tokens_ptr, tokens); + + // Test mutable version + std::vector* mutable_tokens_ptr = input.try_get_tokens(); + ASSERT_NE(mutable_tokens_ptr, nullptr); + EXPECT_EQ(*mutable_tokens_ptr, tokens); + + // Modify through pointer + mutable_tokens_ptr->push_back(100); + EXPECT_EQ(input.get_tokens().size(), 6); + EXPECT_EQ(input.get_tokens().back(), 100); +} + +TEST_F(MultimodalInputTokenTest, TryGetTokensWithTextInput) { + std::string text = "Hello"; + MultimodalInput input(text); + + // Should return nullptr for wrong type + EXPECT_EQ(input.try_get_tokens(), nullptr); + + const MultimodalInput& const_input = input; + EXPECT_EQ(const_input.try_get_tokens(), nullptr); +} + +// Test token convenience factory functions +TEST_F(MultimodalInputTokenTest, MakeTokenInputFromVector) { + std::vector tokens = createTestTokens(); + MultimodalInput input = make_token_input(tokens); + + EXPECT_TRUE(input.is_tokens()); + EXPECT_EQ(input.get_tokens(), tokens); + EXPECT_EQ(input.get_tokens().size(), 5); +} + +TEST_F(MultimodalInputTokenTest, MakeTokenInputFromRvalueVector) { + std::vector tokens = createTestTokens(); + std::vector original_tokens = tokens; + MultimodalInput input = make_token_input(std::move(tokens)); + + EXPECT_TRUE(input.is_tokens()); + EXPECT_EQ(input.get_tokens(), original_tokens); + EXPECT_EQ(input.get_tokens().size(), 5); +} + +// Test token copy semantics +TEST_F(MultimodalInputTokenTest, TokenCopyConstructor) { + std::vector tokens = createTestTokens(); + MultimodalInput original(tokens); + MultimodalInput copy(original); + + EXPECT_TRUE(copy.is_tokens()); + EXPECT_EQ(copy.get_tokens(), tokens); + EXPECT_EQ(original.get_tokens(), tokens); // Original should be unchanged + + // Modify copy, original should be unaffected + copy.get_tokens().push_back(999); + EXPECT_EQ(copy.get_tokens().size(), 6); + EXPECT_EQ(original.get_tokens().size(), 5); +} + +TEST_F(MultimodalInputTokenTest, TokenCopyAssignment) { + std::vector tokens = createTestTokens(); + MultimodalInput original(tokens); + MultimodalInput copy("initial text"); // Start with different type + + copy = original; + + EXPECT_TRUE(copy.is_tokens()); + EXPECT_EQ(copy.get_tokens(), tokens); + EXPECT_EQ(original.get_tokens(), tokens); // Original should be unchanged +} + +// Test token move semantics +TEST_F(MultimodalInputTokenTest, TokenMoveConstructor) { + std::vector tokens = createTestTokens(); + std::vector original_tokens = tokens; + MultimodalInput original(std::move(tokens)); + MultimodalInput moved(std::move(original)); + + EXPECT_TRUE(moved.is_tokens()); + EXPECT_EQ(moved.get_tokens(), original_tokens); +} + +TEST_F(MultimodalInputTokenTest, TokenMoveAssignment) { + std::vector tokens = createTestTokens(); + std::vector original_tokens = tokens; + MultimodalInput original(std::move(tokens)); + MultimodalInput moved("initial text"); // Start with different type + + moved = std::move(original); + + EXPECT_TRUE(moved.is_tokens()); + EXPECT_EQ(moved.get_tokens(), original_tokens); +} + +// Test TypeName and TypeToString static methods for TOKENS +TEST_F(MultimodalInputTokenTest, TypeNameAndToString) { + EXPECT_STREQ( + MultimodalInput::TypeName(MultimodalInput::Type::TOKENS), "tokens"); + EXPECT_EQ( + MultimodalInput::TypeToString(MultimodalInput::Type::TOKENS), "tokens"); + + std::vector tokens = createTestTokens(); + MultimodalInput input(tokens); + EXPECT_STREQ(input.type_name(), "tokens"); +} + +// Test assignment between token and other types +TEST_F(MultimodalInputTokenTest, AssignmentBetweenTokensAndOtherTypes) { + std::vector tokens = createTestTokens(); + std::string text = "Hello"; + + MultimodalInput input(tokens); + EXPECT_TRUE(input.is_tokens()); + + // Assign text to token input + input = MultimodalInput(text); + EXPECT_TRUE(input.is_text()); + EXPECT_EQ(input.get_text(), text); + + // Assign tokens back to text input + input = MultimodalInput(tokens); + EXPECT_TRUE(input.is_tokens()); + EXPECT_EQ(input.get_tokens(), tokens); +} + +// Test token values with specific patterns +TEST_F(MultimodalInputTokenTest, SpecificTokenValues) { + std::vector tokens = { + 0, 1, 2, 65535, 4294967295ULL, 18446744073709551615ULL}; + MultimodalInput input(tokens); + + EXPECT_TRUE(input.is_tokens()); + EXPECT_EQ(input.get_tokens().size(), 6); + EXPECT_EQ(input.get_tokens()[0], 0); + EXPECT_EQ(input.get_tokens()[1], 1); + EXPECT_EQ(input.get_tokens()[2], 2); + EXPECT_EQ(input.get_tokens()[3], 65535); + EXPECT_EQ(input.get_tokens()[4], 4294967295ULL); + EXPECT_EQ(input.get_tokens()[5], 18446744073709551615ULL); // Max uint64_t +} + +// Test token modification through reference +TEST_F(MultimodalInputTokenTest, TokenModificationThroughReference) { + std::vector tokens = createTestTokens(); + MultimodalInput input(tokens); + + // Get mutable reference and modify + std::vector& token_ref = input.get_tokens(); + token_ref[0] = 999; + token_ref.push_back(1000); + + // Verify changes + EXPECT_EQ(input.get_tokens()[0], 999); + EXPECT_EQ(input.get_tokens().size(), 6); + EXPECT_EQ(input.get_tokens().back(), 1000); +}