-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Initial Support for Cloud TPUs #3620
Comments
Nit: image links are broken. |
@youkaichao Thanks for letting me know! Just fixed it. |
Can you elaborate on the custom Pallas kernel for PagedAttention? Is there any links? |
Good question. It's not open-sourced yet, but I was told that it will be released under the JAX repository in a week. |
Is this true after we moved to 1dquery? Or does it mean we need to support both 1d and 2d query inputs? |
You can find a sample Pallas Kernel implementation in TorchXLA for FlashAttention. A similar mechanism would apply to other kernels. Also cc @liangfu to review |
@rkooo567 I believe the change won't affect the TPU backend since GPUs and TPUs only share the scheduler, but not the worker and model runner. Also, we will introduce another attention backend for TPUs, so the changes in |
Thanks for the proposal @WoosukKwon . I'm interested to learn a few more details: |
Here is the WIP PR for the PagedAttention kernel on Pallas + TorchXLA: pytorch/xla#6912. We expect it to land pretty soon. cc @wonjoolee95 |
Hey! I'm really excited about TPU support for VLLM. I just wanted to check about support for larger multi-host pods, since it looks like it only supports single worker TPUs. Is this on the roadmap? |
@miladm paged attention kernel will be eliminated by flash attention both in prefill stage and decoding stage soon. In that case, memory block management will returned back to memory manager. PageAttention is a mistake. |
Any initial benchmarks for models like Gemma2 9b and 27b on TPU V5e or V4, considering switching , |
It is not XLA requirement, it is hardware requirement: if the hardware allocate memory in compile time, then the IR must populate the shape size for allocation optimization. A typical optimization is to use static memory as memory pool to allocate memory "dynamically". However if your chip is GPU, then you are good to allocate memory just in need. XLA provides bounded shape for the first case: https://github.com/pytorch/xla/blob/master/docs/dynamic_shape.md |
Do you have any micro benchmark in TPU (static compilation with memory optimization) between paged attention and decodes with fix length of prefill tokens and decode tokens ? |
Progress
Project Scope
This project focuses on making vLLM compatible with Google cloud TPUs. Our goal is seamless integration so users can easily run vLLM on TPUs for both online and offline inference. We will target common setups, like popular models such as Gemma, using the bfloat16 data type.
Target TPUs and Models
We will focus on the most recent generations of TPUs, namely TPU v4, v5e, and v5p, considering their superior performance to previous generations. We will start by making sure vLLM works with dense models such as Gemma. After that, we will expand support to Mixture-of-Experts (MoE) models such as Mixtral.
Features Not Included (for now)
The following features are outside the scope of this initial project, but we'd like to tackle them in the future:
Design
Overview
To integrate the TPU backend into vLLM, we will add the new TPU executor and TPU worker which are counterparts of the GPU executor and GPU worker, respectively. Unlike NVIDIA and AMD GPUs that share the same executor and worker, we create a separate code path for TPUs considering the significant difference between GPUs and TPUs. On the other hand, the two backends will share the other components of
LLMEngine
, namely the scheduler, KV cache manager, and tokenizer, as they are (almost) device agnostic.PyTorch XLA and JAX
As many components of vLLM are device and runtime agnostic, it is possible to use JAX for TPU integration. However, for faster initial integration and maximum code reuse, we will start with PyTorch XLA. Adding JAX backend to vLLM will be interesting future work.
TPU Workers
For tensor-parallel inference, the vLLM TPU executor will spin up multiple TPU workers; one TPU worker per TPU chip. Specifically, we will use Ray to connect and manage the TPU workers which may reside in different TPU VMs. Note that we do not plan to support multi-slice inference at the moment, while we will support multi-host inference within the same TPU pod slice.
Same as the GPU executor, the TPU executor will use Megatron-style model partitioning for tensor-parallel inference. The partitioning strategy will be hardcoded into the model by replacing
nn.Linear
withRowParallelLinear
andColumnParallelLinear
. Auto-sharding the model can be our future work.GPU Executor vs. TPU Executor
For GPUs, vLLM uses both eager mode and CUDA graphs for model execution. Specifically, vLLM uses eager mode for prefills and CUDA graphs for decodes. vLLM currently does not use torch.compile for GPUs, but plans to use it in the future. For TPUs, on the other hand, vLLM will use
torch.compile
(with openxla_eval backend) to trace the PyTorch model and lower it into an XLA graph.While vLLM’s GPU and TPU backends will take separate code paths, they will share the PyTorch model code. Most of the custom ops for GPUs will not be needed for TPUs, since they can be auto-generated by the XLA compiler. Therefore, for each target op, vLLM will have two implementations,
_forward
and_forward_cuda
, and select either of the two implementations at run time depending on the hardware backend. For example, we can define the target ops/layers as follows:Important exceptions to this are the FlashAttention and PagedAttention custom ops, which cannot be generated by the XLA compiler. We will use custom Pallas kernels for them.
Handling Dynamic Shapes
vLLM’s continuous batching has two phases: prefill and decode. vLLM dynamically switches between the two phases based on its scheduling decisions. The input tensor shape for prefills is
[batch_size, prefill_len, hidden_size]
while the input tensor shape for decodes is[batch_size, 1, hidden_size]
since LLMs decode tokens one by one (here we do not consider special cases such as speculative decoding). In LLM inference, the batch_size and prefill_len can vary for every step.To meet the XLA’s static shape requirement, we will bucketize the possible input shapes. For decodes, we will bucketize the
batch_size
dimension by creating buckets forbatch_size=[8, 16, 24, 32, 40, …, 256]
. For prefills, to reduce the number of compiled graphs, we will fix thebatch_size
to 1, and bucketize theprefill_len
dimension by creating buckets forprefill_len=[8, 16, 32, 64, 128, …, max_model_len]
. Given that each prefill input contains enough tokens to efficiently utilize TPUs, fixingbatch_size
as 1 will not hurt performance a lot. The specific bucket sizes will be tuned after benchmarking the compilation overhead and end-to-end performance.References
The text was updated successfully, but these errors were encountered: