Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Doc] Default to multiprocessing for single-node distributed case #5230

Merged
merged 13 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions docs/source/serving/distributed_serving.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
Distributed Inference and Serving
=================================

vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm <https://arxiv.org/pdf/1909.08053.pdf>`_. We manage the distributed runtime with `Ray <https://github.com/ray-project/ray>`_. To run distributed inference, install Ray with:
vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm <https://arxiv.org/pdf/1909.08053.pdf>`_. We manage the distributed runtime with either `Ray <https://github.com/ray-project/ray>`_ or python native multiprocessing. Multiprocessing can be used when deploying on a single node, multi-node inferencing currently requires Ray.

Multiprocessing will be used by default when not running in a Ray placement group and if there are sufficient GPUs available on the same node for the configured :code:`tensor_parallel_size`, otherwise Ray will be used. This default can be overridden via the :code:`LLM` class :code:`distributed-executor-backend` argument or :code:`--distributed-executor-backend` API server argument. Set it to :code:`mp` for multiprocessing or :code:`ray` for Ray.

njhill marked this conversation as resolved.
Show resolved Hide resolved
.. code-block:: console

Expand All @@ -25,10 +27,12 @@ To run multi-GPU serving, pass in the :code:`--tensor-parallel-size` argument wh
$ --model facebook/opt-13b \
$ --tensor-parallel-size 4

To scale vLLM beyond a single machine, start a `Ray runtime <https://docs.ray.io/en/latest/ray-core/starting-ray.html>`_ via CLI before running vLLM:
To scale vLLM beyond a single machine, install and start a `Ray runtime <https://docs.ray.io/en/latest/ray-core/starting-ray.html>`_ via CLI before running vLLM:

.. code-block:: console

$ pip install ray

$ # On head node
$ ray start --head

Expand Down
17 changes: 16 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,9 +600,24 @@ def __init__(
f"'{self.distributed_executor_backend}'.")

if self.distributed_executor_backend is None and self.world_size > 1:
# We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group.
from torch.cuda import device_count
youkaichao marked this conversation as resolved.
Show resolved Hide resolved

from vllm.executor import ray_utils
backend = "mp"
ray_found = ray_utils.ray is not None
self.distributed_executor_backend = "ray" if ray_found else "mp"
if device_count() < self.world_size:
if not ray_found:
raise ValueError("Unable to load Ray which is "
"required for multi-node inference")
backend = "ray"
elif ray_found:
from ray.util import get_current_placement_group
if get_current_placement_group() is not None:
njhill marked this conversation as resolved.
Show resolved Hide resolved
backend = "ray"
logger.info("Defaulting to use %s for distributed inference",
backend)

self._verify_args()

Expand Down
Loading