Skip to content

Conversation

junjieqian
Copy link
Collaborator

@junjieqian junjieqian commented Sep 3, 2025

This PR supports checkpointing with torchax:

  1. load a checkpoint file in torch tensors and convert to Jax arrays; Or load a checkpoint file in Jax arrays
  2. save a checkpoint file in Jax arrays.

This support single worker now.

@junjieqian junjieqian requested a review from vlad-karp September 3, 2025 04:51
@junjieqian junjieqian marked this pull request as draft September 3, 2025 04:52
@junjieqian junjieqian force-pushed the junjieqian/checkpoint branch from 0ad1525 to d94e2cc Compare September 3, 2025 04:57
@junjieqian junjieqian force-pushed the junjieqian/checkpoint branch from d94e2cc to fae903e Compare September 3, 2025 05:02
@vlad-karp
Copy link
Collaborator

LGTM overall.
Two concerns to test later:

  1. multihost setup
  2. Reading/writing the states of our Flax backed SparsceCore modules

@junjieqian junjieqian force-pushed the junjieqian/checkpoint branch from a2098b8 to 08f610a Compare September 4, 2025 18:44
@junjieqian junjieqian marked this pull request as ready for review September 4, 2025 19:59
@junjieqian junjieqian requested a review from qihqi September 4, 2025 19:59
@qihqi qihqi merged commit 7aba922 into master Sep 10, 2025
24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants