Skip to content

RFC: Multimodal Support on ExecuTorch #12913

@larryliu0820

Description

@larryliu0820

🚀 The feature, motivation and pitch

Context

Recently we've been seeing a lot of edge sized multimodal models coming out:

  • Gemma3 4b
  • Voxtral

ExecuTorch should make sure these models work out of the box by making sure export and runtime are just a click away.

Note: the scope of this RFC only covers for EarlyFusion model architecture type, see the definition in torchtune:

    EarlyFusion is a type of fused model architecture where pretrained encoder(s) are combined
    with a pretrained decoder (LLM) at the model input and not in internal layers. This is a popular architecture
    for multimodal models, with a full overview available in `The Evolution of Multimodal Model Architectures
    <https://arxiv.org/abs/2405.17927>`_. This module works both for decoders in which the encoder tokens are
    inside the vocab and outside the vocab.

Example models are gemma3 4b and voxtral. On the contrary there's another popular model architecture type called DeepFusion:

    DeepFusion is a type of fused model architecture where a pretrained encoder is combined
    with a pretrained decoder (LLM) in the internal decoder layers. This is a popular architecture for multimodal models, with
    a full overview available in `The Evolution of Multimodal Model Architectures <https://arxiv.org/abs/2405.17927>`_.
    A common deep fusion architecture is to fuse the encoder input into the decoder with interspersed cross-attention
    layers. 

We leave DeepFusion out of the scope because it needs significant model definition rewrite to make it work with torch.export. One example for DeepFusion is llama3.2 vision model.

Status Quo

ET has limited support for multimodal models. We have 1 example for Llava 1.5 as of today (07/28/2025) but it is tightly coupled with the definition in llama_transformer.py and so we have to translate HF model weights to match llama_transformer.

We have a runner implemented for Llava here: llava_runner.cpp but it is for Llava only and not generic enough to run other multimodal models.

Goal

Provide APIs to be able to export EarlyFusion multimodal models on HF transformer, that support image-text and audio-text to text generation task. Provide C++ runner to run these models out of the box on XNNPACK.

Non-goal

Not touching the export flow and the runner for Qualcomm or MTK. We should eventually unify the runner implementation but let's start with XNNPACK first.

Export

Due to the complexity of model definition and given the sheer volume of users on HuggingFace, I think we should fully embrace HuggingFace transformer models and avoid weights translations. We have an ongoing work stream to plug in ExecuTorch export flow to transformers: optimum-executorch, with some prototyping PR huggingface/optimum-executorch#111 we've proven that it's possible to extend the current capability of optimum-executorch to support multimodal tasks.

Details

API entry point:

  • optimum.executorch.modeling.ExecuTorchModelForMultimodalToTextCausalLM (temporary name) Similar to ExecuTorchModelForCausalLM this new module handles AutoModelForCausalLM in HuggingFace but cares about the following tasks:
  • image-text-to-text takes both image and text input and generates text output.
  • audio-text-to-text takes both audio and text input and generates text output.

Example usage:

from optimum.executorch.modeling import ExecuTorchModelForMultimodalToTextCausalLM
tokenizer = AutoTokenizer.from_pretrained(model_id)
image_url = "https://llava-vl.github.io/static/images/view.jpg"
conversation = [
            {
                "role": "system",
                "content": [{"type": "text", "text": "You are a helpful assistant."}]
            },
            {
                "role": "user",
                "content": [
                    {"type": "image", "url": image_url},
                    {
                        "type": "text",
                        "text": "What are the things I should be cautious about when I visit here?",
                    },
                ],
            },
        ]
processor = AutoProcessor.from_pretrained(model_id)
inputs = processor.apply_chat_template(
            conversation,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True, 
            return_tensors="pt",
)
model = ExecuTorchModelForMultimodalToTextCausalLM.from_pretrained(
            model_id,
            recipe="xnnpack",
            task="image-text-to-text",
            export=True,
            use_custom_sdpa=True,
            use_custom_kv_cache=True,
            qlinear=True,
            qembedding_config=True,
        )
output = model.generate(
            AutoTokenizer.from_pretrained(model_id),
            input_ids=inputs["input_ids"],
            pixel_values=inputs["pixel_values"],
            max_new_tokens=50,
        )

Eager model construction

optimum/exporters/executorch/tasks/multimodal_to_text.py. It should retrieve the eager model from AutoModelForCausalLM.from_pretrained() and apply quantization and module swapping. It returns a MultimodalToTextExportableModule.

Export recipe

We can start with something really simple (qembedding means int8 quantization for embedding layer, qlinear means 8da4w with group size 32 etc). Later we should be able to leverage the exiting llm_config system or the executorch.export recipes.

Exportable module

