-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[TPU] add tpu_inference #27277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TPU] add tpu_inference #27277
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request transitions the TPU backend to use the tpu-inference library, updating dependencies and import paths accordingly. The changes correctly adapt vLLM to the new library version. However, I've identified a critical issue with the dependency management in requirements/tpu.txt. By removing torch_xla, the fallback mechanism for when tpu-inference is unavailable becomes non-functional, which could lead to confusing runtime errors. My review includes a suggestion to address this to ensure the system remains robust.
requirements/tpu.txt
Outdated
| tpu-inference==0.11.1 | ||
| numba |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing torch_xla from this requirements file makes the fallback path in the code fragile. The codebase contains logic to fall back to a torch_xla-based implementation if tpu-inference is not found or fails to import (e.g., in vllm/platforms/tpu.py).
With torch_xla removed from the dependencies, if tpu-inference is installed but fails to import for any reason (like a transitive dependency issue), the fallback will immediately fail with an ImportError for torch_xla. This creates a brittle setup and can cause confusing errors for users.
To ensure the fallback mechanism is robust, please re-add torch_xla to the dependencies. This ensures that even if the primary tpu-inference path fails, the system can gracefully fall back to the torch_xla-based implementation.
tpu-inference==0.11.1
numba
# Install torch_xla for fallback
torch_xla[tpu, pallas]==2.8.0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
vllm/vllm/distributed/device_communicators/tpu_communicator.py
Lines 22 to 62 in 89ad99a
| if not USE_TPU_INFERENCE: | |
| logger.info("tpu_inference not found, using vLLM's TpuCommunicator") | |
| if current_platform.is_tpu(): | |
| import torch_xla | |
| import torch_xla.core.xla_model as xm | |
| import torch_xla.runtime as xr | |
| from torch_xla._internal import pjrt | |
| from torch_xla.distributed.xla_multiprocessing import ( | |
| create_optimized_replica_groups, | |
| ) | |
| if USE_RAY: | |
| from vllm.executor import ray_utils | |
| class TpuCommunicator(DeviceCommunicatorBase): | |
| def __init__( | |
| self, | |
| cpu_group: ProcessGroup, | |
| device: torch.device | None = None, | |
| device_group: ProcessGroup | None = None, | |
| unique_name: str = "", | |
| ): | |
| super().__init__(cpu_group, device, device_group, unique_name) | |
| # NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node | |
| # must be used together. Therefore, the local rank and world size can | |
| # be simply calculated as follows. | |
| global_rank = self.global_rank | |
| global_world_size = self.global_world_size | |
| if USE_RAY: | |
| logger.info("TpuCommunicator initialized with RAY") | |
| # Calculate how many TPU nodes are in the current deployment. This | |
| # is the Ray placement group if it is deployed with Ray. Default | |
| # to the number of TPU nodes in the Ray cluster. The number of TPU | |
| # nodes is computed by the total number of TPUs divided by the | |
| # number of TPU accelerators per node, to account for clusters | |
| # with both CPUs and TPUs. | |
| num_nodes = ray_utils.get_num_tpu_nodes() | |
| num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group() |
With tpu_inference installed USE_TPU_INFERENCE becomes True, but this module now always exposes the local TpuCommunicator instead of aliasing to the implementation in tpu_inference. The conditional import of ray_utils is still wrapped inside if not USE_TPU_INFERENCE, yet __init__ uses ray_utils whenever the distributed executor backend is Ray. In a TPU inference deployment that uses the Ray backend, ray_utils will be undefined and TpuCommunicator.__init__ will raise NameError before any communication is set up. The Ray utilities need to be imported regardless of USE_TPU_INFERENCE (or the Ray-specific logic needs a similar guard).
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
vllm/vllm/distributed/device_communicators/tpu_communicator.py
Lines 22 to 62 in 89ad99a
| if not USE_TPU_INFERENCE: | |
| logger.info("tpu_inference not found, using vLLM's TpuCommunicator") | |
| if current_platform.is_tpu(): | |
| import torch_xla | |
| import torch_xla.core.xla_model as xm | |
| import torch_xla.runtime as xr | |
| from torch_xla._internal import pjrt | |
| from torch_xla.distributed.xla_multiprocessing import ( | |
| create_optimized_replica_groups, | |
| ) | |
| if USE_RAY: | |
| from vllm.executor import ray_utils | |
| class TpuCommunicator(DeviceCommunicatorBase): | |
| def __init__( | |
| self, | |
| cpu_group: ProcessGroup, | |
| device: torch.device | None = None, | |
| device_group: ProcessGroup | None = None, | |
| unique_name: str = "", | |
| ): | |
| super().__init__(cpu_group, device, device_group, unique_name) | |
| # NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node | |
| # must be used together. Therefore, the local rank and world size can | |
| # be simply calculated as follows. | |
| global_rank = self.global_rank | |
| global_world_size = self.global_world_size | |
| if USE_RAY: | |
| logger.info("TpuCommunicator initialized with RAY") | |
| # Calculate how many TPU nodes are in the current deployment. This | |
| # is the Ray placement group if it is deployed with Ray. Default | |
| # to the number of TPU nodes in the Ray cluster. The number of TPU | |
| # nodes is computed by the total number of TPUs divided by the | |
| # number of TPU accelerators per node, to account for clusters | |
| # with both CPUs and TPUs. | |
| num_nodes = ray_utils.get_num_tpu_nodes() | |
| num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group() |
With tpu_inference installed USE_TPU_INFERENCE becomes True, but this module now always exposes the local TpuCommunicator instead of aliasing to the implementation in tpu_inference. The conditional import of ray_utils is still wrapped inside if not USE_TPU_INFERENCE, yet __init__ uses ray_utils whenever the distributed executor backend is Ray. In a TPU inference deployment that uses the Ray backend, ray_utils will be undefined and TpuCommunicator.__init__ will raise NameError before any communication is set up. The Ray utilities need to be imported regardless of USE_TPU_INFERENCE (or the Ray-specific logic needs a similar guard).
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
cb27073 to
db10aa5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: | ||
| assert dim == -1, "TPUs only support dim=-1 for all-gather." | ||
| return xm.all_gather(input_, dim=dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Communicator now crashes when tpu_inference is present
With this change the alias that swapped in tpu_inference’s communicator was removed, but the guard at lines 22‑31 still skips importing torch_xla, xm, xr, etc. whenever USE_TPU_INFERENCE is True. Because tpu-inference is now listed in requirements/tpu.txt, that flag is effectively always True on TPU installs, so the vLLM TpuCommunicator defined in this module is still used yet its methods call xm.*/torch_xla.* (lines 92‑99) even though those symbols were never imported, causing a NameError as soon as the communicator is instantiated or all_reduce/all_gather are called. Either the communicator needs to keep delegating to tpu_inference or the XLA imports must happen unconditionally; otherwise TPU runs fail immediately when tpu-inference is installed.
Useful? React with 👍 / 👎.
50e1f4a to
dabc072
Compare
Signed-off-by: Johnny Yang <johnnyyang@google.com>
dabc072 to
d65756c
Compare
yaochengji
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Signed-off-by: Johnny Yang <johnnyyang@google.com> Signed-off-by: Benjamin Feuer <penfever@gmail.com>
Signed-off-by: Johnny Yang <johnnyyang@google.com>
Signed-off-by: Johnny Yang <johnnyyang@google.com> Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
Signed-off-by: Johnny Yang <johnnyyang@google.com> Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com>
Purpose
Add tpu-inference for vllm TPU.
Test Plan
E2E tests on v6e.
Test Result
Successful
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.