Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions tests/config/test_config_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,36 @@ def create_config():
assert deep_compare(normal_config_dict, empty_config_dict), (
"Configs with normal CUDA_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES=\"\""
" should be equivalent")


def test_ray_runtime_env(monkeypatch: pytest.MonkeyPatch):
# In testing, this method needs to be nested inside as ray does not
# see the test module.
def create_config():
engine_args = EngineArgs(model="deepseek-ai/DeepSeek-V2-Lite",
trust_remote_code=True)
return engine_args.create_engine_config()

config = create_config()
parallel_config = config.parallel_config
assert parallel_config.ray_runtime_env is None

import ray
ray.init()

runtime_env = {
"env_vars": {
"TEST_ENV_VAR": "test_value",
},
}

config_ref = ray.remote(create_config).options(
runtime_env=runtime_env).remote()

config = ray.get(config_ref)
parallel_config = config.parallel_config
assert parallel_config.ray_runtime_env is not None
assert parallel_config.ray_runtime_env.env_vars().get(
"TEST_ENV_VAR") == "test_value"

ray.shutdown()
5 changes: 5 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@

if TYPE_CHECKING:
from _typeshed import DataclassInstance
from ray.runtime_env import RuntimeEnv
from ray.util.placement_group import PlacementGroup
from transformers.configuration_utils import PretrainedConfig

Expand All @@ -73,6 +74,7 @@
else:
DataclassInstance = Any
PlacementGroup = Any
RuntimeEnv = Any
PretrainedConfig = Any
ExecutorBase = Any
QuantizationConfig = Any
Expand Down Expand Up @@ -1950,6 +1952,9 @@ class ParallelConfig:
ray_workers_use_nsight: bool = False
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""

ray_runtime_env: Optional["RuntimeEnv"] = None
"""Ray runtime environment to pass to distributed workers."""

placement_group: Optional["PlacementGroup"] = None
"""ray distributed model workers placement group."""

Expand Down
11 changes: 11 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
from vllm.ray.lazy_utils import is_ray_initialized
from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file
Expand Down Expand Up @@ -1060,6 +1061,15 @@ def create_engine_config(
calculate_kv_scales=self.calculate_kv_scales,
)

ray_runtime_env = None
if is_ray_initialized():
# Ray Serve LLM calls `create_engine_config` in the context
# of a Ray task, therefore we check is_ray_initialized()
# as opposed to is_in_ray_actor().
import ray
ray_runtime_env = ray.get_runtime_context().runtime_env
logger.info("Using ray runtime env: %s", ray_runtime_env)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't print things like envs to logs, they often can contain secrets.

Copy link
Collaborator Author

@ruisearch42 ruisearch42 Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, make sense. I created a PR to address this. #26302


# Get the current placement group if Ray is initialized and
# we are in a Ray actor. If so, then the placement group will be
# passed to spawned processes.
Expand Down Expand Up @@ -1172,6 +1182,7 @@ def create_engine_config(
max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce,
ray_workers_use_nsight=self.ray_workers_use_nsight,
ray_runtime_env=ray_runtime_env,
placement_group=placement_group,
distributed_executor_backend=self.distributed_executor_backend,
worker_cls=self.worker_cls,
Expand Down
7 changes: 5 additions & 2 deletions vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,12 @@ def initialize_ray_cluster(
logger.warning(
"No existing RAY instance detected. "
"A new instance will be launched with current node resources.")
ray.init(address=ray_address, num_gpus=parallel_config.world_size)
ray.init(address=ray_address,
num_gpus=parallel_config.world_size,
runtime_env=parallel_config.ray_runtime_env)
else:
ray.init(address=ray_address)
ray.init(address=ray_address,
runtime_env=parallel_config.ray_runtime_env)

device_str = current_platform.ray_device_key
if not device_str:
Expand Down
22 changes: 22 additions & 0 deletions vllm/ray/lazy_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


def is_ray_initialized():
"""Check if Ray is initialized."""
try:
import ray
return ray.is_initialized()
except ImportError:
return False


def is_in_ray_actor():
"""Check if we are in a Ray actor."""

try:
import ray
return (ray.is_initialized()
and ray.get_runtime_context().get_actor_id() is not None)
except ImportError:
return False
12 changes: 1 addition & 11 deletions vllm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@

import vllm.envs as envs
from vllm.logger import enable_trace_function_call, init_logger
from vllm.ray.lazy_utils import is_in_ray_actor

if TYPE_CHECKING:
from argparse import Namespace
Expand Down Expand Up @@ -2864,17 +2865,6 @@ def zmq_socket_ctx(
ctx.destroy(linger=linger)


def is_in_ray_actor():
"""Check if we are in a Ray actor."""

try:
import ray
return (ray.is_initialized()
and ray.get_runtime_context().get_actor_id() is not None)
except ImportError:
return False


def _maybe_force_spawn():
"""Check if we need to force the use of the `spawn` multiprocessing start
method.
Expand Down