Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions requirements/tpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,4 @@ ray[data]
setuptools==78.1.0
nixl==0.3.0
tpu_info==0.4.0

# Install torch_xla
torch_xla[tpu, pallas]==2.8.0
tpu-inference==0.11.1
8 changes: 0 additions & 8 deletions vllm/distributed/device_communicators/tpu_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,3 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "TPUs only support dim=-1 for all-gather."
return xm.all_gather(input_, dim=dim)
Comment on lines 97 to 99

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Communicator now crashes when tpu_inference is present

With this change the alias that swapped in tpu_inference’s communicator was removed, but the guard at lines 22‑31 still skips importing torch_xla, xm, xr, etc. whenever USE_TPU_INFERENCE is True. Because tpu-inference is now listed in requirements/tpu.txt, that flag is effectively always True on TPU installs, so the vLLM TpuCommunicator defined in this module is still used yet its methods call xm.*/torch_xla.* (lines 92‑99) even though those symbols were never imported, causing a NameError as soon as the communicator is instantiated or all_reduce/all_gather are called. Either the communicator needs to keep delegating to tpu_inference or the XLA imports must happen unconditionally; otherwise TPU runs fail immediately when tpu-inference is installed.

Useful? React with 👍 / 👎.



if USE_TPU_INFERENCE:
from tpu_inference.distributed.device_communicators import (
TpuCommunicator as TpuInferenceCommunicator,
)

TpuCommunicator = TpuInferenceCommunicator # type: ignore
4 changes: 3 additions & 1 deletion vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,9 @@ def check_max_model_len(cls, max_model_len: int) -> int:


try:
from tpu_inference.platforms import TpuPlatform as TpuInferencePlatform
from tpu_inference.platforms.tpu_platforms import (
TpuPlatform as TpuInferencePlatform,
)

TpuPlatform = TpuInferencePlatform # type: ignore
USE_TPU_INFERENCE = True
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,6 @@ def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:


if USE_TPU_INFERENCE:
from tpu_inference.worker import TPUWorker as TpuInferenceWorker
from tpu_inference.worker.tpu_worker import TPUWorker as TpuInferenceWorker

TPUWorker = TpuInferenceWorker # type: ignore