Skip to content

Conversation

Sugar-zsg
Copy link
Contributor

@Sugar-zsg Sugar-zsg commented Sep 15, 2025

Improves performance by getting the max encoder length directly from the initialized vllm_config.scheduler_config. This avoids the expensive lookup and re-computation previously done by MULTIMODAL_REGISTRY.get_encdec_max_encoder_len.

Test Results:

  • Environment: H20 GPU
  • Data: 10s audio
  • Before: Average latency was 1300ms.
  • After: Average latency is now 305ms.

Purpose

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to improve performance by optimizing how the maximum encoder length is retrieved. However, the current implementation introduces a critical correctness issue by using an incorrect value, which creates an inconsistency between memory allocation for the KV cache and the metadata used in the attention mechanism. This could lead to out-of-bounds memory access. I've provided a detailed comment explaining the issue and recommending a safer approach.

Comment on lines 25 to 33
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change introduces a critical correctness issue by using a value for max_encoder_len that is inconsistent with how the cross-attention KV cache is allocated. This can lead to out-of-bounds memory access and other runtime errors.

The new implementation uses scheduler_config.max_num_encoder_input_tokens, which is derived from max_num_batched_tokens. This is a general batching configuration (e.g., 16384) and not the model-specific maximum encoder sequence length (e.g., 1500 for Whisper).

Meanwhile, the memory for the cross-attention KV cache is allocated based on the correct, smaller value via MULTIMODAL_REGISTRY.get_encdec_max_encoder_len in CrossAttentionSpec.max_memory_usage_bytes. Using a much larger max_seq_len in the attention metadata here creates a dangerous discrepancy.

While the performance motivation is valid, a safer solution is to compute the correct value once during engine initialization and cache it in the configuration. For now, it's best to revert to the original implementation to avoid memory corruption issues.

def _get_max_encoder_len(vllm_config: "VllmConfig") -> int:
    return MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(
        vllm_config.model_config)

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but best for @russellb @heheda12345 to confirm.

@Sugar-zsg
Copy link
Contributor Author

@russellb @heheda12345 Could you please review this PR when you have time ? Thx !

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before: Average latency was 1300ms.
After: Average latency is now 305ms.

The change make sense to me. But the speedup is much larger than I expect from this change. What is the expensive operation that lead to such a large speedup?

@russellb
Copy link
Member

Thanks for the PR! You will also need to update your commit message(s) to include the Signed-off-by header to satisfy the DCO check in CI.

Sugar-zsg and others added 6 commits September 16, 2025 10:07
Improves performance by getting the max encoder length directly from the initialized `vllm_config.scheduler_config`. This avoids the expensive lookup and re-computation previously done by `MULTIMODAL_REGISTRY.get_encdec_max_encoder_len`.

Signed-off-by: Sugar-zsg <952242923@qq.com>
Signed-off-by: Sugar-zsg <952242923@qq.com>
Signed-off-by: Sugar-zsg <952242923@qq.com>
Signed-off-by: Sugar-zsg <952242923@qq.com>
Signed-off-by: Sugar-zsg <952242923@qq.com>
Signed-off-by: Sugar-zsg <952242923@qq.com>
@Sugar-zsg
Copy link
Contributor Author

Sugar-zsg commented Sep 16, 2025

Before: Average latency was 1300ms.
After: Average latency is now 305ms.

The change make sense to me. But the speedup is much larger than I expect from this change. What is the expensive operation that lead to such a large speedup?

During decoding, each call to _prepare_inputs triggers the cross-attention builder, which in turn executes MULTIMODAL_REGISTRY.get_encdec_max_encoder_len. In my tests, this method takes around 10ms per call, while the decoder’s forward computation itself only takes about 2ms.

@Sugar-zsg
Copy link
Contributor Author

@heheda12345 @DarkLight1337 These have been resolved. Please check when you have a chance.

Copy link
Member

@russellb russellb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

I confirmed that this value in the scheduler_config is initialized from the same code from the multimodal registry:

                self.scheduler_config.max_num_encoder_input_tokens = \
                    MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)

@russellb
Copy link
Member

I haven't enabled auto-merge yet in case you wanted to look again @heheda12345

@russellb russellb added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 16, 2025
@heheda12345 heheda12345 enabled auto-merge (squash) September 16, 2025 16:12
@heheda12345 heheda12345 merged commit cd1f885 into vllm-project:main Sep 16, 2025
43 checks passed
russellb added a commit to russellb/vllm that referenced this pull request Sep 16, 2025
This is the same change that was made in vllm-project#24866. In that PR, it was
pointed out that this code:

    MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(...)

is much slower than getting the same value that's cached on the
scheduler config:

    scheduler_config.max_num_encoder_input_tokens

This PR makes the change to more spots: the scheduler, kv cache
manager, and gpu model runner.

Related to issue vllm-project#24946.

Signed-off-by: Russell Bryant <rbryant@redhat.com>
@russellb
Copy link
Member

related PR: making the same change in some other areas: #24989

FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants