In [None]:
import jax
import optax
from flax.training.train_state import TrainState
from orbax.checkpoint import PyTreeCheckpointer

from flightning import FLIGHTNING_PATH
from flightning.modules import MLP

# How to Save and Load JAX Policies

Saving policies or their parameters works easily using the `orbax` package. The idea is that many container objects like lists, tuples, dictionaries (paramaters are dictionaries containing arrays), and other custom classes are PyTrees. [More about Pytrees.](https://jax.readthedocs.io/en/latest/pytrees.html) We simply create a checkpointer object that can save and load pytrees.

## Save Parameters

In [None]:
mlp = MLP([2, 3, 1])
params = mlp.initialize(jax.random.key(0))

train_state = TrainState.create(apply_fn=mlp.apply, params=params,
                                tx=optax.adam(1e-3))

# use absolute path
path = FLIGHTNING_PATH + "/../examples/saved_params"

ckptr = PyTreeCheckpointer()
ckptr.save(path, params)

You can check the directory. This method created a new folder `saved_params` that contains all data and metadata associated with the parameters.

## Load Parameters

In [None]:
params_loaded = ckptr.restore(path)

print("Original params")
print(params)
print()
print("Loaded params")
print(params_loaded)

That was easy. The parameters are all there.

## Save and Load Trainstates

Saving works the same way but when loading, we need to provide the structure to create the right object.

In [None]:
path = FLIGHTNING_PATH + "/../examples/saved_trainstate"

ckptr.save(path, train_state)

trainstate_loaded = ckptr.restore(path)

print(train_state)
print()
print(trainstate_loaded)

We observe that while the original trainstate was an object of class TrainState. However, the loaded one is just a dictionary. To overcome this, we need to provide an object of the correct type and structure.

In [None]:
mlp_obj = MLP([2, 3, 1])
params_obj = mlp.initialize(jax.random.key(42))

train_state_obj = TrainState.create(apply_fn=mlp.apply, params=params,
                                tx=optax.adam(1e-3))

trainstate_loaded = ckptr.restore(path, train_state_obj)

print(trainstate_loaded)