diff --git a/docs/source/serving/distributed_serving.rst b/docs/source/serving/distributed_serving.rst index 4f36dca15d7..b0c45dbf702 100644 --- a/docs/source/serving/distributed_serving.rst +++ b/docs/source/serving/distributed_serving.rst @@ -3,11 +3,9 @@ Distributed Inference and Serving ================================= -vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm `_. We manage the distributed runtime with `Ray `_. To run distributed inference, install Ray with: +vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm `_. We manage the distributed runtime with either `Ray `_ or python native multiprocessing. Multiprocessing can be used when deploying on a single node, multi-node inferencing currently requires Ray. -.. code-block:: console - - $ pip install ray +Multiprocessing will be used by default when not running in a Ray placement group and if there are sufficient GPUs available on the same node for the configured :code:`tensor_parallel_size`, otherwise Ray will be used. This default can be overridden via the :code:`LLM` class :code:`distributed-executor-backend` argument or :code:`--distributed-executor-backend` API server argument. Set it to :code:`mp` for multiprocessing or :code:`ray` for Ray. It's not required for Ray to be installed for the multiprocessing case. To run multi-GPU inference with the :code:`LLM` class, set the :code:`tensor_parallel_size` argument to the number of GPUs you want to use. For example, to run inference on 4 GPUs: @@ -25,10 +23,12 @@ To run multi-GPU serving, pass in the :code:`--tensor-parallel-size` argument wh $ --model facebook/opt-13b \ $ --tensor-parallel-size 4 -To scale vLLM beyond a single machine, start a `Ray runtime `_ via CLI before running vLLM: +To scale vLLM beyond a single machine, install and start a `Ray runtime `_ via CLI before running vLLM: .. code-block:: console + $ pip install ray + $ # On head node $ ray start --head diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 1d060e26584..f8a6de54653 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -77,7 +77,11 @@ def __init__( swap_space=swap_space, enforce_eager=enforce_eager, max_seq_len_to_capture=max_seq_len_to_capture, + # For now use ray for the distributed back-end, since + # we rely on the use of engine_use_ray=True to avoid + # reinitializing CUDA in the same process (driver worker) engine_use_ray=True, + distributed_executor_backend="ray", disable_custom_all_reduce=disable_custom_all_reduce, **kwargs, ) diff --git a/vllm/config.py b/vllm/config.py index 7ffb93c19ed..50b0156b1e8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -603,9 +603,25 @@ def __init__( f"'{self.distributed_executor_backend}'.") if self.distributed_executor_backend is None and self.world_size > 1: + # We use multiprocessing by default if world_size fits on the + # current node and we aren't in a ray placement group. + from torch.cuda import device_count + from vllm.executor import ray_utils + backend = "mp" ray_found = ray_utils.ray is not None - self.distributed_executor_backend = "ray" if ray_found else "mp" + if device_count() < self.world_size: + if not ray_found: + raise ValueError("Unable to load Ray which is " + "required for multi-node inference") + backend = "ray" + elif ray_found: + from ray.util import get_current_placement_group + if self.placement_group or get_current_placement_group(): + backend = "ray" + self.distributed_executor_backend = backend + logger.info("Defaulting to use %s for distributed inference", + backend) self._verify_args() diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 93488439197..99c9e52034c 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -19,10 +19,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor): """Python multiprocessing-based multi-GPU executor""" def _init_executor(self) -> None: - assert ( - not self.speculative_config - ), "Speculative decoding not yet supported for MultiProcGPU backend." - # Create the parallel GPU workers. world_size = self.parallel_config.tensor_parallel_size diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py index 62887533f5c..28c8e8699f0 100644 --- a/vllm/executor/multiproc_worker_utils.py +++ b/vllm/executor/multiproc_worker_utils.py @@ -65,10 +65,11 @@ def _set_future_result(future: Union[ResultFuture, asyncio.Future], future.set_result(result) return loop = future.get_loop() - if result.exception is not None: - loop.call_soon_threadsafe(future.set_exception, result.exception) - else: - loop.call_soon_threadsafe(future.set_result, result.value) + if not loop.is_closed(): + if result.exception is not None: + loop.call_soon_threadsafe(future.set_exception, result.exception) + else: + loop.call_soon_threadsafe(future.set_result, result.value) class ResultHandler(threading.Thread):