Skip to content

Commit

Permalink
Added support for sharing layers between different parts of a model.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Sep 14, 2023
1 parent 56cc31f commit 57afee1
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 0 deletions.
7 changes: 7 additions & 0 deletions docs/api/nn/shared.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Sharing layers

::: equinox.nn.Shared
selection:
members:
- __init__
- __call__
1 change: 1 addition & 0 deletions equinox/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
Sequential as Sequential,
StatefulLayer as StatefulLayer,
)
from ._shared import Shared as Shared
from ._spectral_norm import SpectralNorm as SpectralNorm
from ._stateful import (
delete_init_state as delete_init_state,
Expand Down
158 changes: 158 additions & 0 deletions equinox/nn/_shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from collections.abc import Callable

from jaxtyping import PyTree

from .._eval_shape import filter_eval_shape
from .._module import Module
from .._tree import tree_at, tree_equal


class SharedNode:
"""Placeholder value for nodes that have been removed by `eqx.nn.Shared`."""

def __repr__(self):
return "SharedNode"


class Shared(Module):
"""Used to tie together multiple nodes across a PyTree.
Note that Equinox modules are Py**Trees** -- so the same layer, appearing in two
difference parts of the tree, will be treated as two copies of this layer. For
example,
```python
class SubModel(eqx.Module):
linear: eqx.nn.Linear
class Model(eqx.Module):
linear: eqx.nn.Linear
submodel: SubModel
def __init__(self):
linear = eqx.nn.Linear(...)
self.linear = linear
self.submodel = SubModel(linear)
```
is used to declare `model.linear` and `model.submodel.linear` as two separate
layers. They will start with the same initial parameter values, and then update
independently during training.
For when we really do want to share layers or weights across different parts of a
model, then `eqx.nn.Shared` exists as a way to easily express this in the PyTree
paradigm.
!!! Example
It is common in many language models to have an initial embedding matrix at the
start, and then to reuse this as the weight of the final linear transformation.
```python
import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array, Int
class LanguageModel(eqx.Module):
shared: eqx.nn.Shared
def __init__(self):
embedding = eqx.nn.Embedding(...)
linear = eqx.nn.Linear(...)
# These two weights will now be tied together.
where = lambda embed_and_lin: embed_and_lin[1].weight
get = lambda embed_and_lin: embed_and_lin[0].weight
self.shared = eqx.nn.Shared((embedding, linear), where, get)
def __call__(self, tokens: Int[Array, "sequence"]):
# Expand back out so we can evaluate these layers.
embedding, linear = self.shared()
assert embedding.weight is linear.weight # same parameter!
# Now go ahead and evaluate your language model.
values = jax.vmap(embedding)(tokens)
... # other layers, probably
return jax.vmap(linear)(values)
```
_(Side note: you will sometimes see some authors referring to transposing
the embedding matrix prior to the final linear layer. This is because some
other libraries store the weight matrices of linear layers the other way
around. If that had been necessary here then we could have done it with
`get = lambda embed_and_lin: jnp.transpose(embed_and_lin[0].weight)`.)_
"""

pytree: PyTree
where: Callable
get: Callable

def __init__(self, pytree: PyTree, where: Callable, get: Callable):
"""**Arguments:**
- `pytree`: The PyTree to share some nodes across.
- `where`: a function specifying either a single node, or a sequence of nodes,
as with `eqx.tree_at(where, pytree, ...)`.
- `get`: a function, which when evaluated on `pytree`, returns either a single
value (if `where` does), or a sequence of values (if `where` does, and in
this case this must be a sequence of the same length as `where`).
The node(s) of `get(pytree)` and the corresponding value(s) of `where(pytree)`
will be tied together.
!!! info
To explain how this works. The implementation is just:
```python
class Shared(eqx.Module):
pytree: PyTree
where: Callable
get: Callable
def __init__(self, pytree, where, get):
# `0` is just some dummy value
self.pytree = eqx.tree_at(where, pytree, replace_fn=lambda _: 0)
self.where = where
self.get = get
def __call__(self):
return eqx.tree_at(self.where, self.pytree, self.get(self.pytree))
```
so that at `__init__` time, the duplicate nodes specified in `where` are
removed from the PyTree. We no longer have a separate copy updating during
training.
And then at `__call__` time, references to the values returned by
`get(pytree)` are put in their place. We end up with a pytree of the same
structure as what we started with, which we can now use (evaluate as a
layer etc.) as normal.
!!! tip
If you need to apply any transform (e.g. transposing a matrix), then this
can be done as part of `get`. For example,
`get = lambda pair: jnp.transpose(pair[1].weight)`.
"""

source_struct = filter_eval_shape(get, pytree)
dest_struct = filter_eval_shape(where, pytree)
if tree_equal(source_struct, dest_struct) is not True:
raise ValueError(
"Every node being shared together must have the same pytree "
"structure, shape+dtype of arrays, etc., as each other. Got:\n"
f"{source_struct}\n"
"and\n"
f"{dest_struct}"
)
self.pytree = tree_at(where, pytree, replace_fn=lambda _: SharedNode())
self.where = where
self.get = get

def __call__(self):
"""**Arguments:**
None.
**Returns:**
A PyTree of the same structure as the original `pytree`, with `get(pytree)` in
the place of the nodes at `where(pytree)`.
"""
return tree_at(self.where, self.pytree, self.get(self.pytree))
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ nav:
- 'api/nn/mlp.md'
- 'api/nn/sequential.md'
- 'api/nn/inference.md'
- 'api/nn/shared.md'
- 'api/nn/stateful.md'
- Filtering:
- 'api/filtering/partition-combine.md'
Expand Down
125 changes: 125 additions & 0 deletions tests/test_shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import jax
import jax.numpy as jnp
import jax.random as jr
import pytest
from jaxtyping import Array, Float, Int

import equinox as eqx


def test_shared_array(getkey):
class MyModule(eqx.Module):
shared: eqx.nn.Shared

def __init__(self):
embedding = eqx.nn.Embedding(
num_embeddings=3, embedding_size=4, key=getkey()
)
head = eqx.nn.Linear(4, 3, key=getkey())
where = lambda pair: pair[1].weight
get = lambda pair: pair[0].weight
self.shared = eqx.nn.Shared((embedding, head), where, get)

def __call__(self, token: Int[Array, ""]):
nonlocal called
called = True
embedding, head = self.shared()
assert embedding.weight is head.weight
return head(embedding(token))

called = False
module = MyModule()
module(jnp.array(0))
assert called


# We share a non-leaf node
def test_shared_node(getkey):
class MyModule(eqx.Module):
shared: eqx.nn.Shared

def __init__(self):
attention = eqx.nn.MultiheadAttention(
num_heads=3, query_size=12, key=getkey()
)
my_proj = eqx.nn.Linear(12, 12, use_bias=False, key=getkey())
where = lambda pair: pair[1].key_proj
get = lambda pair: pair[0]
self.shared = eqx.nn.Shared((my_proj, attention), where, get)

def __call__(self, x: Float[Array, "seq 12"]):
nonlocal called
called = True
my_proj, attention = self.shared()
eq = eqx.tree_equal(my_proj, attention.key_proj)
x = attention(x, x, x)
out = jax.vmap(my_proj)(x)
return out, eq

called = False
module = MyModule()
x = jr.normal(getkey(), (5, 12))

@eqx.filter_jit
@eqx.filter_grad(has_aux=True)
def f(module, x):
out, eq = module(x)
return jnp.sum(out), eq

d_module, eq = f(module, x)
assert called
assert eq
module = eqx.apply_updates(module, d_module)
d_module, eq = f(module, x)
assert eq
module = eqx.apply_updates(module, d_module)


def test_mismatched_structure(getkey):
x = jr.normal(getkey(), (3, 4))
y = jr.normal(getkey(), (4, 3))
with pytest.raises(ValueError, match="Every node being shared together"):
eqx.nn.Shared((x, y), lambda pair: pair[0], lambda pair: pair[1])


def test_multi_shared(getkey):
class MyModule(eqx.Module):
shared: eqx.nn.Shared

def __init__(self):
my_proj = eqx.nn.Linear(12, 12, use_bias=False, key=getkey())
attention = eqx.nn.MultiheadAttention(
num_heads=3, query_size=12, key=getkey()
)
where = lambda pair: (pair[1].key_proj, pair[1].query_proj.weight)
get = lambda pair: (pair[0], pair[0].weight + 1)
self.shared = eqx.nn.Shared((my_proj, attention), where, get)

def __call__(self, x: Float[Array, "seq 12"]):
nonlocal called
called = True
my_proj, attention = self.shared()
eq1 = eqx.tree_equal(my_proj, attention.key_proj)
eq2 = (my_proj.weight + 1 == attention.query_proj.weight).all()
x = attention(x, x, x)
out = jax.vmap(my_proj)(x)
eq = eq1 & eq2
return out, eq

called = False
module = MyModule()
x = jr.normal(getkey(), (5, 12))

@eqx.filter_jit
@eqx.filter_grad(has_aux=True)
def f(module, x):
out, eq = module(x)
return jnp.sum(out), eq

d_module, eq = f(module, x)
assert called
assert eq
module = eqx.apply_updates(module, d_module)
d_module, eq = f(module, x)
assert eq
module = eqx.apply_updates(module, d_module)

0 comments on commit 57afee1

Please sign in to comment.