Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/worker/tpu_worker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_init_success(self, mock_vllm_config):
assert worker.profile_dir is None
assert worker.devices == ['tpu:0']

@patch('tpu_inference.worker.tpu_worker.envs')
@patch('tpu_inference.worker.tpu_worker.vllm_envs')
def test_init_with_profiler_on_rank_zero(self, mock_envs,
mock_vllm_config):
"""Tests that the profiler directory is set correctly on rank 0."""
Expand All @@ -74,7 +74,7 @@ def test_init_with_profiler_on_rank_zero(self, mock_envs,
distributed_init_method="test_method")
assert worker.profile_dir == "/tmp/profiles"

@patch('tpu_inference.worker.tpu_worker.envs')
@patch('tpu_inference.worker.tpu_worker.vllm_envs')
def test_init_with_profiler_on_other_ranks(self, mock_envs,
mock_vllm_config):
"""Tests that the profiler directory is NOT set on non-rank 0 workers."""
Expand Down
107 changes: 107 additions & 0 deletions tpu_inference/envs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the tpu-inference project

import functools
import os
from collections.abc import Callable
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
JAX_PLATFORMS: str = ""
TPU_ACCELERATOR_TYPE: str | None = None
TPU_NAME: str | None = None
TPU_WORKER_ID: str | None = None
TPU_MULTIHOST_BACKEND: str = ""
PREFILL_SLICES: str = ""
DECODE_SLICES: str = ""
SKIP_JAX_PRECOMPILE: bool = False
MODEL_IMPL_TYPE: str = "flax_nnx"
NEW_MODEL_DESIGN: bool = False
PHASED_PROFILING_DIR: str = ""
PYTHON_TRACER_LEVEL: int = 1
USE_MOE_EP_KERNEL: bool = False
RAY_USAGE_STATS_ENABLED: str = "0"
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"

environment_variables: dict[str, Callable[[], Any]] = {
# JAX platform selection (e.g., "tpu", "cpu", "proxy")
"JAX_PLATFORMS":
lambda: os.getenv("JAX_PLATFORMS", ""),
# TPU accelerator type (e.g., "v5litepod-16", "v4-8")
"TPU_ACCELERATOR_TYPE":
lambda: os.getenv("TPU_ACCELERATOR_TYPE", None),
# Name of the TPU resource
"TPU_NAME":
lambda: os.getenv("TPU_NAME", None),
# Worker ID for multi-host TPU setups
"TPU_WORKER_ID":
lambda: os.getenv("TPU_WORKER_ID", None),
# Backend for multi-host communication on TPU
"TPU_MULTIHOST_BACKEND":
lambda: os.getenv("TPU_MULTIHOST_BACKEND", "").lower(),
# Slice configuration for disaggregated prefill workers
"PREFILL_SLICES":
lambda: os.getenv("PREFILL_SLICES", ""),
# Slice configuration for disaggregated decode workers
"DECODE_SLICES":
lambda: os.getenv("DECODE_SLICES", ""),
# Skip JAX precompilation step during initialization
"SKIP_JAX_PRECOMPILE":
lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE", "0"))),
# Model implementation type (e.g., "flax_nnx")
"MODEL_IMPL_TYPE":
lambda: os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower(),
# Enable new experimental model design
"NEW_MODEL_DESIGN":
lambda: bool(int(os.getenv("NEW_MODEL_DESIGN", "0"))),
# Directory to store phased profiling output
"PHASED_PROFILING_DIR":
lambda: os.getenv("PHASED_PROFILING_DIR", ""),
# Python tracer level for profiling
"PYTHON_TRACER_LEVEL":
lambda: int(os.getenv("PYTHON_TRACER_LEVEL", "1")),
# Use custom expert-parallel kernel for MoE (Mixture of Experts)
"USE_MOE_EP_KERNEL":
lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL", "0"))),
# Enable/disable Ray usage statistics collection
"RAY_USAGE_STATS_ENABLED":
lambda: os.getenv("RAY_USAGE_STATS_ENABLED", "0"),
# Ray compiled DAG channel type for TPU
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
lambda: os.getenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm"),
}


def __getattr__(name: str) -> Any:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add test coverage like what vllm.main do?

https://github.com/vllm-project/vllm/blob/main/tests/test_envs.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea. I initially thought that existing unit test that utilizes these environment variables has passed as a sufficient coverage test. But it's probably a good idea to add a separate test as well.

@xingliu14, can you add a unit test like the one reference here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll add test on the follow up PR.

"""
Gets environment variables lazily.

NOTE: After enable_envs_cache() invocation (which triggered after service
initialization), all environment variables will be cached.
"""
if name in environment_variables:
return environment_variables[name]()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


def enable_envs_cache() -> None:
"""
Enables caching of environment variables by wrapping the module's __getattr__
function with functools.cache(). This improves performance by avoiding
repeated re-evaluation of environment variables.

NOTE: This should be called after service initialization. Once enabled,
environment variable values are cached and will not reflect changes to
os.environ until the process is restarted.
"""
# Tag __getattr__ with functools.cache
global __getattr__
__getattr__ = functools.cache(__getattr__)

# Cache all environment variables
for key in environment_variables:
__getattr__(key)


def __dir__() -> list[str]:
return list(environment_variables.keys())
4 changes: 2 additions & 2 deletions tpu_inference/layers/vllm/quantization/unquantized.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from typing import Any, Callable, Optional, Union

import jax
Expand All @@ -22,6 +21,7 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)

