Skip to content

Add new tpu backend for torch#63442

Draft
siyuanfoundation wants to merge 1 commit into
ray-project:masterfrom
siyuanfoundation:torchtpu
Draft

Add new tpu backend for torch#63442
siyuanfoundation wants to merge 1 commit into
ray-project:masterfrom
siyuanfoundation:torchtpu

Conversation

@siyuanfoundation
Copy link
Copy Markdown
Contributor

@siyuanfoundation siyuanfoundation commented May 18, 2026

Description

Add a new torch trainer backend for torch_tpu

Related to this announcement https://developers.googleblog.com/torchtpu-running-pytorch-natively-on-tpus-at-google-scale/

Related issues

Additional information

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces TPU support for Ray Train and Ray AIR, including a new TPUTorchDeviceManager and logic to inject environment variables for distributed TPU training. Feedback highlights a missing use_tpu field in ScalingConfig that will cause runtime errors and an empty registration method for the tpu_dist backend. It is also recommended to use math.prod instead of functools.reduce for more idiomatic calculations of TPU chip products.

Comment thread python/ray/air/_internal/device_manager/tpu.py Outdated
Comment thread python/ray/air/config.py Outdated
Comment thread python/ray/_private/accelerators/tpu.py Outdated
Comment thread python/ray/_private/accelerators/tpu.py Outdated
Comment thread python/ray/_private/accelerators/tpu.py Outdated
Signed-off-by: siyuanfoundation <sizhang@google.com>
return

@staticmethod
def set_current_process_visible_accelerator_ids(
Copy link
Copy Markdown
Contributor

@ryanaoleary ryanaoleary May 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only called from ray.util.set_visible_accelerator_ids - so it won't be set automatically which I think is the intended goal. I would prefer adding something like set_accelerator_env_vars to AcceleratorManager and then the path for TPU is:

  1. On Ray node init we call ResourceAndLabelSpec to pass required resources/labels to the Raylet.
  2. Call set_accelerator_env_vars if the AcceleratorManager is not None
  3. For the TPU implementation, we check for the expected Torch and/or JAX env vars and set them in the raylet process if missing
  4. We may want to consider gating the above on if the user passes some flag or env var that indicates they're going to use Torch TPU.

This would keep it extensible to other accelerator types.

)

def set_device(self, device: Union[torch.device, int, str, None]):
# TPU device setting is typically handled by torch_tpu.api.tpu_device()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid silent pass and add some debug logging if we expect set_device to be called but want to force it to be a no-op

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants