Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ jobs:
# TODO: Add these in setup.py
pip install fsspec
pip install rich
pip install flax
- name: Checkout PyTorch Repo
if: inputs.has_code_changes == 'true'
uses: actions/checkout@v4
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/_tpu_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ jobs:
pip install --pre 'torch_xla[pallas]' --index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html'
pip install --pre 'torch_xla[tpu]' --index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html'
pip install --upgrade protobuf
pip install flax
- name: Run Tests (${{ matrix.test_script }})
if: inputs.has_code_changes == 'true'
env:
Expand Down
36 changes: 36 additions & 0 deletions torchax/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,42 @@ The first time `m_jitted` is called, it will trigger `jax.jit` to compile the
compile for the given input shapes. Subsequent calls with the same input shapes
will be fast as the compilation is cached.

## Saving and Loading Checkpoints

You can use `torchax.save_checkpoint` and `torchax.load_checkpoint` to save and load your training state. The state can be a dictionary containing the model's weights, optimizer state, and any other information you want to save.

```python
import torchax
import torch
import optax

# Assume model, optimizer, and other states are defined
model = MyModel()
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(model.parameters())
weights = model.parameters()
buffers = model.buffers()
epoch = 10

state = {
'weights': weights,
'buffers': buffers,
'opt_state': opt_state,
'epoch': epoch,
}

# Save checkpoint
torchax.save_checkpoint(state, '/path/to/checkpoint.pt')

# Load checkpoint
loaded_state = torchax.load_checkpoint('/path/to/checkpoint.pt')

# Restore state
model.load_state_dict(loaded_state['weights'])
opt_state = loaded_state['opt_state']
epoch = loaded_state['epoch']
```

## Citation

```
Expand Down
102 changes: 102 additions & 0 deletions torchax/test/test_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import unittest
import torch
import torch.nn as nn
import torchax
from torchax.checkpoint import _to_torch, _to_jax
import optax
import tempfile
import os
import jax
import jax.numpy as jnp
import shutil


class CheckpointTest(unittest.TestCase):

def test_save_and_load_jax_style_checkpoint(self):
model = torch.nn.Linear(10, 20)
optimizer = optax.adam(1e-3)

torchax.enable_globally()
params_jax, _ = torchax.extract_jax(model)
opt_state = optimizer.init(params_jax)
torchax.disable_globally()

epoch = 1
state = {
'model': model.state_dict(),
'opt_state': opt_state,
'epoch': epoch,
}

with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'checkpoint')
torchax.save_checkpoint(state, path, step=epoch)
loaded_state_jax = torchax.load_checkpoint(path)
loaded_state = _to_torch(loaded_state_jax)

self.assertEqual(state['epoch'], loaded_state['epoch'])

# Compare model state_dict
for key in state['model']:
self.assertTrue(
torch.allclose(state['model'][key], loaded_state['model'][key]))

# Compare optimizer state
original_leaves = jax.tree_util.tree_leaves(state['opt_state'])
loaded_leaves = jax.tree_util.tree_leaves(loaded_state['opt_state'])
for original_leaf, loaded_leaf in zip(original_leaves, loaded_leaves):
if isinstance(original_leaf, (jnp.ndarray, jax.Array)):
# Convert loaded leaf to numpy array for comparison if it is a DeviceArray
self.assertTrue(jnp.allclose(original_leaf, jnp.asarray(loaded_leaf)))
else:
self.assertEqual(original_leaf, loaded_leaf)

def test_load_pytorch_style_checkpoint(self):
model = torch.nn.Linear(10, 20)
optimizer = optax.adam(1e-3)

torchax.enable_globally()
params_jax, _ = torchax.extract_jax(model)
opt_state = optimizer.init(params_jax)
torchax.disable_globally()

epoch = 1
state = {
'model': model.state_dict(),
'opt_state': opt_state,
'epoch': epoch,
}

with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'checkpoint.pt')
torch.save(state, path)
loaded_state_jax = torchax.load_checkpoint(path)

# convert original state to jax for comparison
state_jax = _to_jax(state)

self.assertEqual(state_jax['epoch'], loaded_state_jax['epoch'])

# Compare model state_dict
for key in state_jax['model']:
self.assertTrue(
jnp.allclose(state_jax['model'][key],
loaded_state_jax['model'][key]))

# Compare optimizer state
original_leaves = jax.tree_util.tree_leaves(state_jax['opt_state'])
loaded_leaves = jax.tree_util.tree_leaves(loaded_state_jax['opt_state'])
for original_leaf, loaded_leaf in zip(original_leaves, loaded_leaves):
if isinstance(original_leaf, (jnp.ndarray, jax.Array)):
self.assertTrue(jnp.allclose(original_leaf, loaded_leaf))
else:
self.assertEqual(original_leaf, loaded_leaf)

def test_load_non_existent_checkpoint(self):
with self.assertRaises(FileNotFoundError):
torchax.load_checkpoint('/path/to/non_existent_checkpoint')


if __name__ == '__main__':
unittest.main()
4 changes: 3 additions & 1 deletion torchax/torchax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
'default_env',
'extract_jax',
'enable_globally',
'save_checkpoint',
'load_checkpoint',
]

from jax._src import xla_bridge
from .checkpoint import save_checkpoint, load_checkpoint

os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')

Expand Down
60 changes: 60 additions & 0 deletions torchax/torchax/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch
import os
from typing import Any, Dict
from flax.training import checkpoints
import jax
import jax.numpy as jnp
import numpy as np


def _to_jax(pytree):
return jax.tree_util.tree_map(
lambda x: jnp.asarray(x.cpu().numpy())
if isinstance(x, torch.Tensor) else x, pytree)


def _to_torch(pytree):
return jax.tree_util.tree_map(
lambda x: torch.from_numpy(np.asarray(x))
if isinstance(x, (jnp.ndarray, jax.Array)) else x, pytree)


def save_checkpoint(state: Dict[str, Any], path: str, step: int):
"""Saves a checkpoint to a file in JAX style.

Args:
state: A dictionary containing the state to save. torch.Tensors will be
converted to jax.Array.
path: The path to save the checkpoint to. This is a directory.
step: The training step.
"""
state = _to_jax(state)
checkpoints.save_checkpoint(path, state, step=step, overwrite=True)


def load_checkpoint(path: str) -> Dict[str, Any]:
"""Loads a checkpoint and returns it in JAX format.

This function can load both PyTorch-style (single file) and JAX-style
(directory) checkpoints.

If the checkpoint is in PyTorch format, it will be converted to JAX format.

Args:
path: The path to the checkpoint.

Returns:
The loaded state in JAX format (pytree with jax.Array leaves).
"""
if os.path.isdir(path):
# JAX-style checkpoint
state = checkpoints.restore_checkpoint(path, target=None)
if state is None:
raise FileNotFoundError(f"No checkpoint found at {path}")
return state
elif os.path.isfile(path):
# PyTorch-style checkpoint
state = torch.load(path, weights_only=False)
return _to_jax(state)
else:
raise FileNotFoundError(f"No such file or directory: {path}")