Skip to content

[Doc] Add iter_jax_batches documentation to JaxTrainer guide#63294

Merged
ryanaoleary merged 8 commits into
ray-project:masterfrom
ryanaoleary:doc-iter-jax-batches
May 12, 2026
Merged

[Doc] Add iter_jax_batches documentation to JaxTrainer guide#63294
ryanaoleary merged 8 commits into
ray-project:masterfrom
ryanaoleary:doc-iter-jax-batches

Conversation

@ryanaoleary
Copy link
Copy Markdown
Contributor

Description

#61630 added support for JAX to Ray data, specifically implementing a iter_jax_batches util to yield natively sharded jax.Arrays. This provides first-class support for processing data within JaxTrainer workloads. This PR updates an existing GPT-2 guide using the JaxTrainer to showcase how this new util could simplify the Train code.

Related issues

#55162

Additional information

Optional: Add implementation details, API changes, usage examples, screenshots, etc.

Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
@ryanaoleary ryanaoleary requested a review from a team as a code owner May 12, 2026 04:34
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the JAX trainer tutorial documentation in both the notebook and Markdown formats to include an alternative section on using the iter_jax_batches API for native JAX data ingestion. The reviewer suggested renaming the local_batch variable to batch or global_batch to better reflect that the API yields globally sharded arrays. Additionally, it was recommended to explicitly label iter_jax_batches as an "Alpha API" in the Markdown file for consistency and accuracy.

Comment thread doc/source/train/examples/jax/intro_to_jax_trainer/README.ipynb Outdated
Comment thread doc/source/train/examples/jax/intro_to_jax_trainer/README.md Outdated
Comment thread doc/source/train/examples/jax/intro_to_jax_trainer/README.md Outdated
ryanaoleary and others added 4 commits May 12, 2026 04:56
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
@ryanaoleary
Copy link
Copy Markdown
Contributor Author

ryanaoleary commented May 12, 2026

cc: @siyuanfoundation @liulehui added a section on the iter_jax_batches API so we have an example with JaxTrainer in the Ray docs


In the example above, we used `iter_batches(batch_format="numpy")` and manually transferred data to devices using `jax.make_array_from_process_local_data`.

Ray Data also provides an Alpha API, `iter_jax_batches`, which streamlines this process by automatically yielding globally sharded JAX Arrays. This can be more efficient and requires less boilerplate code.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can add a iter_jax_batches api doc in https://docs.ray.io/en/latest/data/api/dataset.html
and then add the link here,

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in d499698, i also modified the API in dataset.py to be part of CD_API_GROUP, so this requires Ray Data review now

Comment thread doc/source/train/examples/jax/intro_to_jax_trainer/README.md Outdated
Comment thread doc/source/train/examples/jax/intro_to_jax_trainer/README.md
Copy link
Copy Markdown
Contributor

@liulehui liulehui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tyty!❤️

@ryanaoleary ryanaoleary requested a review from a team as a code owner May 12, 2026 06:27
@ryanaoleary ryanaoleary requested a review from liulehui May 12, 2026 06:28
…ro_to_jax_trainer guides

doc: natively use iter_jax_batches inside intro_to_jax_trainer guides
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
@ryanaoleary ryanaoleary force-pushed the doc-iter-jax-batches branch from d499698 to 2a2a976 Compare May 12, 2026 06:29
Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Fix All in Cursor

Reviewed by Cursor Bugbot for commit 2a2a976. Configure here.

Comment thread doc/source/train/examples/jax/intro_to_jax_trainer/README.ipynb Outdated
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
@ryanaoleary ryanaoleary force-pushed the doc-iter-jax-batches branch from 1163f7a to de6cb29 Compare May 12, 2026 06:43
@ray-gardener ray-gardener Bot added docs An issue or change related to documentation train Ray Train Related Issue community-contribution Contributed by the community labels May 12, 2026
@ryanaoleary ryanaoleary enabled auto-merge (squash) May 12, 2026 18:36
@github-actions github-actions Bot added the go add ONLY when ready to merge, run all tests label May 12, 2026
@ryanaoleary
Copy link
Copy Markdown
Contributor Author

@bveeramani wondering if you can review from the Ray Data side for the small dataset.py change

Copy link
Copy Markdown
Member

@bveeramani bveeramani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stamp for Data changes

@ryanaoleary ryanaoleary merged commit 5aba71b into ray-project:master May 12, 2026
11 checks passed
dancingactor pushed a commit to dancingactor/ray that referenced this pull request May 13, 2026
…ject#63294)

## Description
ray-project#61630 added support for JAX to
Ray data, specifically implementing a `iter_jax_batches` util to yield
natively sharded `jax.Arrays`. This provides first-class support for
processing data within JaxTrainer workloads. This PR updates an existing
GPT-2 guide using the `JaxTrainer` to showcase how this new util could
simplify the Train code.

## Related issues
ray-project#55162

## Additional information
> Optional: Add implementation details, API changes, usage examples,
screenshots, etc.

---------

Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
am-kinetica pushed a commit to kineticadb/ray that referenced this pull request May 14, 2026
…ject#63294)

## Description
ray-project#61630 added support for JAX to
Ray data, specifically implementing a `iter_jax_batches` util to yield
natively sharded `jax.Arrays`. This provides first-class support for
processing data within JaxTrainer workloads. This PR updates an existing
GPT-2 guide using the `JaxTrainer` to showcase how this new util could
simplify the Train code.

## Related issues
ray-project#55162

## Additional information
> Optional: Add implementation details, API changes, usage examples,
screenshots, etc.

---------

Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: anindyam1969 <amukherjee@kinetica.com>
Lucas61000 pushed a commit to Lucas61000/ray that referenced this pull request May 15, 2026
…ject#63294)

## Description
ray-project#61630 added support for JAX to
Ray data, specifically implementing a `iter_jax_batches` util to yield
natively sharded `jax.Arrays`. This provides first-class support for
processing data within JaxTrainer workloads. This PR updates an existing
GPT-2 guide using the `JaxTrainer` to showcase how this new util could
simplify the Train code.

## Related issues
ray-project#55162

## Additional information
> Optional: Add implementation details, API changes, usage examples,
screenshots, etc.

---------

Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution Contributed by the community docs An issue or change related to documentation go add ONLY when ready to merge, run all tests train Ray Train Related Issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants