Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linear Attention Mechanism #2150

Closed
parmarsuraj99 opened this issue Sep 5, 2020 · 12 comments
Closed

Linear Attention Mechanism #2150

parmarsuraj99 opened this issue Sep 5, 2020 · 12 comments

Comments

@parmarsuraj99
Copy link

parmarsuraj99 commented Sep 5, 2020

Describe the feature and the current behavior/state.

Are we going to add LinearAttention? If yes, I can start working on it

Relevant information

Which API type would this fall under (layer, metric, optimizer, etc.) layer

Who will benefit with this feature? Building Transformer blocks that are faster with O(N) complexity compared to standard Softmax dot product attnetion

Any other info.
Paper's website

@AakashKumarNain
Copy link
Member

Please feel free to open a PR @parmarsuraj99 Thank you

@bhack
Copy link
Contributor

bhack commented Sep 5, 2020

/cc @saberkun @tanzhenyu @dynamicwebpaige Do you have any internal plan for this?

@saberkun
Copy link
Member

saberkun commented Sep 5, 2020

Looking at the pytorch implementation, what's the difference with https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/layers/dense_attention.py?

@parmarsuraj99
Copy link
Author

In the PyTorch implementation, the authors have implemented many variants of attentions https://github.com/idiap/fast-transformers/tree/master/fast_transformers/attention

The one referenced above is this specific one
https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py

Major difference is the calculation of Values. Instead of Softmax, they introduced a kernel (here elu) on Q and K. Which according to them has shown some improvements.

Major change is focused only in this part.

depth = tf.constant(self.head_size, dtype=tf.float32)
query /= tf.sqrt(depth)
# Calculate dot product attention
logits = tf.einsum("...NHO,...MHO->...HNM", query, key)
# apply mask
if mask is not None:
mask = tf.cast(mask, tf.float32)
# possibly expand on the head dimension so broadcasting works
if len(mask.shape) != len(logits.shape):
mask = tf.expand_dims(mask, -3)
logits += -10e9 * (1.0 - mask)
attn_coef = tf.nn.softmax(logits)

Can we implement something like callable Attention calculation after calculating Linear projections of inputs?
something like,

query = tf.einsum("...NI , HIO -> ...NHO", query, self.query_kernel)
key = tf.einsum("...MI , HIO -> ...MHO", key, self.key_kernel)
value = tf.einsum("...MI , HIO -> ...MHO", value, self.value_kernel)

output, attn_coef = scaled_dot_product_attention(query, key, value, mask)

or

output, attn_coef = linearized_attention(query, key, value, mask)

@tanzhenyu
Copy link
Contributor

In the PyTorch implementation, the authors have implemented many variants of attentions https://github.com/idiap/fast-transformers/tree/master/fast_transformers/attention

The one referenced above is this specific one
https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py

Major difference is the calculation of Values. Instead of Softmax, they introduced a kernel (here elu) on Q and K. Which according to them has shown some improvements.

Major change is focused only in this part.

depth = tf.constant(self.head_size, dtype=tf.float32)
query /= tf.sqrt(depth)
# Calculate dot product attention
logits = tf.einsum("...NHO,...MHO->...HNM", query, key)
# apply mask
if mask is not None:
mask = tf.cast(mask, tf.float32)
# possibly expand on the head dimension so broadcasting works
if len(mask.shape) != len(logits.shape):
mask = tf.expand_dims(mask, -3)
logits += -10e9 * (1.0 - mask)
attn_coef = tf.nn.softmax(logits)

Can we implement something like callable Attention calculation after calculating Linear projections of inputs?
something like,

query = tf.einsum("...NI , HIO -> ...NHO", query, self.query_kernel)
key = tf.einsum("...MI , HIO -> ...MHO", key, self.key_kernel)
value = tf.einsum("...MI , HIO -> ...MHO", value, self.value_kernel)

output, attn_coef = scaled_dot_product_attention(query, key, value, mask)

or

output, attn_coef = linearized_attention(query, key, value, mask)

This seems to be something we could ask user to subclass?

@saberkun
Copy link
Member

saberkun commented Sep 8, 2020

@parmarsuraj99 Thanks!

A bit more concrete idea of subclassing.
There are many innovations about how attention is computed.
Thus, we are trying to have the keras MultiHeadAttention layer being subclassed:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/layers/multi_head_attention.py#L399
A few variants of attention can be accomplished by overriding the compute_attention method.
The variants subclass layers could be good fit to hosted inside packages like addons.

@parmarsuraj99
Copy link
Author

parmarsuraj99 commented Sep 8, 2020

@parmarsuraj99 Thanks!

A bit more concrete idea of subclassing.
There are many innovations about how attention is computed.
Thus, we are trying to have the keras MultiHeadAttention layer being subclassed:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/layers/multi_head_attention.py#L399
A few variants of attention can be accomplished by overriding the compute_attention method.
The variants subclass layers could be good fit to hosted inside packages like addons.

Thanks

This is exactly what I was referring for implementation.
Subclassing MultiHeadAttention and flexible attention computation.

@seanpmorgan
Copy link
Member

Given the feedback from TF team, if you would like to submit and maintain a PR for subclassing the Keras MHA that would be okay to proceed.

@abhishek-niranjan
Copy link
Contributor

@seanpmorgan I'd like to contribute to this too. @parmarsuraj99 let me know if you'd like to work on it together?

@parmarsuraj99
Copy link
Author

@abhishek-niranjan Sure. I'd really love to collaborate

@claverru
Copy link

How is this going?

@seanpmorgan
Copy link
Member

TensorFlow Addons is transitioning to a minimal maintenance and release mode. New features will not be added to this repository. For more information, please see our public messaging on this decision:
TensorFlow Addons Wind Down

Please consider sending feature requests / contributions to other repositories in the TF community with a similar charters to TFA:
Keras
Keras-CV
Keras-NLP

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

8 participants