Skip to content

Commit d94e2cc

Browse files
committed
checkpoint in torchax
1 parent 89f929b commit d94e2cc

File tree

4 files changed

+202
-1
lines changed

4 files changed

+202
-1
lines changed

torchax/README.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,42 @@ The first time `m_jitted` is called, it will trigger `jax.jit` to compile the
179179
compile for the given input shapes. Subsequent calls with the same input shapes
180180
will be fast as the compilation is cached.
181181

182+
## Saving and Loading Checkpoints
183+
184+
You can use `torchax.save_checkpoint` and `torchax.load_checkpoint` to save and load your training state. The state can be a dictionary containing the model's weights, optimizer state, and any other information you want to save.
185+
186+
```python
187+
import torchax
188+
import torch
189+
import optax
190+
191+
# Assume model, optimizer, and other states are defined
192+
model = MyModel()
193+
optimizer = optax.adam(1e-3)
194+
opt_state = optimizer.init(model.parameters())
195+
weights = model.parameters()
196+
buffers = model.buffers()
197+
epoch = 10
198+
199+
state = {
200+
'weights': weights,
201+
'buffers': buffers,
202+
'opt_state': opt_state,
203+
'epoch': epoch,
204+
}
205+
206+
# Save checkpoint
207+
torchax.save_checkpoint(state, '/path/to/checkpoint.pt')
208+
209+
# Load checkpoint
210+
loaded_state = torchax.load_checkpoint('/path/to/checkpoint.pt')
211+
212+
# Restore state
213+
model.load_state_dict(loaded_state['weights'])
214+
opt_state = loaded_state['opt_state']
215+
epoch = loaded_state['epoch']
216+
```
217+
182218
## Citation
183219

184220
```

torchax/test/test_checkpoint.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import unittest
2+
import torch
3+
import torch.nn as nn
4+
import torchax
5+
from torchax.checkpoint import _to_torch, _to_jax
6+
import optax
7+
import tempfile
8+
import os
9+
import jax
10+
import jax.numpy as jnp
11+
import shutil
12+
13+
14+
class CheckpointTest(unittest.TestCase):
15+
16+
def test_save_and_load_jax_style_checkpoint(self):
17+
model = torch.nn.Linear(10, 20)
18+
optimizer = optax.adam(1e-3)
19+
20+
torchax.enable_globally()
21+
params_jax, _ = torchax.extract_jax(model)
22+
opt_state = optimizer.init(params_jax)
23+
torchax.disable_globally()
24+
25+
epoch = 1
26+
state = {
27+
'model': model.state_dict(),
28+
'opt_state': opt_state,
29+
'epoch': epoch,
30+
}
31+
32+
with tempfile.TemporaryDirectory() as tmpdir:
33+
path = os.path.join(tmpdir, 'checkpoint')
34+
torchax.save_checkpoint(state, path, step=epoch)
35+
loaded_state_jax = torchax.load_checkpoint(path)
36+
loaded_state = _to_torch(loaded_state_jax)
37+
38+
self.assertEqual(state['epoch'], loaded_state['epoch'])
39+
40+
# Compare model state_dict
41+
for key in state['model']:
42+
self.assertTrue(
43+
torch.allclose(state['model'][key], loaded_state['model'][key]))
44+
45+
# Compare optimizer state
46+
original_leaves = jax.tree_util.tree_leaves(state['opt_state'])
47+
loaded_leaves = jax.tree_util.tree_leaves(loaded_state['opt_state'])
48+
for original_leaf, loaded_leaf in zip(original_leaves, loaded_leaves):
49+
if isinstance(original_leaf, (jnp.ndarray, jax.Array)):
50+
# Convert loaded leaf to numpy array for comparison if it is a DeviceArray
51+
self.assertTrue(jnp.allclose(original_leaf, jnp.asarray(loaded_leaf)))
52+
else:
53+
self.assertEqual(original_leaf, loaded_leaf)
54+
55+
def test_load_pytorch_style_checkpoint(self):
56+
model = torch.nn.Linear(10, 20)
57+
optimizer = optax.adam(1e-3)
58+
59+
torchax.enable_globally()
60+
params_jax, _ = torchax.extract_jax(model)
61+
opt_state = optimizer.init(params_jax)
62+
torchax.disable_globally()
63+
64+
epoch = 1
65+
state = {
66+
'model': model.state_dict(),
67+
'opt_state': opt_state,
68+
'epoch': epoch,
69+
}
70+
71+
with tempfile.TemporaryDirectory() as tmpdir:
72+
path = os.path.join(tmpdir, 'checkpoint.pt')
73+
torch.save(state, path)
74+
loaded_state_jax = torchax.load_checkpoint(path)
75+
76+
# convert original state to jax for comparison
77+
state_jax = _to_jax(state)
78+
79+
self.assertEqual(state_jax['epoch'], loaded_state_jax['epoch'])
80+
81+
# Compare model state_dict
82+
for key in state_jax['model']:
83+
self.assertTrue(
84+
jnp.allclose(state_jax['model'][key],
85+
loaded_state_jax['model'][key]))
86+
87+
# Compare optimizer state
88+
original_leaves = jax.tree_util.tree_leaves(state_jax['opt_state'])
89+
loaded_leaves = jax.tree_util.tree_leaves(loaded_state_jax['opt_state'])
90+
for original_leaf, loaded_leaf in zip(original_leaves, loaded_leaves):
91+
if isinstance(original_leaf, (jnp.ndarray, jax.Array)):
92+
self.assertTrue(jnp.allclose(original_leaf, loaded_leaf))
93+
else:
94+
self.assertEqual(original_leaf, loaded_leaf)
95+
96+
def test_load_non_existent_checkpoint(self):
97+
with self.assertRaises(FileNotFoundError):
98+
torchax.load_checkpoint('/path/to/non_existent_checkpoint')
99+
100+
101+
if __name__ == '__main__':
102+
unittest.main()

torchax/torchax/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
'default_env',
1616
'extract_jax',
1717
'enable_globally',
18+
'save_checkpoint',
19+
'load_checkpoint',
1820
]
1921

20-
from jax._src import xla_bridge
22+
from .checkpoint import save_checkpoint, load_checkpoint
23+
2124

2225
os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
2326

torchax/torchax/checkpoint.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import torch
2+
import os
3+
from typing import Any, Dict
4+
from flax.training import checkpoints
5+
import jax
6+
import jax.numpy as jnp
7+
import numpy as np
8+
9+
10+
def _to_jax(pytree):
11+
return jax.tree_util.tree_map(
12+
lambda x: jnp.asarray(x.cpu().numpy())
13+
if isinstance(x, torch.Tensor) else x, pytree)
14+
15+
16+
def _to_torch(pytree):
17+
return jax.tree_util.tree_map(
18+
lambda x: torch.from_numpy(np.asarray(x))
19+
if isinstance(x, (jnp.ndarray, jax.Array)) else x, pytree)
20+
21+
22+
def save_checkpoint(state: Dict[str, Any], path: str, step: int):
23+
"""Saves a checkpoint to a file in JAX style.
24+
25+
Args:
26+
state: A dictionary containing the state to save. torch.Tensors will be
27+
converted to jax.Array.
28+
path: The path to save the checkpoint to. This is a directory.
29+
step: The training step.
30+
"""
31+
state = _to_jax(state)
32+
checkpoints.save_checkpoint(path, state, step=step, overwrite=True)
33+
34+
35+
def load_checkpoint(path: str) -> Dict[str, Any]:
36+
"""Loads a checkpoint and returns it in JAX format.
37+
38+
This function can load both PyTorch-style (single file) and JAX-style
39+
(directory) checkpoints.
40+
41+
If the checkpoint is in PyTorch format, it will be converted to JAX format.
42+
43+
Args:
44+
path: The path to the checkpoint.
45+
46+
Returns:
47+
The loaded state in JAX format (pytree with jax.Array leaves).
48+
"""
49+
if os.path.isdir(path):
50+
# JAX-style checkpoint
51+
state = checkpoints.restore_checkpoint(path, target=None)
52+
if state is None:
53+
raise FileNotFoundError(f"No checkpoint found at {path}")
54+
return state
55+
elif os.path.isfile(path):
56+
# PyTorch-style checkpoint
57+
state = torch.load(path, weights_only=False)
58+
return _to_jax(state)
59+
else:
60+
raise FileNotFoundError(f"No such file or directory: {path}")

0 commit comments

Comments
 (0)