Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions tpu_inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import os

# The environment variables override should be imported before any other
# modules to ensure that the environment variables are set before any
# other modules are imported.
import tpu_inference.env_override # noqa: F401
from tpu_inference import envs
from tpu_inference import tpu_info as ti
from tpu_inference.logger import init_logger

logger = init_logger(__name__)

if "proxy" in os.environ.get('JAX_PLATFORMS', '').lower():
if "proxy" in envs.JAX_PLATFORMS:
logger.info("Running vLLM on TPU via Pathways proxy.")
# Must run pathwaysutils.initialize() before any JAX operations
try:
Expand Down
14 changes: 6 additions & 8 deletions tpu_inference/core/disagg_utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
# SPDX-License-Identifier: Apache-2.0

import os
from typing import Tuple

PREFILL_SLICES = 'PREFILL_SLICES'
DECODE_SLICES = 'DECODE_SLICES'
from tpu_inference import envs


def is_disagg_enabled() -> bool:
# We triggrer our code path as long as prefill slices are set. This
# allows us to test interleave mode effectively with the code path
# for comparison purposes.
return PREFILL_SLICES in os.environ
return bool(envs.PREFILL_SLICES)


def _parse_slices(slices_str: str) -> Tuple[int, ...]:
Expand Down Expand Up @@ -40,12 +38,12 @@ def _parse_slices(slices_str: str) -> Tuple[int, ...]:


def get_prefill_slices() -> Tuple[int, ...]:
if PREFILL_SLICES not in os.environ:
if not envs.PREFILL_SLICES:
return ()
return _parse_slices(os.environ[PREFILL_SLICES])
return _parse_slices(envs.PREFILL_SLICES)


def get_decode_slices() -> Tuple[int, ...]:
if DECODE_SLICES not in os.environ:
if not envs.DECODE_SLICES:
return ()
return _parse_slices(os.environ[DECODE_SLICES])
return _parse_slices(envs.DECODE_SLICES)
5 changes: 2 additions & 3 deletions tpu_inference/distributed/tpu_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@

import copy
import functools
import os
import threading
import time
from concurrent.futures import Future, ThreadPoolExecutor
Expand All @@ -86,6 +85,7 @@
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request

from tpu_inference import envs
from tpu_inference.distributed.utils import (get_host_ip, get_kv_ips,
get_kv_ports,
get_kv_transfer_port, get_node_id,
Expand Down Expand Up @@ -441,8 +441,7 @@ def __init__(self, vllm_config: VllmConfig):

self.runner: TPUModelRunner = None
self.mesh: Mesh = None
self.multi_host = os.getenv("TPU_MULTIHOST_BACKEND",
"").lower() == "ray"
self.multi_host = envs.TPU_MULTIHOST_BACKEND == "ray"
# NOTE(xiang): This can not be the worker rank set in RayDistributedExecutor.
# The worker rank is assigned with vLLM's sorting logic, which does not work
# for TPU host topology.
Expand Down
5 changes: 3 additions & 2 deletions tpu_inference/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from vllm.utils.network_utils import get_ip

from tpu_inference import envs
from tpu_inference.logger import init_logger

logger = init_logger(__name__)
Expand All @@ -17,7 +18,7 @@ def set_node_kv_ip_port(ip_port: tuple[int, str, int]):


def get_kv_ips() -> str:
if os.getenv("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
if envs.TPU_MULTIHOST_BACKEND == "ray":
num_nodes = len(_NODES_KV_IP_PORT)
ips = []
for node_id in range(num_nodes):
Expand All @@ -28,7 +29,7 @@ def get_kv_ips() -> str:


def get_kv_ports() -> str:
if os.getenv("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
if envs.TPU_MULTIHOST_BACKEND == "ray":
num_nodes = len(_NODES_KV_IP_PORT)
ports = []
for node_id in range(num_nodes):
Expand Down
2 changes: 1 addition & 1 deletion tpu_inference/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
environment_variables: dict[str, Callable[[], Any]] = {
# JAX platform selection (e.g., "tpu", "cpu", "proxy")
"JAX_PLATFORMS":
lambda: os.getenv("JAX_PLATFORMS", ""),
lambda: os.getenv("JAX_PLATFORMS", "").lower(),
# TPU accelerator type (e.g., "v5litepod-16", "v4-8")
"TPU_ACCELERATOR_TYPE":
lambda: os.getenv("TPU_ACCELERATOR_TYPE", None),
Expand Down
4 changes: 2 additions & 2 deletions tpu_inference/layers/vllm/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)

from tpu_inference import envs
from tpu_inference.logger import init_logger

P = PartitionSpec
Expand Down Expand Up @@ -211,8 +212,7 @@ def _shard_module_to_tpu(model: torch.nn.Module, mesh: Mesh) -> None:
def _sharded_device_put(tensor: jax.Array, sharding) -> jax.Array:
if isinstance(tensor, tuple):
return tuple(_sharded_device_put(t, sharding) for t in tensor)
import os
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
multihost_backend = envs.TPU_MULTIHOST_BACKEND
if multihost_backend != "ray":
return jax.device_put(tensor, sharding)

Expand Down
16 changes: 15 additions & 1 deletion tpu_inference/mock/vllm_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,20 @@ def maybe_convert_bool(value: Optional[str]) -> Optional[bool]:
return bool(int(value))


def _get_jax_platforms() -> str:
"""Get JAX_PLATFORMS from tpu_inference.envs module.
Returns:
The JAX_PLATFORMS value.
"""
try:
from tpu_inference import envs
return envs.JAX_PLATFORMS
except ImportError:
# Fallback if tpu_inference.envs is not available
return os.getenv("JAX_PLATFORMS", "").lower()


def get_vllm_port() -> Optional[int]:
"""Get the port from VLLM_PORT environment variable.
Expand Down Expand Up @@ -941,7 +955,7 @@ def get_vllm_port() -> Optional[int]:

# Whether using Pathways
"VLLM_TPU_USING_PATHWAYS":
lambda: bool("proxy" in os.getenv("JAX_PLATFORMS", "").lower()),
lambda: bool("proxy" in _get_jax_platforms()),

# Allow use of DeepGemm kernels for fused moe ops.
"VLLM_USE_DEEP_GEMM":
Expand Down
4 changes: 2 additions & 2 deletions tpu_inference/models/jax/utils/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from jax.sharding import PartitionSpec as P
from safetensors import safe_open

from tpu_inference import utils
from tpu_inference import envs, utils
from tpu_inference.logger import init_logger
from tpu_inference.models.jax.utils import file_utils

Expand Down Expand Up @@ -421,7 +421,7 @@ def load_hf_weights(vllm_config,
# NOTE(xiang): Disable multi-threading mode if running on multi-host.
# Because multi-threading would cause different JAX processes to load
# different weights at the same time.
if os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
if envs.TPU_MULTIHOST_BACKEND == "ray":
max_workers = 1
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [
Expand Down
3 changes: 1 addition & 2 deletions tpu_inference/platforms/tpu_platform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0

import os
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast

import jax.numpy as jnp
Expand Down Expand Up @@ -183,7 +182,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config.worker_cls = \
"tpu_inference.worker.tpu_worker.TPUWorker"

multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
multihost_backend = envs.TPU_MULTIHOST_BACKEND
if not multihost_backend: # Single host
if parallel_config.pipeline_parallel_size == 1:
logger.info("Force using UniProcExecutor for JAX on \
Expand Down
7 changes: 4 additions & 3 deletions tpu_inference/tpu_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import requests

from tpu_inference import envs
from tpu_inference.logger import init_logger

logger = init_logger(__name__)
Expand Down Expand Up @@ -32,22 +33,22 @@ def get_tpu_metadata(key: str = "") -> str:


def get_tpu_type() -> str:
tpu_type = os.getenv("TPU_ACCELERATOR_TYPE", None)
tpu_type = envs.TPU_ACCELERATOR_TYPE
if tpu_type is None:
tpu_type = get_tpu_metadata(key="accelerator-type")
return tpu_type


def get_node_name() -> str:
tpu_name = os.getenv("TPU_NAME", None)
tpu_name = envs.TPU_NAME
if not tpu_name:
tpu_name = get_tpu_metadata(key="instance-id")
return tpu_name


def get_node_worker_id() -> int:
"""For multi-host TPU VM, this returns the worker id for the current node."""
worker_id = os.getenv("TPU_WORKER_ID", None)
worker_id = envs.TPU_WORKER_ID
if worker_id is None:
worker_id = get_tpu_metadata(key="agent-worker-number")
if worker_id is None:
Expand Down
9 changes: 5 additions & 4 deletions tpu_inference/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: Apache-2.0
import os
import time
from collections import defaultdict
from collections.abc import Sequence
Expand All @@ -14,8 +13,10 @@
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from vllm import envs, utils
from vllm import envs as vllm_envs
from vllm import utils

from tpu_inference import envs
from tpu_inference.logger import init_logger

GBYTES = 1024 * 1024 * 1024
Expand Down Expand Up @@ -57,10 +58,10 @@ def get_num_kv_heads_by_tp(num_kv_heads: int, tp_size: int) -> int:

def hbm_usage_bytes(devices: Any) -> List[Tuple[int, int]]:
usage = []
if envs.VLLM_TPU_USING_PATHWAYS:
if vllm_envs.VLLM_TPU_USING_PATHWAYS:
return pathways_hbm_usage_gb(devices)

multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
multihost_backend = envs.TPU_MULTIHOST_BACKEND
if multihost_backend == "ray":
# MemoryStats is only supported for addressable PjRt devices.
# Assume all the devices have similar memory usage for now.
Expand Down