diff --git a/tpu_inference/__init__.py b/tpu_inference/__init__.py index d10311cb0..09d2fcdd7 100644 --- a/tpu_inference/__init__.py +++ b/tpu_inference/__init__.py @@ -1,15 +1,14 @@ -import os - # The environment variables override should be imported before any other # modules to ensure that the environment variables are set before any # other modules are imported. import tpu_inference.env_override # noqa: F401 +from tpu_inference import envs from tpu_inference import tpu_info as ti from tpu_inference.logger import init_logger logger = init_logger(__name__) -if "proxy" in os.environ.get('JAX_PLATFORMS', '').lower(): +if "proxy" in envs.JAX_PLATFORMS: logger.info("Running vLLM on TPU via Pathways proxy.") # Must run pathwaysutils.initialize() before any JAX operations try: diff --git a/tpu_inference/core/disagg_utils.py b/tpu_inference/core/disagg_utils.py index 58528b8ad..ecb16e9ac 100644 --- a/tpu_inference/core/disagg_utils.py +++ b/tpu_inference/core/disagg_utils.py @@ -1,17 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 -import os from typing import Tuple -PREFILL_SLICES = 'PREFILL_SLICES' -DECODE_SLICES = 'DECODE_SLICES' +from tpu_inference import envs def is_disagg_enabled() -> bool: # We triggrer our code path as long as prefill slices are set. This # allows us to test interleave mode effectively with the code path # for comparison purposes. - return PREFILL_SLICES in os.environ + return bool(envs.PREFILL_SLICES) def _parse_slices(slices_str: str) -> Tuple[int, ...]: @@ -40,12 +38,12 @@ def _parse_slices(slices_str: str) -> Tuple[int, ...]: def get_prefill_slices() -> Tuple[int, ...]: - if PREFILL_SLICES not in os.environ: + if not envs.PREFILL_SLICES: return () - return _parse_slices(os.environ[PREFILL_SLICES]) + return _parse_slices(envs.PREFILL_SLICES) def get_decode_slices() -> Tuple[int, ...]: - if DECODE_SLICES not in os.environ: + if not envs.DECODE_SLICES: return () - return _parse_slices(os.environ[DECODE_SLICES]) + return _parse_slices(envs.DECODE_SLICES) diff --git a/tpu_inference/distributed/tpu_connector.py b/tpu_inference/distributed/tpu_connector.py index 66a50b26a..cf09dcea7 100644 --- a/tpu_inference/distributed/tpu_connector.py +++ b/tpu_inference/distributed/tpu_connector.py @@ -60,7 +60,6 @@ import copy import functools -import os import threading import time from concurrent.futures import Future, ThreadPoolExecutor @@ -86,6 +85,7 @@ from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request +from tpu_inference import envs from tpu_inference.distributed.utils import (get_host_ip, get_kv_ips, get_kv_ports, get_kv_transfer_port, get_node_id, @@ -441,8 +441,7 @@ def __init__(self, vllm_config: VllmConfig): self.runner: TPUModelRunner = None self.mesh: Mesh = None - self.multi_host = os.getenv("TPU_MULTIHOST_BACKEND", - "").lower() == "ray" + self.multi_host = envs.TPU_MULTIHOST_BACKEND == "ray" # NOTE(xiang): This can not be the worker rank set in RayDistributedExecutor. # The worker rank is assigned with vLLM's sorting logic, which does not work # for TPU host topology. diff --git a/tpu_inference/distributed/utils.py b/tpu_inference/distributed/utils.py index cf1a0b966..61dde5e60 100644 --- a/tpu_inference/distributed/utils.py +++ b/tpu_inference/distributed/utils.py @@ -2,6 +2,7 @@ from vllm.utils.network_utils import get_ip +from tpu_inference import envs from tpu_inference.logger import init_logger logger = init_logger(__name__) @@ -17,7 +18,7 @@ def set_node_kv_ip_port(ip_port: tuple[int, str, int]): def get_kv_ips() -> str: - if os.getenv("TPU_MULTIHOST_BACKEND", "").lower() == "ray": + if envs.TPU_MULTIHOST_BACKEND == "ray": num_nodes = len(_NODES_KV_IP_PORT) ips = [] for node_id in range(num_nodes): @@ -28,7 +29,7 @@ def get_kv_ips() -> str: def get_kv_ports() -> str: - if os.getenv("TPU_MULTIHOST_BACKEND", "").lower() == "ray": + if envs.TPU_MULTIHOST_BACKEND == "ray": num_nodes = len(_NODES_KV_IP_PORT) ports = [] for node_id in range(num_nodes): diff --git a/tpu_inference/envs.py b/tpu_inference/envs.py index 1ef212f00..e97993204 100644 --- a/tpu_inference/envs.py +++ b/tpu_inference/envs.py @@ -26,7 +26,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # JAX platform selection (e.g., "tpu", "cpu", "proxy") "JAX_PLATFORMS": - lambda: os.getenv("JAX_PLATFORMS", ""), + lambda: os.getenv("JAX_PLATFORMS", "").lower(), # TPU accelerator type (e.g., "v5litepod-16", "v4-8") "TPU_ACCELERATOR_TYPE": lambda: os.getenv("TPU_ACCELERATOR_TYPE", None), diff --git a/tpu_inference/layers/vllm/sharding.py b/tpu_inference/layers/vllm/sharding.py index b06f8b35f..b9fd4fdd9 100644 --- a/tpu_inference/layers/vllm/sharding.py +++ b/tpu_inference/layers/vllm/sharding.py @@ -19,6 +19,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from tpu_inference import envs from tpu_inference.logger import init_logger P = PartitionSpec @@ -211,8 +212,7 @@ def _shard_module_to_tpu(model: torch.nn.Module, mesh: Mesh) -> None: def _sharded_device_put(tensor: jax.Array, sharding) -> jax.Array: if isinstance(tensor, tuple): return tuple(_sharded_device_put(t, sharding) for t in tensor) - import os - multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() + multihost_backend = envs.TPU_MULTIHOST_BACKEND if multihost_backend != "ray": return jax.device_put(tensor, sharding) diff --git a/tpu_inference/mock/vllm_envs.py b/tpu_inference/mock/vllm_envs.py index 1a938002a..476643579 100644 --- a/tpu_inference/mock/vllm_envs.py +++ b/tpu_inference/mock/vllm_envs.py @@ -189,6 +189,20 @@ def maybe_convert_bool(value: Optional[str]) -> Optional[bool]: return bool(int(value)) +def _get_jax_platforms() -> str: + """Get JAX_PLATFORMS from tpu_inference.envs module. + + Returns: + The JAX_PLATFORMS value. + """ + try: + from tpu_inference import envs + return envs.JAX_PLATFORMS + except ImportError: + # Fallback if tpu_inference.envs is not available + return os.getenv("JAX_PLATFORMS", "").lower() + + def get_vllm_port() -> Optional[int]: """Get the port from VLLM_PORT environment variable. @@ -941,7 +955,7 @@ def get_vllm_port() -> Optional[int]: # Whether using Pathways "VLLM_TPU_USING_PATHWAYS": - lambda: bool("proxy" in os.getenv("JAX_PLATFORMS", "").lower()), + lambda: bool("proxy" in _get_jax_platforms()), # Allow use of DeepGemm kernels for fused moe ops. "VLLM_USE_DEEP_GEMM": diff --git a/tpu_inference/models/jax/utils/weight_utils.py b/tpu_inference/models/jax/utils/weight_utils.py index 64f026dae..64730748f 100644 --- a/tpu_inference/models/jax/utils/weight_utils.py +++ b/tpu_inference/models/jax/utils/weight_utils.py @@ -18,7 +18,7 @@ from jax.sharding import PartitionSpec as P from safetensors import safe_open -from tpu_inference import utils +from tpu_inference import envs, utils from tpu_inference.logger import init_logger from tpu_inference.models.jax.utils import file_utils @@ -421,7 +421,7 @@ def load_hf_weights(vllm_config, # NOTE(xiang): Disable multi-threading mode if running on multi-host. # Because multi-threading would cause different JAX processes to load # different weights at the same time. - if os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() == "ray": + if envs.TPU_MULTIHOST_BACKEND == "ray": max_workers = 1 with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [ diff --git a/tpu_inference/platforms/tpu_platform.py b/tpu_inference/platforms/tpu_platform.py index e23d4f7e8..b3a4a7de3 100644 --- a/tpu_inference/platforms/tpu_platform.py +++ b/tpu_inference/platforms/tpu_platform.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -import os from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast import jax.numpy as jnp @@ -183,7 +182,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config.worker_cls = \ "tpu_inference.worker.tpu_worker.TPUWorker" - multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() + multihost_backend = envs.TPU_MULTIHOST_BACKEND if not multihost_backend: # Single host if parallel_config.pipeline_parallel_size == 1: logger.info("Force using UniProcExecutor for JAX on \ diff --git a/tpu_inference/tpu_info.py b/tpu_inference/tpu_info.py index 9f5d02269..41b1a7e21 100644 --- a/tpu_inference/tpu_info.py +++ b/tpu_inference/tpu_info.py @@ -3,6 +3,7 @@ import requests +from tpu_inference import envs from tpu_inference.logger import init_logger logger = init_logger(__name__) @@ -32,14 +33,14 @@ def get_tpu_metadata(key: str = "") -> str: def get_tpu_type() -> str: - tpu_type = os.getenv("TPU_ACCELERATOR_TYPE", None) + tpu_type = envs.TPU_ACCELERATOR_TYPE if tpu_type is None: tpu_type = get_tpu_metadata(key="accelerator-type") return tpu_type def get_node_name() -> str: - tpu_name = os.getenv("TPU_NAME", None) + tpu_name = envs.TPU_NAME if not tpu_name: tpu_name = get_tpu_metadata(key="instance-id") return tpu_name @@ -47,7 +48,7 @@ def get_node_name() -> str: def get_node_worker_id() -> int: """For multi-host TPU VM, this returns the worker id for the current node.""" - worker_id = os.getenv("TPU_WORKER_ID", None) + worker_id = envs.TPU_WORKER_ID if worker_id is None: worker_id = get_tpu_metadata(key="agent-worker-number") if worker_id is None: diff --git a/tpu_inference/utils.py b/tpu_inference/utils.py index ea9edd20a..ca3d693da 100644 --- a/tpu_inference/utils.py +++ b/tpu_inference/utils.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -import os import time from collections import defaultdict from collections.abc import Sequence @@ -14,8 +13,10 @@ from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc from jax.sharding import Mesh, NamedSharding, PartitionSpec -from vllm import envs, utils +from vllm import envs as vllm_envs +from vllm import utils +from tpu_inference import envs from tpu_inference.logger import init_logger GBYTES = 1024 * 1024 * 1024 @@ -57,10 +58,10 @@ def get_num_kv_heads_by_tp(num_kv_heads: int, tp_size: int) -> int: def hbm_usage_bytes(devices: Any) -> List[Tuple[int, int]]: usage = [] - if envs.VLLM_TPU_USING_PATHWAYS: + if vllm_envs.VLLM_TPU_USING_PATHWAYS: return pathways_hbm_usage_gb(devices) - multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() + multihost_backend = envs.TPU_MULTIHOST_BACKEND if multihost_backend == "ray": # MemoryStats is only supported for addressable PjRt devices. # Assume all the devices have similar memory usage for now.