MultimodalToTextExportableModule provides sample inputs and dynamic shapes for exporting the model and implements an export() method to facilitates that. It wraps the text model with TorchExportableModuleWithHybridCache which hosts a KV cache member field. Its cache will be feed into the text model during export.

The way we should export the model is like the following:

  • We export token_embedding which converts tokens to embeddings.
  • We then export image_encoder or audio_encoder to convert image or audio features to embeddings.
  • Lastly we export text_decoder which takes the embeddings and prefill the kv cache and generate tokens in autoregression loop.
  • We export a single .pte file with all 3 methods.

Please see the diagram below:

Supported Model Architecture:
┌─────────────────────────────────────────────────────────────────────────┐
│                        Multimodal LLM Architecture                      │
└─────────────────────────────────────────────────────────────────────────┘
   Input: std::vector<MultimodalInput>
          ┌─────────────────┐  ┌─────────────────┐  ┌─────────────────┐
          │     Image       │  │     Audio       │  │      Text       │
          │    [224x        │  │    [16kHz       │  │     "What"      │
          │     224x3]      │  │     audio]      │  │                 │
          └─────────────────┘  └─────────────────┘  └─────────────────┘
                   │                    │                    │
                   ▼                    ▼                    ▼
          ┌─────────────────┐  ┌─────────────────┐  ┌─────────────────┐ ◄─┐
          │     Encoder     │  │     Encoder     │  │ Text Tokenizer  │   │
          │   (Vision)      │  │   (Audio)       │  │   & Embedding   │   │
          │                 │  │                 │  │                 │   │
          │ pixels → embed  │  │ waveform→embed  │  │ tokens → embed  │   │
          └─────────────────┘  └─────────────────┘  └─────────────────┘   │
                   │                    │                    │            │
                   ▼                    ▼                    ▼            │
          ┌─────────────────┐  ┌─────────────────┐  ┌─────────────────┐   │
          │     [D_emb]     │  │     [D_emb]     │  │     [D_emb]     │   │
          │    Embedding    │  │    Embedding    │  │    Embedding    │   │
          └─────────────────┘  └─────────────────┘  └─────────────────┘   │
                   │                    │                    │            │
                   └────────────────────┼────────────────────┘            │
                                        │                                 │
                                        ▼                                 │
                   ┌─────────────────────────────┐                        │
                   │      Text Decoder Block     │                        │
                   │    (Transformer Layers)     │                        │
                   │                             │                        │
                   │  ┌─────────────────────┐    │                        │
                   │  │   Self-Attention    │    │                        │
                   │  │   + Feed Forward    │    │                        │
                   │  │   (with KV Cache)   │    │                        │
                   │  └─────────────────────┘    │                        │
                   │           │                 │                        │
                   │           ▼                 │                        │
                   │    Token Generation         │                        │
                   │    (pos_ tracking)          │                        │
                   └─────────────────────────────┘                        │
                                  │───────────────────────────────────────┘
                                  │          (Autoregressive)
                                  ▼
                         ┌─────────────────┐
                         │  Generated Text │
                         │ "This image     │
                         │  shows a cat    │
                         │  sitting..."    │
                         └─────────────────┘

Runtime

The design of the multimodal runner should follow these principles:

The contract between the runtime and export:

  • The names of the methods need to match. image_encoder, audio_encoder and text_decoder.
  • The ordering of the input matters, since the runner will prefill all inputs
  • Export metadata such as EOS token id.

Example code:

#include <executorch/extension/llm/runner/multimodal_input.h>
#include <executorch/extension/llm/runner/multimodal_runner.h>
#include <executorch/extension/llm/runner/llm_runner_helper.h>

// Create inputs
std::vector<MultimodalInput> inputs;
inputs.emplace_back(make_text_input("Describe this image:"));
inputs.emplace_back(make_image_input(std::move(image)));

// Type checking
if (input.is_text()) {
    std::cout << input.get_text() << std::endl;
}
if (input.is_image()) {
    const auto& img = input.get_image();
    std::cout << "Image: " << img.width << "x" << img.height << std::endl;
}

// Create runner using helper functions
auto tokenizer = load_tokenizer("tokenizer.bin");
auto runner = create_multimodal_runner("model.pte", std::move(tokenizer));

// Generate text
auto error = runner->generate(inputs, config, token_callback, stats_callback);

Preprocessing and post processing

I plan to add simple image/audio preprocessing to pytorch-labs/tokenizers repo. Focusing on image and audio files to ExecuTorch tensor conversion.

Bindings and mobile APIs

I plan to add Python bindings for people to quickly experiment with their .pte files, we should also add multimodal runner to Swift Package and Android prebuilt .aar.

Alternatives

No response

Additional context

No response

RFC (Optional)

No response

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

Status

In progress

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions