Motivation. We may want our model to **combine** knowledge from different behaviors of the same attention mechanism
- capturing dependencies of various ranges (e.g., shorter-range vs. longer-range) within a sequence. <br>
Thus, it may be beneficial to allow our attention mechanism to jointly use different representation subspaces of queries, keys, and values.

- Instead of performing a single attention pooling, {queries, keys, and values} can be transformed with $h$ independently learned linear projections. <br> - $h$ attention pooling outputs are concatenated and transformed with another learned <font color='red'>linear projection</font> to produce the final output.

This design is called **multi-head attention**, where each of the $h$ attention
pooling outputs is a **head**.

### Model

Given a query $\mathbf{q} \in \mathbb{R}^{d_q}$, a key $\mathbf{k} \in \mathbb{R}^{d_k}$, and a value $\mathbf{v} \in \mathbb{R}^{d_v}$, each attention head $\mathbf{h}_i (i=1,...,h)$ is computed as
\begin{equation}
\begin{aligned}
    \mathbf{h}_i &= f(\mathbf{W}_i^{(q)}\mathbf{q}, \mathbf{W}_i^{(k)}\mathbf{k}, \mathbf{W}_i^{(v)}\mathbf{v}) \in \mathbb{R}^{p_v} \\
                 &= \sum_{j=1}^{m}{\alpha{(\mathbf{W}_i^{(q)}\mathbf{q},\mathbf{W}_i^{(k)}\mathbf{k}_j}) \mathbf{W}_i^{(v)}\mathbf{v}_j} \\
\end{aligned}
\end{equation}
where learnable parameters $\mathbf{W}_i^{(q)} \in \mathbb{R}^{p_q \times d_q}$, $\mathbf{W}_i^{(k)} \in \mathbb{R}^{p_k \times d_k}$ and $\mathbf{W}_i^{(v)} \in \mathbb{R}^{p_v \times d_v}$, and $f$ is <font color='red'>attention pooling, such as additive attention and scaled dot-product attention</font>.

The multi-head attention output is another linear transformation via learnable parameters $\mathbf{W}_o \in \mathbb{R}^{p_o \times hp_v}$ of the concatenation of $h$ heads:
\begin{equation}
    \mathbf{W}_o \begin{bmatrix} \mathbf{h}_1 \\ \mathbf{h}_2 \\ \vdots \\ \mathbf{h}_h \\ \end{bmatrix} \in \mathbb{R}^{p_o}.
\end{equation}
Based on this design, <font color='red'>each head</font> may attend to <font color='red'>different parts of the input</font>. More sophisticated 
functions than the simple weighted average can be expressed.

In [1]:
import torch
from d2l import torch as d2l
from torch import nn

In [2]:
class MultiHeadAttention(d2l.Module):
    """Multi-head attention."""

    def __init__(self, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super().__init__()
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_k = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_v = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_o = nn.LazyLinear(num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # Shape of queries, keys, or values:
        # (batch_size, no. of queries or key-value pairs, num_hiddens)
        # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
        # After transposing, shape of output queries, keys, or values:
        # (batch_size * num_heads, no. of queries or key-value pairs,
        # num_hiddens / num_heads)
        queries = self.transpose_qkv(self.W_q(queries))
        keys = self.transpose_qkv(self.W_k(keys))
        values = self.transpose_qkv(self.W_v(values))
        if valid_lens is not None:
            # On axis 0, copy the first item (scalar or vector) for num_heads
            # times, then copy the next item, and so on
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
        # Shape of output: (batch_size * num_heads, no. of queries,
        # num_hiddens / num_heads)
        output = self.attention(queries, keys, values, valid_lens)
        # Shape of output_concat: (batch_size, no. of queries, num_hiddens)
        output_concat = self.transpose_output(output)
        return self.W_o(output_concat)

    def transpose_qkv(self, X):
        """Transposition for parallel computation of multiple attention heads."""
        # Shape of input X: (batch_size, no. of queries or key-value pairs,
        # num_hiddens). Shape of output X: (batch_size, no. of queries or
        # key-value pairs, num_heads, num_hiddens / num_heads)
        X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
        # Shape of output X: (batch_size, num_heads, no. of queries or key-value
        # pairs, num_hiddens / num_heads)
        X = X.permute(0, 2, 1, 3)
        # Shape of output: (batch_size * num_heads, no. of queries or key-value
        # pairs, num_hiddens / num_heads)
        return X.reshape(-1, X.shape[2], X.shape[3])

    def transpose_output(self, X):
        """Reverse the operation of transpose_qkv."""
        X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
        X = X.permute(0, 2, 1, 3)
        return X.reshape(X.shape[0], X.shape[1], -1)

**REMARK**. To avoid significant growth of computational cost and parameterization cost, we set $p_q = p_k = p_v = p_o/h$. Above, $p_o$ is specified via ```num_hiddens```.

### nn.LazyLinear

The framework defers initialization, waiting **until the first time we pass data** through the model, to infer the sizes of each layer.

**Parameters** <br>
- out_features (int) – size of each output sample. <br>
**Variables** <br>
- weight (torch.nn.parameter.UninitializedParameter) – the learnable weights of the module of shape ( out_features , in_features ) (out_features,in_features). <br>
- bias (torch.nn.parameter.UninitializedParameter) – the learnable bias of the module of shape ( out_features ) (out_features).

In [8]:
net = nn.Sequential(nn.LazyLinear(256), nn.ReLU(), nn.LazyLinear(10))

In [9]:
net[0].weight

<UninitializedParameter>

In [10]:
X = torch.rand(2, 20)
net(X)
net[0].weight.shape

torch.Size([256, 20])

In [11]:
net

Sequential(
  (0): Linear(in_features=20, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
)

In [20]:
X = torch.rand(2, 5, 20)
net(X)
net[0].weight.shape

torch.Size([256, 20])

### torch.repeat_interleave

```valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)``` <br>
dim (int, optional) – The dimension along which to repeat values. By default, use the **flattened** input array, and return a flat output array.

In [4]:
x = torch.tensor([[1, 2], [3, 4]])
torch.repeat_interleave(x, repeats=2, dim=0)

tensor([[1, 2],
        [1, 2],
        [3, 4],
        [3, 4]])

### transpose_qkv

```queries = self.transpose_qkv(self.W_q(queries))```
- Shape of queries, keys, or values: (batch_size, no. of queries or key-value pairs, num_hiddens)
- After transposing, shape of output queries, keys, or values: (batch_size * num_heads, no. of queries or key-value pairs, num_hiddens / num_heads)

### Toy example

In [3]:
torch.manual_seed(123)

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
d2l.check_shape(attention(X, Y, Y, valid_lens),
                (batch_size, num_queries, num_hiddens))



### Variable dimensions

queries.shape: torch.Size([2, 4, 100]) <br>
keys.shape: torch.Size([2, 6, 100]) <br> 
self.W_k(keys).shape: torch.Size([2, 6, 100])

**After** we pass data <br>
queries.shape: torch.Size([10, 4, 20]) <br> 
keys.shape: torch.Size([10, 6, 20]) <br>
values.shape: torch.Size([10, 6, 20])

self.W_q.weight.shape: torch.Size([100, 100]) <br>
self.W_k.weight.shape: torch.Size([100, 100]) <br>
self.W_v.weight.shape: torch.Size([100, 100])

Example.
1. nn.LazyLinear infers the dimension of $W_k$ should be 100 by 100
2. self.W_k(keys). ```keys.shape: torch.Size([2, 6, 100])```. The layer ```self.W_k``` is applied on all the 2*6=12 vectors to generate another 12 100D vectors. 