In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
print(torch.__version__)

1.12.0+cu113


The interactions between queries (volitional cues) and keys (nonvolitional cues) result in attention pooling. The attention pooling selectively aggregates values (sensory inputs) to produce the output. In this section, we will describe attention pooling in greater detail to give you a high-level view of how attention mechanisms work in practice. Specifically, the Nadaraya-Watson kernel regression model proposed in 1964 is a simple yet complete example for demonstrating machine learning with attention mechanisms

# 1 intuition of attention mechanisms

## generate mock dataset

In [2]:
n_train = 50 # train sample size
x_train, _ = torch.sort(torch.rand(n_train) * 5)   # sorted train samples

def f(x):
    return 2 * torch.sin(x) + x**0.8

y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))  # ooutput of train samples
x_test = torch.arange(0, 5, 0.5)  # test samples
y_truth = f(x_test)  # true label of test samples
n_test = len(x_test) 
n_test

10

In [3]:
print(x_train)

tensor([0.1750, 0.3232, 0.3445, 0.6826, 0.6932, 0.7577, 0.7591, 0.8219, 0.8320,
        0.9403, 0.9977, 1.0834, 1.0958, 1.1068, 1.3033, 1.6270, 1.8241, 1.8301,
        1.9882, 2.0055, 2.0410, 2.3363, 2.4483, 2.6827, 2.6923, 2.7316, 2.8533,
        2.8697, 2.8959, 3.2993, 3.3062, 3.3612, 3.3637, 3.5369, 3.6493, 3.7116,
        3.7555, 3.9016, 3.9295, 3.9970, 4.1728, 4.2189, 4.2192, 4.2819, 4.3514,
        4.5283, 4.6237, 4.6254, 4.7662, 4.9375])


In [4]:
print(y_truth)

tensor([0.0000, 1.5332, 2.6829, 3.3782, 3.5597, 3.2783, 2.6905, 2.0227, 1.5178,
        1.3759])


In [5]:
print(y_train)

tensor([0.4839, 0.4465, 1.2327, 2.6117, 2.3442, 2.4821, 2.2916, 1.7639, 1.8538,
        2.5228, 1.7965, 2.8513, 3.1080, 2.2103, 2.6587, 3.6102, 3.5720, 3.8102,
        3.4114, 3.0347, 3.1894, 3.7735, 3.7699, 3.8172, 3.7964, 3.6540, 2.7537,
        3.6843, 2.3706, 2.1517, 1.9694, 2.8241, 3.2272, 1.5330, 1.7377, 1.6393,
        1.8550, 1.9915, 1.5142, 0.8626, 1.8244, 1.1411, 1.4603, 1.0822, 0.7710,
        1.8000, 1.0719, 1.8532, 1.6951, 1.4659])


In [6]:
print(x_test)

tensor([0.0000, 0.5000, 1.0000, 1.5000, 2.0000, 2.5000, 3.0000, 3.5000, 4.0000,
        4.5000])


## 1.1 Average Pooling

We begin with perhaps the world’s “dumbest” estimator for this regression problem: using average pooling to average over all the training outputs:

f(x) = \frac{1}{n}\sum_{i=1}^n y_i,

In [7]:
y_hat = y_train.mean().repeat(n_test)
print(y_hat)

tensor([2.2875, 2.2875, 2.2875, 2.2875, 2.2875, 2.2875, 2.2875, 2.2875, 2.2875,
        2.2875])


## 1.2 Nonparametric Attention Pooling

In [8]:
def diff(queries, keys):
    return queries.reshape((-1, 1)) - keys.reshape((1, -1)) #


def score_function(queries, keys):
    query_key_diffs = diff(queries, keys)
    scores = - query_key_diffs**2 / 2
    return scores

def attention_pool(scores, values):
    attention_weights = F.softmax(scores, dim=1)
    return torch.matmul(attention_weights, values), attention_weights

y_hat, attention_weights = attention_pool(score_function(x_test, x_train), y_train)

In [49]:
print(x_test.reshape((-1, 1)).shape)

print(x_train.reshape((1, -1)).shape)

a = x_test.reshape((-1, 1)) - x_train.reshape((1, -1))
print(a.shape)

torch.Size([10, 1])
torch.Size([1, 50])
torch.Size([10, 50])


In [9]:
print(y_hat)

tensor([2.0835, 2.2867, 2.5109, 2.7237, 2.8440, 2.7861, 2.5560, 2.2535, 1.9771,
        1.7711])


In [10]:
print(attention_weights)
print(attention_weights.shape)

tensor([[8.2032e-02, 7.9059e-02, 7.8498e-02, 6.5988e-02, 6.5506e-02, 6.2511e-02,
         6.2446e-02, 5.9421e-02, 5.8928e-02, 5.3535e-02, 5.0640e-02, 4.6321e-02,
         4.5696e-02, 4.5147e-02, 3.5626e-02, 2.2171e-02, 1.5780e-02, 1.5609e-02,
         1.1541e-02, 1.1149e-02, 1.0377e-02, 5.4377e-03, 4.1590e-03, 2.2796e-03,
         2.2217e-03, 1.9967e-03, 1.4217e-03, 1.3565e-03, 1.2578e-03, 3.6047e-04,
         3.5234e-04, 2.9336e-04, 2.9085e-04, 1.6001e-04, 1.0684e-04, 8.4954e-05,
         7.2121e-05, 4.1228e-05, 3.6959e-05, 2.8279e-05, 1.3793e-05, 1.1365e-05,
         1.1353e-05, 8.6959e-06, 6.4432e-06, 2.9374e-06, 1.8981e-06, 1.8837e-06,
         9.7218e-07, 4.2343e-07],
        [5.5944e-02, 5.8064e-02, 5.8270e-02, 5.8004e-02, 5.7888e-02, 5.7052e-02,
         5.7032e-02, 5.6001e-02, 5.5817e-02, 5.3531e-02, 5.2109e-02, 4.9751e-02,
         4.9387e-02, 4.9062e-02, 4.2713e-02, 3.1252e-02, 2.4546e-02, 2.4353e-02,
         1.9488e-02, 1.8990e-02, 1.7991e-02, 1.0927e-02, 8.8392e-03, 5.4472

## 1.3 Parametric Attention Pooling

In [11]:
#(n_train，n_train)
X_tile = x_train.repeat((n_train, 1))
print(X_tile)
print(X_tile.shape)

tensor([[0.1750, 0.3232, 0.3445,  ..., 4.6254, 4.7662, 4.9375],
        [0.1750, 0.3232, 0.3445,  ..., 4.6254, 4.7662, 4.9375],
        [0.1750, 0.3232, 0.3445,  ..., 4.6254, 4.7662, 4.9375],
        ...,
        [0.1750, 0.3232, 0.3445,  ..., 4.6254, 4.7662, 4.9375],
        [0.1750, 0.3232, 0.3445,  ..., 4.6254, 4.7662, 4.9375],
        [0.1750, 0.3232, 0.3445,  ..., 4.6254, 4.7662, 4.9375]])
torch.Size([50, 50])


In [12]:
#(n_train，n_train)
Y_tile = y_train.repeat((n_train, 1))
print(Y_tile)
print(Y_tile.shape)

tensor([[0.4839, 0.4465, 1.2327,  ..., 1.8532, 1.6951, 1.4659],
        [0.4839, 0.4465, 1.2327,  ..., 1.8532, 1.6951, 1.4659],
        [0.4839, 0.4465, 1.2327,  ..., 1.8532, 1.6951, 1.4659],
        ...,
        [0.4839, 0.4465, 1.2327,  ..., 1.8532, 1.6951, 1.4659],
        [0.4839, 0.4465, 1.2327,  ..., 1.8532, 1.6951, 1.4659],
        [0.4839, 0.4465, 1.2327,  ..., 1.8532, 1.6951, 1.4659]])
torch.Size([50, 50])


In [13]:
# keys的形状:('n_train'，'n_train'-1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
print(keys)
print(keys.shape)

tensor([[0.3232, 0.3445, 0.6826,  ..., 4.6254, 4.7662, 4.9375],
        [0.1750, 0.3445, 0.6826,  ..., 4.6254, 4.7662, 4.9375],
        [0.1750, 0.3232, 0.6826,  ..., 4.6254, 4.7662, 4.9375],
        ...,
        [0.1750, 0.3232, 0.3445,  ..., 4.6237, 4.7662, 4.9375],
        [0.1750, 0.3232, 0.3445,  ..., 4.6237, 4.6254, 4.9375],
        [0.1750, 0.3232, 0.3445,  ..., 4.6237, 4.6254, 4.7662]])
torch.Size([50, 49])


In [14]:
# values的形状:('n_train'，'n_train'-1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
print(values)
print(values.shape)

tensor([[0.4465, 1.2327, 2.6117,  ..., 1.8532, 1.6951, 1.4659],
        [0.4839, 1.2327, 2.6117,  ..., 1.8532, 1.6951, 1.4659],
        [0.4839, 0.4465, 2.6117,  ..., 1.8532, 1.6951, 1.4659],
        ...,
        [0.4839, 0.4465, 1.2327,  ..., 1.0719, 1.6951, 1.4659],
        [0.4839, 0.4465, 1.2327,  ..., 1.0719, 1.8532, 1.4659],
        [0.4839, 0.4465, 1.2327,  ..., 1.0719, 1.8532, 1.6951]])
torch.Size([50, 49])


### train process

In [15]:
def score_function_with_parameter(queries, keys, w):
    return -((queries - keys) * w)**2 / 2
    

class NWKernelRegression(nn.Module):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.w = nn.Parameter(torch.rand((1,), requires_grad=True))

    def forward(self, queries, keys, values):
        # queries和attention_weights的形状为(查询个数，“键－值”对个数)
        queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))
        scores = score_function_with_parameter(queries,keys, self.w)
        self.attention_weights = nn.functional.softmax(scores, dim=1)
        # values的形状为(查询个数，“键－值”对个数)
        return torch.bmm(self.attention_weights.unsqueeze(1),values.unsqueeze(-1)).reshape(-1)

In [16]:
net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
optimizer = torch.optim.SGD(net.parameters(), lr=0.5)

for epoch in range(5):
    optimizer.zero_grad()
    l = loss(net(x_train, keys, values), y_train)
    l.sum().backward()
    optimizer.step()
    print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')

epoch 1, loss 31.119806
epoch 2, loss 10.461807
epoch 3, loss 10.461460
epoch 4, loss 10.461108
epoch 5, loss 10.460758


In [17]:
print(net.w)

Parameter containing:
tensor([17.1402], requires_grad=True)


In [18]:
print(net.attention_weights)
print(net.attention_weights.shape)

tensor([[7.2999e-01, 2.7001e-01, 7.1064e-16,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [4.0903e-02, 9.5910e-01, 6.0871e-09,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.5529e-02, 9.8447e-01, 5.5434e-08,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 7.6590e-01, 4.1745e-02,
         4.7585e-07],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 4.2702e-01, 4.5744e-01,
         1.1346e-01],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 3.9473e-05, 4.5953e-05,
         9.9991e-01]], grad_fn=<SoftmaxBackward0>)
torch.Size([50, 49])


### evaluate process

In [19]:
# keys的形状:(n_test，n_train)，每一行包含着相同的训练输入（例如，相同的键）
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test，n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()

In [20]:
print(net.w)

Parameter containing:
tensor([17.1402], requires_grad=True)


In [21]:
print(net.attention_weights)
print(net.attention_weights.shape)

tensor([[9.9998e-01, 1.9499e-05, 2.4042e-06, 1.7062e-28, 1.9839e-29, 2.1132e-35,
         1.5621e-35, 7.1901e-42, 6.2358e-43, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [3.6027e-06, 2.0024e-01, 5.6749e-01, 1.4791e-01, 8.2179e-02, 1.1443e-03,
         1.0324e-03, 4.8421e-06, 1.8423e-06, 8.4895e-12, 3.1332e-15, 3.8617e-21,
         4.4563e-22, 6.4202e-23, 1.3355e-40, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000

# 2 general framework of attention mechanisms - Attention Scoring Functions

In Section 1, we used a Gaussian kernel to model interactions between queries and keys. Treating the exponent of the Gaussian kernel in (11.2.6) as an attention scoring function (or scoring function for short), the results of this function were essentially fed into a softmax operation. As a result, we obtained a probability distribution (attention weights) over values that are paired with keys. In the end, the output of the attention pooling is simply a weighted sum of the values based on these attention weights.

At a high level, we can use the above algorithm to instantiate the framework of attention mechanisms in Fig. 11.1.3. Denoting an attention scoring function by , Fig. 11.3.1 illustrates how the output of attention pooling can be computed as a weighted sum of values. Since attention weights are a probability distribution, the weighted sum is essentially a weighted average.

![](https://d2l.ai/_images/attention-output.svg)

Mathematically, suppose that we have a query \mathbf{q} \in \mathbb{R}^q and  key-value pairs , where any  and any . The attention pooling  is instantiated as a weighted sum of the values:

## elementary knowledge - Masked Softmax Operation

Ａs we just mentioned, a softmax operation is used to output a probability distribution as attention weights. In some cases, not all the values should be fed into attention pooling. For instance, for efficient minibatch processing in Section 10.5, some text sequences are padded with special tokens that do not carry meaning. To get an attention pooling over only meaningful tokens as values, we can specify a valid sequence length (in number of tokens) to filter out those beyond this specified range when computing softmax. In this way, we can implement such a masked softmax operation in the following masked_softmax function, where any value beyond the valid length is masked as zero

In [22]:
# X: 3D tensor, valid_lens: 1D or 2D tensor
def sequence_mask(attention_scores, valid_len, value=0):
    maxlen = attention_scores.size(1)
    mask = torch.arange((maxlen), dtype=torch.float32, device=attention_scores.device)[None, :] < valid_len[:, None]
    attention_scores[~mask] = value
    return attention_scores

    
def masked_softmax(attention_scores, valid_lens):
    """Perform softmax operation by masking elements on the last axis.
    Defined in :numref:`sec_attention-scoring-functions`"""
    
    if valid_lens is None:
        return nn.functional.softmax(attention_scores, dim=-1)
    else:
        shape = attention_scores.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # On the last axis, replace masked elements with a very large negative
        # value, whose exponentiation outputs 0
        attention_scores = sequence_mask(attention_scores.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(attention_scores.reshape(shape), dim=-1)

In [23]:
masked_attention_weights = masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))
print(masked_attention_weights)

tensor([[[0.7113, 0.2887, 0.0000, 0.0000],
         [0.4018, 0.5982, 0.0000, 0.0000]],

        [[0.4104, 0.2470, 0.3426, 0.0000],
         [0.2497, 0.4913, 0.2590, 0.0000]]])


In [24]:
masked_attention_weights = masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))
print(masked_attention_weights)

tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.4249, 0.1644, 0.4107, 0.0000]],

        [[0.6734, 0.3266, 0.0000, 0.0000],
         [0.2495, 0.1832, 0.3178, 0.2495]]])


## 2.1 Scaled Dot-Product Attention

A more computationally efficient design for the scoring function can be simply dot product. However, the dot product operation requires that both the query and the key have the same vector length, say d. Assume that all the elements of the query and the key are independent random variables with zero mean and unit variance. The dot product of both vectors has zero mean and a variance of . To ensure that the variance of the dot product still remains one regardless of vector length, the scaled dot-product attention scoring function

In [25]:
class DotProductAttention(nn.Module):
    """Scaled dot product attention."""
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
    # Shape of queries: (batch_size, no. of queries, d)
    # Shape of keys: (batch_size, no. of key-value pairs, d)
    # Shape of values: (batch_size, no. of key-value pairs, value dimension)
    # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # Swap the last two dimensions of keys with keys.transpose(1, 2)
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)

In [26]:
queries, keys = torch.normal(0, 1, (2, 1, 2)), torch.ones((2, 10, 2))

print(queries)

print(keys)

# The two value matrices in the values minibatch are identical
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
print(values)

valid_lens = torch.tensor([2, 6])

tensor([[[ 0.2017, -0.5536]],

        [[ 1.9334,  1.4100]]])
tensor([[[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]]])
tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.],
         [12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.],
         [24., 25., 26., 27.],
         [28., 29., 30., 31.],
         [32., 33., 34., 35.],
         [36., 37., 38., 39.]],

        [[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.],
         [12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.],
         [24., 25., 26., 27.],
         [28., 29., 30., 

In [27]:
attention = DotProductAttention(dropout=0.5)
attention.eval()
output = attention(queries, keys, values, valid_lens)
print(output)
print(output.shape)

print(attention.attention_weights)
print(attention.attention_weights.shape)

tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],

        [[10.0000, 11.0000, 12.0000, 13.0000]]])
torch.Size([2, 1, 4])
tensor([[[0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000]],

        [[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000,
          0.0000, 0.0000]]])
torch.Size([2, 1, 10])


## 2.2 Additive Attention

In general, when queries and keys are vectors of different lengths, we can use additive attention as the scoring function. Given a query  and a key , the additive attention scoring function

a(\mathbf q, \mathbf k) = \mathbf w_v^\top \text{tanh}(\mathbf W_q\mathbf q + \mathbf W_k \mathbf k) \in \mathbb{R},

