You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am interested in adding cpu_offload_weight capability to vLLM. Wonder whether the community would like to see it happen?
The following is my design. I did some simple PoC implementation. Would love to hear feedback and suggestions before investing more time on it.
Objective
The rapid expansion of AI model sizes necessitates an increased demand for GPUs with larger memory capacities. This surge poses challenges for users without access to powerful GPU resources, hindering their ability to run vLLM on numerous models.
To address these challenges, we are devloping a feature called "cpu-offload-weight" to vLLM. With cpu-offload, users can now experiment with large models even without access to high-end GPUs. This democratizes access to vLLM, empowering a broader community of learners and researchers to engage with cutting-edge AI models.
Proposed Features
Upon initial loading of the model weights, cpu-offload pins the entire weight onto the CPU. During inference computation, these weights are streamed into the GPU as required. When computing on a layer, the weights are loaded into GPU from CPU. Upon completion of computation for a given layer, the outputs are retained within the GPU memory to serve as inputs for subsequent layers. Meanwhile, the weights are relocated back to the CPU, freeing up GPU memory for the upcoming layer's computations. This dynamic process optimizes resource utilization and ensures smooth execution of vLLM even with limited GPU resources.
No changes to KV caches or other data structures are needed.
With this change, I am able to run facebook/opt-13b with one single T4 GPU. This GPU is equipped with 15G memory only, while the weight of facebook/opt-13b alone is 25G, let along KV cache, etc.
This implementation is absolutely preliminary, yet it shows:
Cpu offload weights does help democratize vLLM, allowing people to try out big models in GPUs with limited memories.
The changes to the existing models can be safe and simple. If you take a look at the PR I shared above, the changes to opt.py is just 1) adding a new parameter, cpu_offload_weight, to the init of the modules 2)add one line of decorator to some forward() functions to indicate what weights need to be loaded from cpu.
Proposed Future Improvement
Surely, Cpu offload will increase the latency and hurt the throughput. I am going to invest more time on the following improvements:
Prefetching layers' weights. Namely, loading the weights of next layer when the GPU is computing on the current layer, if the memory permits.
Parallel Fetching. For inference on multiple GPUs, involves parallelizing the fetch of each layer across multiple GPUs by using each GPU to fetch only a portion of the layer. Employing the aggregate PCIe links of the GPUs in this manner essentially increases the transfer bandwidth linearly, thus reducing the latency.
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered:
🚀 The feature, motivation and pitch
I am interested in adding cpu_offload_weight capability to vLLM. Wonder whether the community would like to see it happen?
The following is my design. I did some simple PoC implementation. Would love to hear feedback and suggestions before investing more time on it.
Objective
The rapid expansion of AI model sizes necessitates an increased demand for GPUs with larger memory capacities. This surge poses challenges for users without access to powerful GPU resources, hindering their ability to run vLLM on numerous models.
To address these challenges, we are devloping a feature called "cpu-offload-weight" to vLLM. With cpu-offload, users can now experiment with large models even without access to high-end GPUs. This democratizes access to vLLM, empowering a broader community of learners and researchers to engage with cutting-edge AI models.
Proposed Features
Upon initial loading of the model weights, cpu-offload pins the entire weight onto the CPU. During inference computation, these weights are streamed into the GPU as required. When computing on a layer, the weights are loaded into GPU from CPU. Upon completion of computation for a given layer, the outputs are retained within the GPU memory to serve as inputs for subsequent layers. Meanwhile, the weights are relocated back to the CPU, freeing up GPU memory for the upcoming layer's computations. This dynamic process optimizes resource utilization and ensures smooth execution of vLLM even with limited GPU resources.
No changes to KV caches or other data structures are needed.
PoC Implementation
I tried this concept in my forked vLLM repo, bd-iaas-us@352a767#diff-bec94efb3136aec6b8c231b14f18447f7ffdee404b534cdc0e57ed92e7d96b8c.
With this change, I am able to run facebook/opt-13b with one single T4 GPU. This GPU is equipped with 15G memory only, while the weight of facebook/opt-13b alone is 25G, let along KV cache, etc.
This implementation is absolutely preliminary, yet it shows:
Proposed Future Improvement
Surely, Cpu offload will increase the latency and hurt the throughput. I am going to invest more time on the following improvements:
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: