Skip to content

[MM]: Optimize encoder cache memory consumption by storing encoder outputs only #25903

@ywang96

Description

@ywang96

🚀 The feature, motivation and pitch

Currently the encoder embedding cache stores the embeddings that encoder outputs are scattered into.

# Cache the encoder outputs by mm_hash
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
self.encoder_cache[mm_hash] = scatter_mm_placeholders(
output,
is_embed=pos_info.is_embed,
)

This is because very often the representation of a multimodal item in the token sequence can include special tokens other than embedding placeholder tokens (such as break token, image start token, image end token, etc). For example, in Pixtral we have the following:

Image

Storing embeddings after scattering eases the logic for scheduling, since the scheduler doesn't need to be aware of whether a token is an embedding or not, and will just grab the sequence it needs to be merged into text embedding depending on the mm_position information. Because of this design, we also had to reserve for the space for the embeddings after scattering in the encoder cache during profiling run, which was addressed in #25810.

# NOTE: This happens when encoder cache needs to store
# the embeddings that encoder outputs are scattered onto.
# In this case we create dummy embeddings of size
# (encode_budget, hidden_size) and scatter encoder
# output into it.
encoder_output_shape = dummy_encoder_outputs[0].shape
if encoder_output_shape[0] < encoder_budget:
expanded_outputs = []
for output in dummy_encoder_outputs:
expanded = output.new_zeros(
(encoder_budget, encoder_output_shape[-1]))
num_tokens = output.shape[0]
expanded[:num_tokens].copy_(output)
expanded_outputs.append(expanded)

However, the Qwen3-VL release introduces a challenge to this design. Previously we assume there are very few of such non-embedding special tokens in the entire sequence, but this has flipped for Qwen3-VL video inference because of the new timestamp insertion where the special tokens for each timestamp can go up to 12 tokens, which means in the worst scenario we're overallocating memory for 12x tokens as needed when it can be allocated for decoder KV cache instead.

To optimize this, we should store only encoder outputs in the encoder cache. This requires some non-trivial work on the scheduler side since it will now need to schedule depending on mm_position.is_embed information as well in addition to mm_position.offset and mm_position.length.

for i, mm_feature in enumerate(mm_features):
start_pos = mm_feature.mm_position.offset
num_encoder_tokens = mm_feature.mm_position.length

Alternatives

No response

Additional context

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Labels

Type

No type

Projects

Status

Todo

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions