New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add some kind de/serialisation? #46
Comments
I would like to pick up on a comment you made in https://www.reddit.com/r/MachineLearning/comments/u34oh2/d_what_jax_nn_library_to_use/i4umg44/?context=3.
I agree that it would be attractive to have an option to serialize complete PyTrees, but I also see some challenges. Just to give an example, if you have a equinox module, class MyModule(eqx.Module):
activation: Any
module = MyModule(activation=jnn.tanh) pickling this will fail because jnn.tanh seems to point to a lambda and pickle doesn't support pickling lambdas (https://stackoverflow.com/questions/25348532/can-python-pickle-lambda-functions). But until we have a solution for this, why not go with something along the lines of def save_model_weights(model: eqx.Module, path: str):
tree_weights, tree_other = eqx.partition(model, eqx.is_inexact_array)
with open(path + "_weights.npy", "wb") as f:
for x in jax.tree_leaves(tree_weights):
np.save(f, x, allow_pickle=False)
def load_model_weights(model, path):
tree_weights, tree_other = eqx.partition(model, eqx.is_inexact_array)
leaves_orig, treedef = jax.tree_flatten(tree_weights)
with open(path + "_weights.npy", "rb") as f:
flat_state = [jnp.asarray(np.load(f)) for _ in leaves_orig]
tree_weights = jax.tree_unflatten(treedef, flat_state)
return eqx.combine(tree_weights, tree_other) The naming should make it clear that it's just about the model weights. def save_model(model: eqx.Module, path: str):
...
def load_model(path):
... |
Right, pickling is a whole can of worms. Let's leave that alone for now. Yep, you've convinced me! I think it seems reasonable to add something that saves just the leaves of a PyTree. Your implementation is also a fair bit tidier than the one I suggested over on reddit; nice. Can you open a PR for this? I'll pre-emptively offer a review - mostly nits.
|
Thought. We may also wish to allow customising the use of For example: whilst it hasn't happened yet, there's discussion about eventually being able to de/serialise JIT-compiled functions in JAX. Meanwhile, the next release of Equinox will have |
Hi, thanks for the review and the thoughtful remarks! I agree with most of your comments, but I'm not entirely sold on exposing a function |
Hmm. I agree that the points you're making are reasonable. If we were to use a tree-based name:
As you can probably tell, I'm disinclined to write something as specific My feeling is that issues like new-user-discoverability can be handled through appropriate documentation etc. Perhaps we could change the proposed heading above from "Serialisation" to "Serialisation (save/load models to disk)" if we wanted to really make it clear where to look. WDYT? |
Hi, sorry for the late feedback.
I can understand your reasoning here. Would you be okay with that? Irrespective of how we do it, I'd be happy to contribute a first draft of a PR sometime this week. |
Is your reasoning that this would improve discoverability because it'll be in the documentation for At least wrt the former then this could be handled by improving the docstring and I'd be happy to have that in. The problem with adding new methods is that this breaks a different principle in Eqx, which is that no Module method is special cased. (Other than magic methods and (I appreciate it probably feels like you're threading a needle here.) |
Yes, exactly. It seems like an obvious place where to look as a user.
Could you elaborate on what you mean that no Module method is special cased?
No worries at all, I think it is good if library authors are opinionated about their design choices! I appreciate the discussion. |
FWIW I do feel that discoverability could be better in Equinox. For example
Consider magic methods: these are special cased by Python itself. (Another example: PyTorch special-cases Equinox doesn't special-case any methods like this. You can't change how Equinox treats your class. (You don't need to.)
Not at all! Not how I read it. |
Also: this should respect The following is entirely untested, but I think that probably means doing something like: def _save_index(f, x: StateIndex):
value = experimental.get_state(x)
jnp.save(value)
def _save_no(f, x):
pass
def _load_index(f, x: StateIndex):
value = jnp.load(f)
experimental.set_state(x, value)
return x
def _load_scalar(f, x):
return np.load(f).item()
def _load_no(f, x):
return x
def _default_serialise_filter_spec(x):
if is_array_like(x):
return jnp.save
elif isinstance(x, experimental.StateIndex):
return _save_index
else:
return _save_no
def _default_deserialise_filter_spec(x):
if isinstance(x, jnp.ndarray):
return jnp.load
elif isinstance(x, np.ndarray):
return np.load
elif isinstance(x, (bool, float, complex, int)):
return _load_scalar
elif isinstance(x, experimental.StateIndex):
return _load_index
else:
return _load_no
def _assert_same(new, old):
if type(new) is not type(old):
raise RuntimeError(...)
if isinstance(new, (np.ndarray, jnp.array)) and (new.shape != old.shape or new.dtype != old.dtype):
raise RuntimeError(...)
def _is_index(x):
return isinstance(x, experimental.StateIndex)
def tree_serialise_leaves(path: Union[str, pathlib.Path], pytree: PyTree, filter_spec=_default_serialise_filter_spec, is_leaf=_is_index):
with open(pathlib.Path(path).with_suffix("npy"), "wb") as f:
def _serialise(spec, x):
def __serialise(y):
spec(f, y)
return jax.tree_map(__serialise, x)
jax.tree_map(_serialise, filter_spec, pytree, is_leaf=is_leaf)
def tree_deserialise_leaves(path: Union[str, pathlib.Path], like: PyTree, filter_spec=_default_serialise_filter_spec, is_leaf=_is_index):
with open(pathlib.Path(path).with_suffix("npy"), "rb") as f:
def _deserialise(spec, x):
def __deserialise(y):
return spec(f, y)
return jax.tree_map(__deserialise, x)
out = jax.tree_map(_deserialise, filter_spec, like, is_leaf=is_leaf)
jax.tree_map(_assert_same, out, like, is_leaf=is_leaf)
return out |
@jaschau heads-up when putting together your PR to branch off the |
@jaschau Any plans to pick this up as a PR? No worries if not -- I'll do it -- I just want to get seralisation+deserialisation into the next release of Equinox. |
Hi @patrick-kidger, sorry for the little progress here. I've been side-tracked with other topics so I haven't come around to working on this beyond our initial discussion. I'm afraid I cannot promise any serious progress on this in the upcoming weeks from my side, so please move ahead if you're eager to work on this. |
Equinox models are just PyTrees so they should be very easy to serialise/deserialise; just save the PyTree to disk in whatever way is desired. It might be worth adding some library functions for this just for convenience. Perhaps checking the device of JAX arrays etc?
This should respect the
get_state
/set_state
stuff that's being put together.In addition, there should be a version of
get_state
which inlines its state in the call graph, for faster inference.The text was updated successfully, but these errors were encountered: