-
-
Notifications
You must be signed in to change notification settings - Fork 130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Standard MLP (equinox.nn.MLP) does not work with apply_updates
function
#108
Comments
This is actually intended behaviour. What's going on here is that the PyTree structure of the model may include non-JAX-arrays in its leaves, and indeed may contain arbitrary Python objects. Their interactions with other operations -- in this case you're trying to add two PyTrees together -- are resolved through filtering. With respect to applying updates, this is covered in the FAQ. (See also the docs and Section 4 of the paper.) It's important that we have pretty much everything be part of the PyTree structure. For example, someone may wish to define and use class LearntActivation(eqx.Module):
param: jnp.ndarray
def __call__(self, x):
return jax.nn.relu(param * x)
eqx.nn.MLP(..., activation=LearntActivation(...)) |
As such I'm going to close #109 -- but I appreciate the willingness to get stuck in. |
Thanks for the quick response, @patrick-kidger ! I was trying to use the filtering approach, but I was facing some problems. I had a deeper dive, and I realise now where I went wrong. I appreciate the help! |
From what I can see, the standard MLP included in
equinox.nn.MLP
breaks when trying to apply updates.Simple demo:
Error thrown:
The issue seems to arise in
tree_map
, when we flatten the tree:Inspecting the resulting
leaves
:We see that the tree flattening has also been applied to the custom jvp object (i.e. the RELU) and the identity function—i.e. the activation fields in the MLP class.
To fix this, I believe we should simply mark the
activation
andfinal_activation
fields as static (i.e. should not be treated as leaves of the PyTree):This fixes the issue for me.
The text was updated successfully, but these errors were encountered: