<a href="https://colab.research.google.com/github/rahiakela/dive-to-deep-learning/blob/main/10-attention-mechanisms/1_attention_mechanism.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Attention Mechanism

As a bit of a historical digression, attention research is an enormous field with a long history in cognitive neuroscience. Focalization, concentration of consciousness are of the essence of attention, which enable the human to prioritize the perception in order to deal effectively with others.

As a result, we do not process all the information that is available in the sensory input. At any time, we are aware of only a small fraction of the information in the environment. In cognitive neuroscience, there are several types of attention such as selective attention, covert attention, and
spatial attention. The theory ignites the spark in recent deep learning is the feature integration theory of the selective attention, which was developed by [Anne Treisman and Garry Gelade through the paper](https://www.sciencedirect.com/science/article/abs/pii/0010028580900055) in 1980. 

This paper declares that when perceiving a stimulus, features are registered early, automatically, and in parallel, while objects are identified separately
and at a later stage in processing. The theory has been one of the most influential psychological models of human visual attention.

In seq2seq, we encode the source sequence input information in the recurrent unit state and then pass it to the decoder to generate the target sequence. A token in the target sequence may closely relate to one or more tokens in the source sequence, instead of the whole source sequence.

For example, when translating “Hello world.” to “Bonjour le monde.”, “Bonjour” maps to “Hello” and “monde” maps to “world”. **In the seq2seq model, the decoder may implicitly select the corresponding information from the state passed by the encoder. The attention mechanism, however, makes this selection explicit.**

**Attention is a generalized pooling method with bias alignment over inputs. The core component in the attention mechanism is the attention layer, or called attention for simplicity. An input of the attention layer is called a query.**

For a query, attention returns an output based on the memory—a set of key-value pairs encoded in the attention layer. To be more specific, assume that the memory contains $n$ key-value pairs, $(k_1; v_1),..., (k_n; v_n)$, with $k_i \in\mathbb R^{d_k} , v_i \in\mathbb R^{d_v}$ . Given a query $q \in\mathbb R^{d_q}$ , the attention layer returns an output $o \in\mathbb R^{d_v}$ with the same shape as the value.

<img src='https://github.com/rahiakela/img-repo/blob/master/attention-mechanism.png?raw=1' width='800'/>

To compute the output of attention, we first use a score function  that measures the similarity between the query and the key. So for each key $k_1,.., k_n$, we compute the scores $a1,..,a_n$ by

$$ a_i=\alpha(q,k_i)$$

Next we use softmax to obtain the attention weights, i.e.,

$$ b = softmax(a), where, b_i={\frac{exp(a_i)}{\sum_j{exp(a_j)}}},b=[b_1,...,b_n]^T$$

Finally, the output is a weighted sum of the values:

$$o = \sum_{i=1}^{n}b_iV_i$$

<img src='https://github.com/rahiakela/img-repo/blob/master/attention-output.png?raw=1' width='800'/>

Different choices of the score function lead to different attention layers. Below, we introduce two commonly used attention layers. Before diving into the implementation, we first express two operators to get you up and running: 

- a masked version of the softmax operator masked_softmax and
- a specialized dot operator batch_dot.


## Setup

In [None]:
!pip install -U mxnet-cu101==1.7.0

In [3]:
import math
from mxnet import nd
from mxnet.gluon import nn

## Masked softmax attention

The masked softmax takes a 3-dimensional input and enables us to filter out some elements by specifying a valid length for the last dimension. As a result, any value outside the valid length will be masked as 0. Let us implement the
masked_softmax function.

In [4]:
def masked_softmax(X, valid_length):
  """Perform softmax by filtering out some elements."""

  # X: 3-D tensor, valid_length: 1-D or 2-D tensor
  if valid_length is None:
      return X.softmax()
  else:
      shape = X.shape
      if valid_length.ndim == 1:
          valid_length = valid_length.repeat(shape[1], axis=0)
      else:
          valid_length = valid_length.reshape((-1,))
      # fill masked elements with a large negative, whose exp is 0
      X = nd.SequenceMask(X.reshape((-1, shape[-1])), valid_length, True, axis=1, value=-1e6)
      return X.softmax().reshape(shape)

We need to construct two 2X4 matrices as the input. In addition, we specify that the valid length equals to 2 for the first example, and 3 for the second example.

In [5]:
masked_softmax(nd.random.uniform(shape=(2,2,4)), nd.array([2,3]))


[[[0.488994   0.511006   0.         0.        ]
  [0.4365484  0.56345165 0.         0.        ]]

 [[0.288171   0.3519408  0.3598882  0.        ]
  [0.29034296 0.25239873 0.45725837 0.        ]]]
<NDArray 2x2x4 @cpu(0)>

Then, as we can see from the following outputs, the values outside valid lengths are masked as zero.

## Dot Product Attention

$$\alpha(\mathbf Q, \mathbf K) = \langle \mathbf Q, \mathbf K^T \rangle /\sqrt{d}.$$

In [None]:
class DotProductAttention(nn.Block):  
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    # query: (batch_size, #queries, d)
    # key: (batch_size, #kv_pairs, d)
    # value: (batch_size, #kv_pairs, dim_v)
    # valid_length: either (batch_size, ) or (batch_size, seq_len) 
    def forward(self, query, key, value, valid_length=None):
        d = query.shape[-1]
        # set transpose_b=True to swap the last two dimensions of key
        scores = nd.batch_dot(query, key, transpose_b=True) / math.sqrt(d)
        attention_weights = self.dropout(masked_softmax(scores, valid_length))
        return nd.batch_dot(attention_weights, value)

Example:

In [None]:
atten = DotProductAttention(dropout=0.5)
atten.initialize()
keys = nd.ones((2,10,2))
values = nd.arange(40).reshape((1,10,4)).repeat(2,axis=0)
atten(nd.ones((2,1,2)), keys, values, nd.array([2, 6]))


[[[ 2.        3.        4.        5.      ]]

 [[10.       11.       12.000001 13.      ]]]
<NDArray 2x1x4 @cpu(0)>

## Multilayer Perception Attention

$\mathbf W_k\in\mathbb R^{h\times d_k}$, $\mathbf W_q\in\mathbb R^{h\times d_q}$, and $\mathbf v\in\mathbb R^{p}$:

$$\alpha(\mathbf k, \mathbf q) = \mathbf v^T \text{tanh}(\mathbf W_k \mathbf k + \mathbf W_q\mathbf q). $$


In [None]:
class MLPAttention(nn.Block):  # This class is saved in d2l. 
    def __init__(self, units, dropout, **kwargs):
        super(MLPAttention, self).__init__(**kwargs)
        # Use flatten=True to keep query's and key's 3-D shapes.   
        self.W_k = nn.Dense(units, activation='tanh', 
                            use_bias=False, flatten=False)
        self.W_q = nn.Dense(units, activation='tanh', 
                            use_bias=False, flatten=False)
        self.v = nn.Dense(1, use_bias=False, flatten=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, valid_length):
        query, key = self.W_k(query), self.W_q(key)
        # expand query to (batch_size, #querys, 1, units), and key to 
        # (batch_size, 1, #kv_pairs, units). Then plus them with broadcast.  
        features = query.expand_dims(axis=2) + key.expand_dims(axis=1)
        scores = self.v(features).squeeze(axis=-1)
        attention_weights = self.dropout(masked_softmax(scores, valid_length))
        return nd.batch_dot(attention_weights, value)

Example

In [None]:
atten = MLPAttention(units=8, dropout=0.1)
atten.initialize()
atten(nd.ones((2,1,2)), keys, values, nd.array([2, 6]))


[[[ 2.        3.        4.        5.      ]]

 [[10.       11.       12.000001 13.      ]]]
<NDArray 2x1x4 @cpu(0)>