Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add attention functions and tests #181

Merged
merged 3 commits into from Aug 30, 2022

Conversation

jenkspt
Copy link
Contributor

@jenkspt jenkspt commented Aug 22, 2022

Adds dot_product_attention_weights and dot_product_attention functions and tests

Design considerations:

  • dot_product_attention and dot_product_attention_weights don't take multi-head inputs -- instead attention heads are vmap'd over in MultiheadAttention. This allows for greater flexibility when creating other types of attention modules
  • To simplify the dot_product_attention signature -- dropout_fn is added as a single argument callable, which should close over the dropout arguments like key and inference. The alternative I think would be to add a functional version of dropout and add its arguments to dot_product_attention, however this would make changing the dropout rate after initializing the module less intuitive -- since dropout rate would have to be an attribute of MultiheadAttention.
  • mask shape check is kept inside dot_product_attention_weights. The downside to this is that errors raised inside vmap'd functions are less obvious -- i.e. if the heads don't match then vmap function will raise an error. The alternative is to pull the shape check out and put it back in MultiheadAttention

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nits aside, this LGTM!

equinox/nn/attention.py Outdated Show resolved Hide resolved
equinox/nn/attention.py Outdated Show resolved Hide resolved
tests/test_nn.py Show resolved Hide resolved
@patrick-kidger
Copy link
Owner

Have you checked (with the out_axes=1 change) that the results of MultiheadAttention really do remain unchanged after this update? If so I'm happy to merge this.

@jenkspt
Copy link
Contributor Author

jenkspt commented Aug 23, 2022

Have you checked (with the out_axes=1 change) that the results of MultiheadAttention really do remain unchanged after this update? If so I'm happy to merge this.

I checked the outputs and they match with out_axes=1. The tests should probably be updated to catch this (since using out_axes=0 passes the tests -- but doesn't match the outputs of the current MultiheadAttention. Writing these tests isn't easy however -- maybe just copying a small input/output from the current MultiheadAttention and adding it to the tests would suffice?

Also while doing this I noticed that the outputs do not match when using dropout -- I'm not sure yet if this is because the order that dropout is being applied is different, or whether there is a problem with vmap'ing when the key is closed over in the dropout_fn.

@jenkspt
Copy link
Contributor Author

jenkspt commented Aug 23, 2022

Yea the problem is vmap'ing with the key in dropout_fn

@patrick-kidger
Copy link
Owner

patrick-kidger commented Aug 29, 2022

Just checking that you're not waiting on any input from me at the moment?

@jenkspt
Copy link
Contributor Author

jenkspt commented Aug 30, 2022

The most recent changes fix the problem with closing over the key in dropout_fn by adding explicit key and inference arguments.

However vmap'ing dropout doesn't match the un-vmap'ed version.
For example:

import jax
import jax.numpy as jnp
from equinox import nn

key = jax.random.PRNGKey(41)

x = jnp.arange(4*5).reshape(4, 5)
dropout = nn.Dropout(.5)

y1 = dropout(x, key=key)
y2 = jax.vmap(dropout, in_axes=0, out_axes=0)(x, key=jax.random.split(key, 4))
y3 = jax.vmap(dropout, in_axes=1, out_axes=1)(x, key=jax.random.split(key, 5))

jnp.allclose(y1, y2)    # --> False
jnp.allclose(y1, y3)    # --> False

I may be fundamentally misunderstanding how the PRNG works, but my expectation is that y1 and y2 are equal. Any insight into why these don't match? Otherwise I can create a jax issue.

A simpler MWE:

import jax
import jax.numpy as jnp

key = jax.random.PRNGKey(41)

def test(key):
    return jax.random.normal(key, ())

y1 = jax.random.normal(key, (4,))
y2 = jax.vmap(test)(jax.random.split(key, 4))
jnp.allclose(y1, y2)    # --> False

@patrick-kidger
Copy link
Owner

patrick-kidger commented Aug 30, 2022

It's definitely expected that these don't match. jax.random.split should return new keys that produce statistically independent random numbers to their parent. (And to each other.)

I could believe that there's no way to reproduce the same behaviour as before when using dropout, since we now have a vmap'd dropout and JAX may be doing something else under-the-hood. I think that's fine -- it's mostly inference mode that I'm concerned about; I'd prefer not to break any models that have already been serialised to disk, but dropout is really training-time-only.

@jenkspt
Copy link
Contributor Author

jenkspt commented Aug 30, 2022

Ok that sounds reasonable. Pending any additional feedback I think this is ready to merge.

@patrick-kidger patrick-kidger merged commit bc3a8d9 into patrick-kidger:main Aug 30, 2022
@patrick-kidger
Copy link
Owner

Excellent. Thanks for contributing!
(I'll be doing a new release shortly.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants