-
Notifications
You must be signed in to change notification settings - Fork 6.9k
Open
Labels
community-backlogenhancementRequest for new feature and/or capabilityRequest for new feature and/or capabilityperformancetrainRay Train Related IssueRay Train Related IssuetriageNeeds triage (eg: priority, bug/not-bug, and owning component)Needs triage (eg: priority, bug/not-bug, and owning component)usability
Description
Description
This issue will serve as an implementation tracker for a JaxTrainer to support jax and SPMD workloads. The initial support for this framework will target SPMD with multi-host TPUs on Kubernetes.
Milestone 1: MVP of JaxTrainer with single-slice multi-host TPUs
- Add default TPU info to Ray node labels
- Add API change to
ScalingConfigto specifytopologyandacceleratorarguments - Add JaxTrainer wrapper of DataParallelTrainer to RayTrain with SPMD scheduling support
- Extensively test Jax training workload with multi-host TPUs and Anyscale and KubeRay operators
- Add documentation and guides to RayTrain docs
Milestone 2: Full support for TPU multi-slice
We can update this issue with more steps as milestones become clear.
Use case
Full support for SPMD workloads orchestrated with Ray and RayTrain.
matthewdeng, liulehui and hosseinsarshar
Metadata
Metadata
Assignees
Labels
community-backlogenhancementRequest for new feature and/or capabilityRequest for new feature and/or capabilityperformancetrainRay Train Related IssueRay Train Related IssuetriageNeeds triage (eg: priority, bug/not-bug, and owning component)Needs triage (eg: priority, bug/not-bug, and owning component)usability