Skip to content

[Train] JaxTrainer Implementation Tracking Issue #55162

@ryanaoleary

Description

@ryanaoleary

Description

This issue will serve as an implementation tracker for a JaxTrainer to support jax and SPMD workloads. The initial support for this framework will target SPMD with multi-host TPUs on Kubernetes.

Milestone 1: MVP of JaxTrainer with single-slice multi-host TPUs

Milestone 2: Full support for TPU multi-slice

We can update this issue with more steps as milestones become clear.

Use case

Full support for SPMD workloads orchestrated with Ray and RayTrain.

Metadata

Metadata

Assignees

Labels

community-backlogenhancementRequest for new feature and/or capabilityperformancetrainRay Train Related IssuetriageNeeds triage (eg: priority, bug/not-bug, and owning component)usability

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions