New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add attention functions and tests #181
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nits aside, this LGTM!
Have you checked (with the |
I checked the outputs and they match with Also while doing this I noticed that the outputs do not match when using dropout -- I'm not sure yet if this is because the order that dropout is being applied is different, or whether there is a problem with vmap'ing when the |
Yea the problem is vmap'ing with the |
Just checking that you're not waiting on any input from me at the moment? |
f1fb68f
to
094785a
Compare
The most recent changes fix the problem with closing over the key in However vmap'ing dropout doesn't match the un-vmap'ed version. import jax
import jax.numpy as jnp
from equinox import nn
key = jax.random.PRNGKey(41)
x = jnp.arange(4*5).reshape(4, 5)
dropout = nn.Dropout(.5)
y1 = dropout(x, key=key)
y2 = jax.vmap(dropout, in_axes=0, out_axes=0)(x, key=jax.random.split(key, 4))
y3 = jax.vmap(dropout, in_axes=1, out_axes=1)(x, key=jax.random.split(key, 5))
jnp.allclose(y1, y2) # --> False
jnp.allclose(y1, y3) # --> False I may be fundamentally misunderstanding how the PRNG works, but my expectation is that y1 and y2 are equal. Any insight into why these don't match? Otherwise I can create a jax issue. A simpler MWE: import jax
import jax.numpy as jnp
key = jax.random.PRNGKey(41)
def test(key):
return jax.random.normal(key, ())
y1 = jax.random.normal(key, (4,))
y2 = jax.vmap(test)(jax.random.split(key, 4))
jnp.allclose(y1, y2) # --> False |
It's definitely expected that these don't match. I could believe that there's no way to reproduce the same behaviour as before when using dropout, since we now have a vmap'd dropout and JAX may be doing something else under-the-hood. I think that's fine -- it's mostly inference mode that I'm concerned about; I'd prefer not to break any models that have already been serialised to disk, but dropout is really training-time-only. |
Ok that sounds reasonable. Pending any additional feedback I think this is ready to merge. |
Excellent. Thanks for contributing! |
Adds
dot_product_attention_weights
anddot_product_attention
functions and testsDesign considerations:
dot_product_attention
anddot_product_attention_weights
don't take multi-head inputs -- instead attention heads are vmap'd over inMultiheadAttention
. This allows for greater flexibility when creating other types of attention modulesdot_product_attention
signature --dropout_fn
is added as a single argument callable, which should close over the dropout arguments likekey
andinference
. The alternative I think would be to add a functional version of dropout and add its arguments todot_product_attention
, however this would make changing the dropout rate after initializing the module less intuitive -- since dropout rate would have to be an attribute ofMultiheadAttention
.mask
shape check is kept insidedot_product_attention_weights
. The downside to this is that errors raised inside vmap'd functions are less obvious -- i.e. if the heads don't match thenvmap
function will raise an error. The alternative is to pull the shape check out and put it back inMultiheadAttention