In [28]:
class AdditiveAttention(nn.Module):
    """Additive attention."""
    def __init__(self, num_hiddens, dropout, **kwargs):
        super(AdditiveAttention, self).__init__(**kwargs)
        self.W_k = nn.LazyLinear(num_hiddens, bias=False)
        self.W_q = nn.LazyLinear(num_hiddens, bias=False)
        self.w_v = nn.LazyLinear(1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens):
        queries, keys = self.W_q(queries), self.W_k(keys)
        # After dimension expansion, shape of queries: (batch_size, no. of
        # queries, 1, num_hiddens) and shape of keys: (batch_size, 1, no. of
        # key-value pairs, num_hiddens). Sum them up with broadcasting
        features = queries.unsqueeze(2) + keys.unsqueeze(1)
        features = torch.tanh(features)
        # There is only one output of self.w_v, so we remove the last
        # one-dimensional entry from the shape. Shape of scores: (batch_size,
        # no. of queries, no. of key-value pairs)
        scores = self.w_v(features).squeeze(-1)
        self.attention_weights = masked_softmax(scores, valid_lens)
        # Shape of values: (batch_size, no. of key-value pairs, value
        # dimension)
        return torch.bmm(self.dropout(self.attention_weights), values)

In [29]:
queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))
# The two value matrices in the values minibatch are identical
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(
    2, 1, 1)
valid_lens = torch.tensor([2, 6])

attention = AdditiveAttention(num_hiddens=8, dropout=0.1)
attention.eval()
output = attention(queries, keys, values, valid_lens)
print(output)
print(output.shape)

print(attention.attention_weights)
print(attention.attention_weights.shape)

tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],

        [[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=<BmmBackward0>)
torch.Size([2, 1, 4])
tensor([[[0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000]],

        [[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000,
          0.0000, 0.0000]]], grad_fn=<SoftmaxBackward0>)
torch.Size([2, 1, 10])




# further reading

## multi-head attention

In practice, given the same set of queries, keys, and values we may want our model to combine knowledge from different behaviors of the same attention mechanism, such as capturing dependencies of various ranges (e.g., shorter-range vs. longer-range) within a sequence. Thus, it may be beneficial to allow our attention mechanism to jointly use different representation subspaces of queries, keys, and values.

To this end, instead of performing a single attention pooling, queries, keys, and values can be transformed with  independently learned linear projections. Then these  projected queries, keys, and values are fed into attention pooling in parallel. In the end,  attention pooling outputs are concatenated and transformed with another learned linear projection to produce the final output. This design is called multi-head attention, where each of the  attention pooling outputs is a head [Vaswani et al., 2017]. Using fully connected layers to perform learnable linear transformations, Fig. 11.5.1 describes multi-head attention

![](https://d2l.ai/_images/multi-head-attention.svg)

In [30]:
def transpose_qkv(X, num_heads):
    """Transposition for parallel computation of multiple attention heads."""
    # Shape of input X: (batch_size, no. of queries or key-value pairs,
    # num_hiddens). Shape of output X: (batch_size, no. of queries or
    # key-value pairs, num_heads, num_hiddens / num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    # Shape of output X: (batch_size, num_heads, no. of queries or key-value
    # pairs, num_hiddens / num_heads)
    X = X.permute(0, 2, 1, 3)
    # Shape of output: (batch_size * num_heads, no. of queries or key-value
    # pairs, num_hiddens / num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_output(X, num_heads):
    """Reverse the operation of transpose_qkv."""
    X = X.reshape(-1,num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

In [31]:
class MultiHeadAttention(nn.Module):
    """"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # queries，keys，values:
        # (batch_size，查询或者“键－值”对的个数，num_hiddens)
        # valid_lens　的形状:
        # (batch_size，)或(batch_size，查询的个数)
        # 经过变换后，输出的queries，keys，values　的形状:
        # (batch_size*num_heads，查询或者“键－值”对的个数，
        # num_hiddens/num_heads)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # 在轴0，将第一项（标量或者矢量）复制num_heads次，
            # 然后如此复制第二项，然后诸如此类。
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        # output的形状:(batch_size*num_heads，查询的个数，
        # num_hiddens/num_heads)
        output = self.attention(queries, keys, values, valid_lens)

        # output_concat的形状:(batch_size，查询的个数，num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

In [32]:
num_hiddens, num_heads = 10, 5
multi_head_attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5)
multi_head_attention.eval()

MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=10, out_features=10, bias=False)
  (W_k): Linear(in_features=10, out_features=10, bias=False)
  (W_v): Linear(in_features=10, out_features=10, bias=False)
  (W_o): Linear(in_features=10, out_features=10, bias=False)
)

In [33]:
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens =  6, torch.tensor([3, 2])

X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
output = multi_head_attention(X, Y, Y, valid_lens)

print(output)
print(output.shape)

tensor([[[ 0.5506, -0.1278, -0.5473,  0.1268,  0.0913,  0.0385, -0.0627,
           0.3436,  0.0286,  0.2243],
         [ 0.5506, -0.1278, -0.5473,  0.1268,  0.0913,  0.0385, -0.0627,
           0.3436,  0.0286,  0.2243],
         [ 0.5506, -0.1278, -0.5473,  0.1268,  0.0913,  0.0385, -0.0627,
           0.3436,  0.0286,  0.2243],
         [ 0.5506, -0.1278, -0.5473,  0.1268,  0.0913,  0.0385, -0.0627,
           0.3436,  0.0286,  0.2243]],

        [[ 0.5506, -0.1278, -0.5473,  0.1268,  0.0913,  0.0385, -0.0627,
           0.3436,  0.0286,  0.2243],
         [ 0.5506, -0.1278, -0.5473,  0.1268,  0.0913,  0.0385, -0.0627,
           0.3436,  0.0286,  0.2243],
         [ 0.5506, -0.1278, -0.5473,  0.1268,  0.0913,  0.0385, -0.0627,
           0.3436,  0.0286,  0.2243],
         [ 0.5506, -0.1278, -0.5473,  0.1268,  0.0913,  0.0385, -0.0627,
           0.3436,  0.0286,  0.2243]]], grad_fn=<UnsafeViewBackward0>)
torch.Size([2, 4, 10])


In [34]:
print(multi_head_attention.attention.attention_weights)
print(multi_head_attention.attention.attention_weights.shape)

tensor([[[0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000]],

        [[0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000]],

        [[0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000]],

        [[0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000]],

        [[0.3333, 0.3333, 0.3333, 0.0000, 0.0000

## self-attention

In [35]:
num_hiddens, num_heads = 10, 5
multi_head_attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                                   num_hiddens, num_heads, 0.5)
multi_head_attention.eval()

MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=10, out_features=10, bias=False)
  (W_k): Linear(in_features=10, out_features=10, bias=False)
  (W_v): Linear(in_features=10, out_features=10, bias=False)
  (W_o): Linear(in_features=10, out_features=10, bias=False)
)

In [36]:
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
output_after_attention = multi_head_attention(X, X, X, valid_lens)

print(output_after_attention)
print(output_after_attention.shape)

tensor([[[ 0.4852, -0.4486,  0.0234, -0.2439, -0.1019, -0.6457, -0.4172,
           0.2136, -0.6824,  0.4824],
         [ 0.4852, -0.4486,  0.0234, -0.2439, -0.1019, -0.6457, -0.4172,
           0.2136, -0.6824,  0.4824],
         [ 0.4852, -0.4486,  0.0234, -0.2439, -0.1019, -0.6457, -0.4172,
           0.2136, -0.6824,  0.4824],
         [ 0.4852, -0.4486,  0.0234, -0.2439, -0.1019, -0.6457, -0.4172,
           0.2136, -0.6824,  0.4824]],

        [[ 0.4852, -0.4486,  0.0234, -0.2439, -0.1019, -0.6457, -0.4172,
           0.2136, -0.6824,  0.4824],
         [ 0.4852, -0.4486,  0.0234, -0.2439, -0.1019, -0.6457, -0.4172,
           0.2136, -0.6824,  0.4824],
         [ 0.4852, -0.4486,  0.0234, -0.2439, -0.1019, -0.6457, -0.4172,
           0.2136, -0.6824,  0.4824],
         [ 0.4852, -0.4486,  0.0234, -0.2439, -0.1019, -0.6457, -0.4172,
           0.2136, -0.6824,  0.4824]]], grad_fn=<UnsafeViewBackward0>)
torch.Size([2, 4, 10])


In [37]:
print(multi_head_attention.attention.attention_weights)
print(multi_head_attention.attention.attention_weights.shape)

tensor([[[0.3333, 0.3333, 0.3333, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000]],

        [[0.3333, 0.3333, 0.3333, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000]],

        [[0.3333, 0.3333, 0.3333, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000]],

        [[0.3333, 0.3333, 0.3333, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000]],

        [[0.3333, 0.3333, 0.3333, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000],
         [0.3333, 0.3333, 0.3333, 0.0000]],

        [[0.5000, 0.5000, 0.0000, 0.0000],
         [0.5000, 0.5000, 0.0000, 0.0000],
         [0.5000, 0.5000, 0.0000, 0.0000],
 

# Reference
* https://d2l.ai/chapter_attention-mechanisms-and-transformers/index.html