from tpu_inference import envs
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
from tpu_inference.layers.vllm.linear_common import (
Expand Down Expand Up @@ -164,7 +164,7 @@ def __init__(self,
ep_axis_name: str = 'model'):
super().__init__(moe)
self.mesh = mesh
self.use_kernel = bool(int(os.getenv("USE_MOE_EP_KERNEL", "0")))
self.use_kernel = envs.USE_MOE_EP_KERNEL
self.ep_axis_name = ep_axis_name
# TODO: Use autotune table once we have it.
self.block_size = {
Expand Down
4 changes: 2 additions & 2 deletions tpu_inference/models/common/model_loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import functools
import os
from typing import Any, Optional

import jax
Expand All @@ -11,6 +10,7 @@
from vllm.config import VllmConfig
from vllm.utils.func_utils import supports_kw

from tpu_inference import envs
from tpu_inference.layers.jax.sharding import ShardingAxisName
from tpu_inference.logger import init_logger
from tpu_inference.models.jax.utils.quantization.quantization_utils import (
Expand Down Expand Up @@ -314,7 +314,7 @@ def get_model(
mesh: Mesh,
is_draft_model: bool = False,
) -> Any:
impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
impl = envs.MODEL_IMPL_TYPE
logger.info(f"Loading model with MODEL_IMPL_TYPE={impl}")

if impl == "flax_nnx":
Expand Down
19 changes: 10 additions & 9 deletions tpu_inference/platforms/tpu_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast

import jax.numpy as jnp
import vllm.envs as envs
import vllm.envs as vllm_envs
from torchax.ops.mappings import j2t_dtype
from tpu_info import device
from vllm.inputs import ProcessorInputs, PromptType
from vllm.platforms.interface import Platform, PlatformEnum
from vllm.sampling_params import SamplingParams, SamplingType

from tpu_inference import envs
from tpu_inference.layers.jax.sharding import ShardingConfigManager
from tpu_inference.logger import init_logger

Expand Down Expand Up @@ -71,7 +72,7 @@ def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
try:
if envs.VLLM_TPU_USING_PATHWAYS:
if vllm_envs.VLLM_TPU_USING_PATHWAYS:
# Causes mutliprocess accessing IFRT when calling jax.devices()
return "TPU v6 lite"
else:
Expand All @@ -87,7 +88,7 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:

@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return not envs.VLLM_USE_V1
return not vllm_envs.VLLM_USE_V1

@classmethod
def get_punica_wrapper(cls) -> str:
Expand Down Expand Up @@ -118,11 +119,11 @@ def _initialize_sharding_config(cls, vllm_config: VllmConfig) -> None:

@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if not envs.VLLM_USE_V1:
if not vllm_envs.VLLM_USE_V1:
raise RuntimeError("VLLM_USE_V1=1 must be set for JAX backend.")

if envs.VLLM_TPU_USING_PATHWAYS:
assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, (
if vllm_envs.VLLM_TPU_USING_PATHWAYS:
assert not vllm_envs.VLLM_ENABLE_V1_MULTIPROCESSING, (
"VLLM_ENABLE_V1_MULTIPROCESSING must be 0 when using Pathways(JAX_PLATFORMS=proxy)"
)
cls._initialize_sharding_config(vllm_config)
Expand All @@ -144,7 +145,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
compilation_config.backend = "openxla"

# If we use vLLM's model implementation in PyTorch, we should set it with torch version of the dtype.
impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
impl = envs.MODEL_IMPL_TYPE

# NOTE(xiang): convert dtype to jnp.dtype
# NOTE(wenlong): skip this logic for mm model preprocessing
Expand All @@ -164,7 +165,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
vllm_config.model_config.dtype = j2t_dtype(
vllm_config.model_config.dtype.dtype)

if envs.VLLM_USE_V1:
if vllm_envs.VLLM_USE_V1:
# TODO(cuiq): remove this dependency.
from vllm.v1.attention.backends.pallas import \
PallasAttentionBackend
Expand Down Expand Up @@ -250,7 +251,7 @@ def validate_request(
"""Raises if this request is unsupported on this platform"""

if isinstance(params, SamplingParams):
if params.structured_outputs is not None and not envs.VLLM_USE_V1:
if params.structured_outputs is not None and not vllm_envs.VLLM_USE_V1:
raise ValueError("Structured output is not supported on "
f"{cls.device_name} V0.")
if params.sampling_type == SamplingType.RANDOM_SEED:
Expand Down
10 changes: 5 additions & 5 deletions tpu_inference/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import jax.numpy as jnp
import jaxlib
import jaxtyping
import vllm.envs as envs
import vllm.envs as vllm_envs
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
has_kv_transfer_group)
Expand All @@ -22,7 +22,7 @@
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput

from tpu_inference import utils
from tpu_inference import envs, utils
from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
get_node_id)
from tpu_inference.layers.jax.sharding import ShardingConfigManager
Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(self,
devices=None):
# If we use vLLM's model implementation in PyTorch, we should set it
# with torch version of the dtype.
impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
impl = envs.MODEL_IMPL_TYPE
if impl != "vllm": # vllm-pytorch implementation does not need this conversion

# NOTE(wenlong): because sometimes mm needs to use torch for preprocessing
Expand Down Expand Up @@ -86,11 +86,11 @@ def __init__(self,
# TPU Worker is initialized. The profiler server needs to start after
# MP runtime is initialized.
self.profile_dir = None
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
if not self.devices or 0 in self.device_ranks:
# For TPU, we can only have 1 active profiler session for 1 profiler
# server. So we only profile on rank0.
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
self.profile_dir = vllm_envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
self.profile_dir)

Expand Down