Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some kind de/serialisation? #46

Closed
patrick-kidger opened this issue Mar 26, 2022 · 13 comments · Fixed by #82
Closed

Add some kind de/serialisation? #46

patrick-kidger opened this issue Mar 26, 2022 · 13 comments · Fixed by #82
Labels
feature New feature

Comments

@patrick-kidger
Copy link
Owner

patrick-kidger commented Mar 26, 2022

Equinox models are just PyTrees so they should be very easy to serialise/deserialise; just save the PyTree to disk in whatever way is desired. It might be worth adding some library functions for this just for convenience. Perhaps checking the device of JAX arrays etc?

This should respect the get_state/set_state stuff that's being put together.

In addition, there should be a version of get_state which inlines its state in the call graph, for faster inference.

@patrick-kidger patrick-kidger added the feature New feature label Apr 14, 2022
@jaschau
Copy link

jaschau commented Apr 15, 2022

I would like to pick up on a comment you made in https://www.reddit.com/r/MachineLearning/comments/u34oh2/d_what_jax_nn_library_to_use/i4umg44/?context=3.

The reason I'm dragging my heels on this is that I'm not yet completely sold on whether to use the code above, or to instead try and pickle the entire PyTree in one go. The former requires you to be able to produce the PyTree structure yourself (and by default doesn't save e.g. NumPy arrays); the latter has annoying compatibility concerns. (+I think Flax has a couple of ways of doing de/serialisation that might be worth using as inspiration.) So I want to be sure we get this right.

I agree that it would be attractive to have an option to serialize complete PyTrees, but I also see some challenges. Just to give an example, if you have a equinox module,

class MyModule(eqx.Module):
    activation: Any

module = MyModule(activation=jnn.tanh)

pickling this will fail because jnn.tanh seems to point to a lambda and pickle doesn't support pickling lambdas (https://stackoverflow.com/questions/25348532/can-python-pickle-lambda-functions).

But until we have a solution for this, why not go with something along the lines of

def save_model_weights(model: eqx.Module, path: str):
    tree_weights, tree_other = eqx.partition(model, eqx.is_inexact_array)
    with open(path + "_weights.npy", "wb") as f:
        for x in jax.tree_leaves(tree_weights):
            np.save(f, x, allow_pickle=False)


def load_model_weights(model, path):
    tree_weights, tree_other = eqx.partition(model, eqx.is_inexact_array)

    leaves_orig, treedef = jax.tree_flatten(tree_weights)
    with open(path + "_weights.npy", "rb") as f:
        flat_state = [jnp.asarray(np.load(f)) for _ in leaves_orig]
    tree_weights = jax.tree_unflatten(treedef, flat_state)
    return eqx.combine(tree_weights, tree_other)

The naming should make it clear that it's just about the model weights.
Once we have a good solution for pickling entire trees, we can add methods

def save_model(model: eqx.Module, path: str):
    ...

def load_model(path):
    ...

@patrick-kidger
Copy link
Owner Author

Right, pickling is a whole can of worms. Let's leave that alone for now.

Yep, you've convinced me! I think it seems reasonable to add something that saves just the leaves of a PyTree.

Your implementation is also a fair bit tidier than the one I suggested over on reddit; nice. Can you open a PR for this? I'll pre-emptively offer a review - mostly nits.

  1. Change the type annotation from Module to PyTree. Remember, Modules are never special-cased! Likewise, probably change the names to {save,load}_leaves or {de,}serialise_leaves or similar, to help emphasise this.
  2. Allow specifying the filter spec, instead of baking in is_inexact_array. And perhaps make the default is_array instead?
  3. Use jnp.{load,save} over np.{load,save}. I think this is needed to handle bfloat16 dtypes correctly.
  4. Allow path to be a Union[str, pathlib.Path]. I also probably wouldn't hardcode the + "_weights.npy" part of it; I think just pass path through unchanged. (Or only add a .npy suffix.)
  5. Validate that flat_state and leaves_orig match shape+dtype, and raise an error if not.

@patrick-kidger
Copy link
Owner Author

Thought. We may also wish to allow customising the use of jnp.{load,save} at some point.

For example: whilst it hasn't happened yet, there's discussion about eventually being able to de/serialise JIT-compiled functions in JAX. Meanwhile, the next release of Equinox will have filter_jit return a new _JitWrapper object: instead of capturing its inputs (function, filter spec, etc.) via closure, it will capture them as leaves in a PyTree instead. Mostly that doesn't change anything, but in this context it offers a neat opportunity. We could de/serialise a filter-jit-wrapped function in the same way as everything else; we'd just need to use whatever mechanism is introduced for de/serialising JIT-compiled functions in place of jnp.{load,save}.

@jaschau
Copy link

jaschau commented Apr 17, 2022

Hi, thanks for the review and the thoughtful remarks! I agree with most of your comments, but I'm not entirely sold on exposing a function serialize_leaves. There are two reasons
a) the naming seems to imply that it can serialize the leaves of an arbitrary tree. In practice, you would be limited to serializing numeric leaves due to the use of jnp.save, so as a user you somehow have to be aware of what should probably be an implementation detail.
b) the predominant use case will probably saving weights of a model. At least, that's what I would look for as a user of a neural network library and the connection between a method serialize_leaves and serializing weights might not be immediately clear to a new user.
So I would rather make _serialize_leaves a library internal method and expose serialize_model_weights to the user (which internally would call _serialize_leaves).
What do you think?

@patrick-kidger
Copy link
Owner Author

patrick-kidger commented Apr 17, 2022

Hmm. I agree that the points you're making are reasonable.

If we were to use a tree-based name:

  • We can handle (a) by having the filter_spec specify the saving/loading function, so that it really is this more general thing. For example filter_spec=lambda x: jnp.save if is_array(x) else None. (Valid values being either callables or None.)
  • I think (b) is a more serious concern, and handling this might involve reorganising the documentation slightly. At the moment all the tree-based utilities just get lumped in one place, despite being a hodge-podge of things useful in different ways. Many of them are already important for manipulating models; in particular apply_updates and tree_at. Perhaps we could split out the above page into multiple pages:
Utilities
├ Manipulation
│ ├ apply_updates  # hmm, this one doesn't fit the consistent naming scheme.
│ ├ tree_at
│ └ tree_inference  # new in the next update, don't worry if you don't recognise this ;)
├ Serialisation
│ ├ tree_serialise_leaves
│ └ tree_deserialise_leaves
└ Miscellaneous
  ├ tree_pformat
  ├ tree_equal
  └ static_field  # This should really go with Module but it's a pretty advanced
                  # thing with niche uses, so we hide it here instead.
Experimental
└ Stateful operations
  ├ StateIndex
  ├ get_state
  └ set_state

As you can probably tell, I'm disinclined to write something as specific save_model_weights. It runs counter to the simple models-are-PyTrees idea that is what makes reasoning about Equinox so easy in the first place. I don't think de/serialisation is an important and special enough problem that it's worth breaking that abstraction.

My feeling is that issues like new-user-discoverability can be handled through appropriate documentation etc. Perhaps we could change the proposed heading above from "Serialisation" to "Serialisation (save/load models to disk)" if we wanted to really make it clear where to look.

WDYT?

@jaschau
Copy link

jaschau commented Apr 20, 2022

Hi, sorry for the late feedback.

As you can probably tell, I'm disinclined to write something as specific save_model_weights. It runs counter to the simple models-are-PyTrees idea that is what makes reasoning about Equinox so easy in the first place. I don't think de/serialisation is an important and special enough problem that it's worth breaking that abstraction.

I can understand your reasoning here.
With respect to (b), I still think that the best documentation is the documentation that you do not need, so one alternative I could think of is introducing methods save_weights and load_weights in eqx.Module which are a one-liner call to tree_serialise_leaves. This would of course increase the amount of code in eqx.Module, but I think that the one-liner would very naturally embody the models-are-PyTrees idea.

Would you be okay with that?

Irrespective of how we do it, I'd be happy to contribute a first draft of a PR sometime this week.

@patrick-kidger
Copy link
Owner Author

patrick-kidger commented Apr 20, 2022

Is your reasoning that this would improve discoverability because it'll be in the documentation for Module? (Or because it'll appear in dir(Module())?)

At least wrt the former then this could be handled by improving the docstring and I'd be happy to have that in. The problem with adding new methods is that this breaks a different principle in Eqx, which is that no Module method is special cased. (Other than magic methods and tree_{un,}flatten.)

(I appreciate it probably feels like you're threading a needle here.)

@jaschau
Copy link

jaschau commented Apr 20, 2022

Is your reasoning that this would improve discoverability because it'll be in the documentation for Module? (Or because it'll appear in dir(Module())?)

Yes, exactly. It seems like an obvious place where to look as a user.

At least wrt the former then this could be handled by improving the docstring and I'd be happy to have that in. The problem with adding new methods is that this breaks a different principle in Eqx, which is that no Module method is special cased. (Other than magic methods and tree_{un,}flatten.)

Could you elaborate on what you mean that no Module method is special cased?

At least wrt the former then this could be handled by improving the docstring and I'd be happy to have that in. The problem with adding new methods is that this breaks a different principle in Eqx, which is that no Module method is special cased. (Other than magic methods and tree_{un,}flatten.)

(I appreciate it probably feels like you're threading a needle here.)

No worries at all, I think it is good if library authors are opinionated about their design choices! I appreciate the discussion.
Edit: not sure if opinionated carries a negative connotation - if so, it wasn't meant in that way!

@patrick-kidger
Copy link
Owner Author

Yes, exactly. It seems like an obvious place where to look as a user.

FWIW I do feel that discoverability could be better in Equinox. For example tree_at is pretty important for modifying PyTrees, but it's not terribly well-advertised.

Could you elaborate on what you mean that no Module method is special cased?

Consider magic methods: these are special cased by Python itself. MyClass.__len__ has meaning beyond simply being a method on a class, and changing magic methods allow you to change how Python handles your class. Likewise there is tree_{un,}flatten, and these allow you to handle how JAX handles your class.

(Another example: PyTorch special-cases forward, as the appropriate extension point for subclasses of torch.nn.Module.)

Equinox doesn't special-case any methods like this. You can't change how Equinox treats your class. (You don't need to.)

Edit: not sure if opinionated carries a negative connotation - if so, it wasn't meant in that way!

Not at all! Not how I read it.

@patrick-kidger
Copy link
Owner Author

patrick-kidger commented Apr 21, 2022

Also: this should respect eqx.experimental.{StateIndex,get_state,set_state}. Ideally the default should also handle NumPy arrays.

The following is entirely untested, but I think that probably means doing something like:

def _save_index(f, x: StateIndex):
    value = experimental.get_state(x)
    jnp.save(value)

def _save_no(f, x):
    pass

def _load_index(f, x: StateIndex):
    value = jnp.load(f)
    experimental.set_state(x, value)
    return x

def _load_scalar(f, x):
    return np.load(f).item()

def _load_no(f, x):
    return x
    
def _default_serialise_filter_spec(x):
    if is_array_like(x):
        return jnp.save
    elif isinstance(x, experimental.StateIndex):
        return _save_index
    else:
        return _save_no

def _default_deserialise_filter_spec(x):
    if isinstance(x, jnp.ndarray):
        return jnp.load
    elif isinstance(x, np.ndarray):
        return np.load
    elif isinstance(x, (bool, float, complex, int)):
        return _load_scalar
    elif isinstance(x, experimental.StateIndex):
        return _load_index
    else:
        return _load_no

def _assert_same(new, old):
    if type(new) is not type(old):
        raise RuntimeError(...)
    if isinstance(new, (np.ndarray, jnp.array)) and (new.shape != old.shape or new.dtype != old.dtype):
        raise RuntimeError(...)

def _is_index(x):
    return isinstance(x, experimental.StateIndex)

def tree_serialise_leaves(path: Union[str, pathlib.Path], pytree: PyTree, filter_spec=_default_serialise_filter_spec, is_leaf=_is_index):
    with open(pathlib.Path(path).with_suffix("npy"), "wb") as f:
        def _serialise(spec, x):
            def __serialise(y):
                spec(f, y)
            return jax.tree_map(__serialise, x)
        jax.tree_map(_serialise, filter_spec, pytree, is_leaf=is_leaf)

def tree_deserialise_leaves(path: Union[str, pathlib.Path], like: PyTree, filter_spec=_default_serialise_filter_spec, is_leaf=_is_index):
    with open(pathlib.Path(path).with_suffix("npy"), "rb") as f:
        def _deserialise(spec, x):
            def __deserialise(y):
                return spec(f, y)
            return jax.tree_map(__deserialise, x)
        out = jax.tree_map(_deserialise, filter_spec, like, is_leaf=is_leaf)
    jax.tree_map(_assert_same, out, like, is_leaf=is_leaf)
    return out

@patrick-kidger
Copy link
Owner Author

@jaschau heads-up when putting together your PR to branch off the v050 branch.

@patrick-kidger
Copy link
Owner Author

@jaschau Any plans to pick this up as a PR? No worries if not -- I'll do it -- I just want to get seralisation+deserialisation into the next release of Equinox.

@jaschau
Copy link

jaschau commented May 4, 2022

Hi @patrick-kidger, sorry for the little progress here. I've been side-tracked with other topics so I haven't come around to working on this beyond our initial discussion. I'm afraid I cannot promise any serious progress on this in the upcoming weeks from my side, so please move ahead if you're eager to work on this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants