diff --git a/.github/workflows/_test.yml b/.github/workflows/_test.yml index 4ef00dcedae..e21ecbfc47b 100644 --- a/.github/workflows/_test.yml +++ b/.github/workflows/_test.yml @@ -121,6 +121,7 @@ jobs: # TODO: Add these in setup.py pip install fsspec pip install rich + pip install flax - name: Checkout PyTorch Repo if: inputs.has_code_changes == 'true' uses: actions/checkout@v4 diff --git a/.github/workflows/_tpu_ci.yml b/.github/workflows/_tpu_ci.yml index dc766c53a89..b67f695f81e 100644 --- a/.github/workflows/_tpu_ci.yml +++ b/.github/workflows/_tpu_ci.yml @@ -55,6 +55,7 @@ jobs: pip install --pre 'torch_xla[pallas]' --index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html' pip install --pre 'torch_xla[tpu]' --index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html' pip install --upgrade protobuf + pip install flax - name: Run Tests (${{ matrix.test_script }}) if: inputs.has_code_changes == 'true' env: diff --git a/torchax/README.md b/torchax/README.md index 06d9e26d7dc..4a9d85ca644 100644 --- a/torchax/README.md +++ b/torchax/README.md @@ -179,6 +179,42 @@ The first time `m_jitted` is called, it will trigger `jax.jit` to compile the compile for the given input shapes. Subsequent calls with the same input shapes will be fast as the compilation is cached. +## Saving and Loading Checkpoints + +You can use `torchax.save_checkpoint` and `torchax.load_checkpoint` to save and load your training state. The state can be a dictionary containing the model's weights, optimizer state, and any other information you want to save. + +```python +import torchax +import torch +import optax + +# Assume model, optimizer, and other states are defined +model = MyModel() +optimizer = optax.adam(1e-3) +opt_state = optimizer.init(model.parameters()) +weights = model.parameters() +buffers = model.buffers() +epoch = 10 + +state = { + 'weights': weights, + 'buffers': buffers, + 'opt_state': opt_state, + 'epoch': epoch, +} + +# Save checkpoint +torchax.save_checkpoint(state, '/path/to/checkpoint.pt') + +# Load checkpoint +loaded_state = torchax.load_checkpoint('/path/to/checkpoint.pt') + +# Restore state +model.load_state_dict(loaded_state['weights']) +opt_state = loaded_state['opt_state'] +epoch = loaded_state['epoch'] +``` + ## Citation ``` diff --git a/torchax/test/test_checkpoint.py b/torchax/test/test_checkpoint.py new file mode 100644 index 00000000000..4867d44b1eb --- /dev/null +++ b/torchax/test/test_checkpoint.py @@ -0,0 +1,102 @@ +import unittest +import torch +import torch.nn as nn +import torchax +from torchax.checkpoint import _to_torch, _to_jax +import optax +import tempfile +import os +import jax +import jax.numpy as jnp +import shutil + + +class CheckpointTest(unittest.TestCase): + + def test_save_and_load_jax_style_checkpoint(self): + model = torch.nn.Linear(10, 20) + optimizer = optax.adam(1e-3) + + torchax.enable_globally() + params_jax, _ = torchax.extract_jax(model) + opt_state = optimizer.init(params_jax) + torchax.disable_globally() + + epoch = 1 + state = { + 'model': model.state_dict(), + 'opt_state': opt_state, + 'epoch': epoch, + } + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'checkpoint') + torchax.save_checkpoint(state, path, step=epoch) + loaded_state_jax = torchax.load_checkpoint(path) + loaded_state = _to_torch(loaded_state_jax) + + self.assertEqual(state['epoch'], loaded_state['epoch']) + + # Compare model state_dict + for key in state['model']: + self.assertTrue( + torch.allclose(state['model'][key], loaded_state['model'][key])) + + # Compare optimizer state + original_leaves = jax.tree_util.tree_leaves(state['opt_state']) + loaded_leaves = jax.tree_util.tree_leaves(loaded_state['opt_state']) + for original_leaf, loaded_leaf in zip(original_leaves, loaded_leaves): + if isinstance(original_leaf, (jnp.ndarray, jax.Array)): + # Convert loaded leaf to numpy array for comparison if it is a DeviceArray + self.assertTrue(jnp.allclose(original_leaf, jnp.asarray(loaded_leaf))) + else: + self.assertEqual(original_leaf, loaded_leaf) + + def test_load_pytorch_style_checkpoint(self): + model = torch.nn.Linear(10, 20) + optimizer = optax.adam(1e-3) + + torchax.enable_globally() + params_jax, _ = torchax.extract_jax(model) + opt_state = optimizer.init(params_jax) + torchax.disable_globally() + + epoch = 1 + state = { + 'model': model.state_dict(), + 'opt_state': opt_state, + 'epoch': epoch, + } + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'checkpoint.pt') + torch.save(state, path) + loaded_state_jax = torchax.load_checkpoint(path) + + # convert original state to jax for comparison + state_jax = _to_jax(state) + + self.assertEqual(state_jax['epoch'], loaded_state_jax['epoch']) + + # Compare model state_dict + for key in state_jax['model']: + self.assertTrue( + jnp.allclose(state_jax['model'][key], + loaded_state_jax['model'][key])) + + # Compare optimizer state + original_leaves = jax.tree_util.tree_leaves(state_jax['opt_state']) + loaded_leaves = jax.tree_util.tree_leaves(loaded_state_jax['opt_state']) + for original_leaf, loaded_leaf in zip(original_leaves, loaded_leaves): + if isinstance(original_leaf, (jnp.ndarray, jax.Array)): + self.assertTrue(jnp.allclose(original_leaf, loaded_leaf)) + else: + self.assertEqual(original_leaf, loaded_leaf) + + def test_load_non_existent_checkpoint(self): + with self.assertRaises(FileNotFoundError): + torchax.load_checkpoint('/path/to/non_existent_checkpoint') + + +if __name__ == '__main__': + unittest.main() diff --git a/torchax/torchax/__init__.py b/torchax/torchax/__init__.py index fe4c1c8ff04..4921ae33104 100644 --- a/torchax/torchax/__init__.py +++ b/torchax/torchax/__init__.py @@ -15,9 +15,11 @@ 'default_env', 'extract_jax', 'enable_globally', + 'save_checkpoint', + 'load_checkpoint', ] -from jax._src import xla_bridge +from .checkpoint import save_checkpoint, load_checkpoint os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') diff --git a/torchax/torchax/checkpoint.py b/torchax/torchax/checkpoint.py new file mode 100644 index 00000000000..daded1c3afa --- /dev/null +++ b/torchax/torchax/checkpoint.py @@ -0,0 +1,60 @@ +import torch +import os +from typing import Any, Dict +from flax.training import checkpoints +import jax +import jax.numpy as jnp +import numpy as np + + +def _to_jax(pytree): + return jax.tree_util.tree_map( + lambda x: jnp.asarray(x.cpu().numpy()) + if isinstance(x, torch.Tensor) else x, pytree) + + +def _to_torch(pytree): + return jax.tree_util.tree_map( + lambda x: torch.from_numpy(np.asarray(x)) + if isinstance(x, (jnp.ndarray, jax.Array)) else x, pytree) + + +def save_checkpoint(state: Dict[str, Any], path: str, step: int): + """Saves a checkpoint to a file in JAX style. + + Args: + state: A dictionary containing the state to save. torch.Tensors will be + converted to jax.Array. + path: The path to save the checkpoint to. This is a directory. + step: The training step. + """ + state = _to_jax(state) + checkpoints.save_checkpoint(path, state, step=step, overwrite=True) + + +def load_checkpoint(path: str) -> Dict[str, Any]: + """Loads a checkpoint and returns it in JAX format. + + This function can load both PyTorch-style (single file) and JAX-style + (directory) checkpoints. + + If the checkpoint is in PyTorch format, it will be converted to JAX format. + + Args: + path: The path to the checkpoint. + + Returns: + The loaded state in JAX format (pytree with jax.Array leaves). + """ + if os.path.isdir(path): + # JAX-style checkpoint + state = checkpoints.restore_checkpoint(path, target=None) + if state is None: + raise FileNotFoundError(f"No checkpoint found at {path}") + return state + elif os.path.isfile(path): + # PyTorch-style checkpoint + state = torch.load(path, weights_only=False) + return _to_jax(state) + else: + raise FileNotFoundError(f"No such file or directory: {path}")