[Doc] Add iter_jax_batches documentation to JaxTrainer guide#63294
Conversation
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
There was a problem hiding this comment.
Code Review
This pull request updates the JAX trainer tutorial documentation in both the notebook and Markdown formats to include an alternative section on using the iter_jax_batches API for native JAX data ingestion. The reviewer suggested renaming the local_batch variable to batch or global_batch to better reflect that the API yields globally sharded arrays. Additionally, it was recommended to explicitly label iter_jax_batches as an "Alpha API" in the Markdown file for consistency and accuracy.
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
|
cc: @siyuanfoundation @liulehui added a section on the |
|
|
||
| In the example above, we used `iter_batches(batch_format="numpy")` and manually transferred data to devices using `jax.make_array_from_process_local_data`. | ||
|
|
||
| Ray Data also provides an Alpha API, `iter_jax_batches`, which streamlines this process by automatically yielding globally sharded JAX Arrays. This can be more efficient and requires less boilerplate code. |
There was a problem hiding this comment.
I think we can add a iter_jax_batches api doc in https://docs.ray.io/en/latest/data/api/dataset.html
and then add the link here,
There was a problem hiding this comment.
done in d499698, i also modified the API in dataset.py to be part of CD_API_GROUP, so this requires Ray Data review now
…ro_to_jax_trainer guides doc: natively use iter_jax_batches inside intro_to_jax_trainer guides Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
d499698 to
2a2a976
Compare
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Reviewed by Cursor Bugbot for commit 2a2a976. Configure here.
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
1163f7a to
de6cb29
Compare
|
@bveeramani wondering if you can review from the Ray Data side for the small |
…ject#63294) ## Description ray-project#61630 added support for JAX to Ray data, specifically implementing a `iter_jax_batches` util to yield natively sharded `jax.Arrays`. This provides first-class support for processing data within JaxTrainer workloads. This PR updates an existing GPT-2 guide using the `JaxTrainer` to showcase how this new util could simplify the Train code. ## Related issues ray-project#55162 ## Additional information > Optional: Add implementation details, API changes, usage examples, screenshots, etc. --------- Signed-off-by: Ryan O'Leary <ryanaoleary@google.com> Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…ject#63294) ## Description ray-project#61630 added support for JAX to Ray data, specifically implementing a `iter_jax_batches` util to yield natively sharded `jax.Arrays`. This provides first-class support for processing data within JaxTrainer workloads. This PR updates an existing GPT-2 guide using the `JaxTrainer` to showcase how this new util could simplify the Train code. ## Related issues ray-project#55162 ## Additional information > Optional: Add implementation details, API changes, usage examples, screenshots, etc. --------- Signed-off-by: Ryan O'Leary <ryanaoleary@google.com> Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: anindyam1969 <amukherjee@kinetica.com>
…ject#63294) ## Description ray-project#61630 added support for JAX to Ray data, specifically implementing a `iter_jax_batches` util to yield natively sharded `jax.Arrays`. This provides first-class support for processing data within JaxTrainer workloads. This PR updates an existing GPT-2 guide using the `JaxTrainer` to showcase how this new util could simplify the Train code. ## Related issues ray-project#55162 ## Additional information > Optional: Add implementation details, API changes, usage examples, screenshots, etc. --------- Signed-off-by: Ryan O'Leary <ryanaoleary@google.com> Signed-off-by: Ryan O'Leary <113500783+ryanaoleary@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

Description
#61630 added support for JAX to Ray data, specifically implementing a
iter_jax_batchesutil to yield natively shardedjax.Arrays. This provides first-class support for processing data within JaxTrainer workloads. This PR updates an existing GPT-2 guide using theJaxTrainerto showcase how this new util could simplify the Train code.Related issues
#55162
Additional information