Skip to content

Conversation

youkaichao
Copy link
Member

@youkaichao youkaichao commented Nov 22, 2024

previously we register attention ops separately, e.g. flashinfer, flash attention.

this pr changes the registration to be the unified attention interface, so that we don't need to register these attention backends one by one.

how it works:

  1. when we create an attention class, we register it in the per-model static forward context, identified by its layer name
  2. when we call the attention implementation, we pass in the layer name through pytorch custom op, and inside the custom op, we find the attention object, and call the implementation.

TODO:

in the future, we should make all attention implementation accept an output argument, so that it is aligned with the v1 attention behavior.

Signed-off-by: youkaichao <youkaichao@gmail.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model changes LGTM.

alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap)

self.use_direct_call = envs.VLLM_USE_V1 or current_platform.is_tpu()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: why is TPU included in this exception?
For V1, it's because of the output argument, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for v1, yes.

for TPU, the kv cache is not a tensor. it is a tuple of tensor.

their signatures of attention op do not match others.


@contextmanager
def set_forward_context(context: Any):
def set_forward_context(context: Any, vllm_config: VllmConfig):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A dumb question: can we use get_current_vllm_config() here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, get_current_vllm_config() only works during model initialization.

this is model execution.

@youkaichao youkaichao added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 22, 2024
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
@youkaichao
Copy link
Member Author

error comes from huggingface timeout

@youkaichao youkaichao merged commit eebad39 into vllm-project:main Nov 22, 2024
54 of 56 checks passed
@youkaichao youkaichao deleted the unified_vllm_attention branch November 22, 2024 22:04
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
Signed-off-by: youkaichao <youkaichao@gmail.com>
anko-intel pushed a commit to HabanaAI/vllm-fork that referenced this pull request Feb 12, 2025
Signed-off-by: youkaichao <youkaichao@gmail.com>
joerunde pushed a commit to vllm-project/vllm-spyre that referenced this pull request Jun 27, 2025
This PR does some refactoring primarily on spyre_model_runner. This
changes tries to reduce code deduplication between static batching and
continuous batching. However, the intention of this work will not be
complete until a next PR has as goal remove kv cache manager from the
spyre model runner.

Summary of changes:

- Reduce code deduplication in spyre model runner, some methods are
common in `SpyreMoldeRunner` class, while
`StaticBatchingSpyreModelRunner` and
`ContinuousBatchingSpyreModelRunner` override few of them to do their
specific logic
- Changed `ContinuousBatchingFmsModel` class to get the attention
metadata via forward context, and changed the model runner to pass to
use the `with set_forward_context` to pass the attention metadata. This
is the way vLLM does to support multiple attention backends
[[REF](vllm-project/vllm#10558)]
- Moved the left pads to the CachedRequestState. 
- Bugfix: The `execute_model` in CB model runner was inconsistent with
the data of input batch when it outputs the resul in
`CBSpyreModelRunnerOutput`. Changed it with prepare_prompt to use the
data of input batch.
- Misc: few renamed variables, more comments, and TODOs

---------

Signed-off-by: Wallas Santos <wallashss@ibm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants