Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standard MLP (equinox.nn.MLP) does not work with apply_updates function #108

Closed
callumtilbury opened this issue Jun 13, 2022 · 3 comments
Closed

Comments

@callumtilbury
Copy link

From what I can see, the standard MLP included in equinox.nn.MLP breaks when trying to apply updates.

Simple demo:

import equinox as eqx
import jax.random as jrand

# Define a simple MLP
mlp = eqx.nn.MLP(in_size=1,width_size=1,out_size=1,depth=1,key=jrand.PRNGKey(0))

# Should be able to apply updates with itself
eqx.apply_updates(mlp,mlp)

Error thrown:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [4], in <cell line: 1>()
----> 1 eqx.apply_updates(mlp,mlp)

File ~/Desktop/equinox/equinox/update.py:37, in apply_updates(model, updates)
     18 """A `jax.tree_map`-broadcasted version of
     19 ```python
     20 model = model if update is None else model + update
   (...)
     34 The updated model.
     35 """
     36 # Assumes that updates is a prefix of model
---> 37 return jax.tree_map(_apply_update, updates, model, is_leaf=_is_none)

File ~/miniconda3/envs/eqx-fork/lib/python3.10/site-packages/jax/_src/tree_util.py:184, in tree_map(f, tree, is_leaf, *rest)
    182 leaves, treedef = tree_flatten(tree, is_leaf)
    183 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 184 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/miniconda3/envs/eqx-fork/lib/python3.10/site-packages/jax/_src/tree_util.py:184, in <genexpr>(.0)
    182 leaves, treedef = tree_flatten(tree, is_leaf)
    183 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 184 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/Desktop/equinox/equinox/update.py:10, in _apply_update(u, p)
      8     return p
      9 else:
---> 10     return p + u

TypeError: unsupported operand type(s) for +: 'custom_jvp' and 'custom_jvp'

The issue seems to arise in tree_map, when we flatten the tree:

leaves, treedef = tree_flatten(tree, is_leaf)

Inspecting the resulting leaves:

[DeviceArray([[0.58076453]], dtype=float32),
DeviceArray([-0.44256163], dtype=float32),
DeviceArray([[0.882236]], dtype=float32),
DeviceArray([0.79829645], dtype=float32),
<jax._src.custom_derivatives.custom_jvp object at 0x7fc6f0e88460>,
<function MLP._identity at 0x7fc6f1db8670>]

We see that the tree flattening has also been applied to the custom jvp object (i.e. the RELU) and the identity function—i.e. the activation fields in the MLP class.

To fix this, I believe we should simply mark the activation and final_activation fields as static (i.e. should not be treated as leaves of the PyTree):

class MLP(eqx.Module):
    """Standard Multi-Layer Perceptron; also known as a feed-forward network."""

    layers: List[Linear]
    activation: Callable = static_field()
    final_activation: Callable = static_field()

    ...

This fixes the issue for me.

callumtilbury added a commit to callumtilbury/equinox that referenced this issue Jun 13, 2022
@patrick-kidger
Copy link
Owner

This is actually intended behaviour.

What's going on here is that the PyTree structure of the model may include non-JAX-arrays in its leaves, and indeed may contain arbitrary Python objects. Their interactions with other operations -- in this case you're trying to add two PyTrees together -- are resolved through filtering.

With respect to applying updates, this is covered in the FAQ. (See also the docs and Section 4 of the paper.)

It's important that we have pretty much everything be part of the PyTree structure. For example, someone may wish to define and use

class LearntActivation(eqx.Module):
    param: jnp.ndarray

    def __call__(self, x):
        return jax.nn.relu(param * x)

eqx.nn.MLP(..., activation=LearntActivation(...))

@patrick-kidger
Copy link
Owner

As such I'm going to close #109 -- but I appreciate the willingness to get stuck in.

@callumtilbury
Copy link
Author

Thanks for the quick response, @patrick-kidger ! I was trying to use the filtering approach, but I was facing some problems. I had a deeper dive, and I realise now where I went wrong. I appreciate the help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants