# The Attention mechanism and its implementation in PyTorch

Computing self-attention of a sentence with GloVe embeddings and the `MultiheadAttention` class with PyTorch

Author: Pierre Nugues

## Modules

In [1]:
import torch
import torch.nn
import torch.nn.functional as F

## Noncontextual embeddings

We load GloVe

In [2]:
def read_embeddings(file):
    """
    Return the embeddings in the from of a dictionary
    :param file:
    :return:
    """
    embeddings = {}
    glove = open(file, encoding='utf8')
    for line in glove:
        values = line.strip().split()
        word = values[0]
        vector = torch.FloatTensor(list(map(float, values[1:])))
        embeddings[word] = vector
    glove.close()
    return embeddings

In [3]:
embedding_file = '/Users/pierre/Documents/Cours/EDAN20/corpus/glove.6B.50d.txt'
embeddings_dict = read_embeddings(embedding_file)

In [4]:
embeddings_dict['ship']

tensor([ 1.5213,  0.1052,  0.3816, -0.5080,  0.0324, -0.1348, -1.2474,  0.7981,
         0.8469, -1.1010,  0.8874,  1.3749,  0.4293,  0.6572, -0.2636, -0.4176,
        -0.4885,  0.9106, -1.7158, -0.4380,  0.7839,  0.1964, -0.4066, -0.5397,
         0.8244, -1.7434,  0.1428,  0.2804,  1.1688,  0.1690,  2.2271, -0.5827,
        -0.4572,  0.6281,  0.5444,  0.2846,  0.4448, -0.5534, -0.3649, -0.0164,
         0.4088, -0.8715,  1.5513, -0.8070, -0.1004, -0.2846, -0.3322, -0.5061,
         0.4827, -0.6620])

## Cosine similarity

Let us compute the cosine similarity of the words in a sentence:
> I must go back to my ship and to my crew

_Odyssey_, book I 

Remember that:
$$\cos(\mathbf{u}, \mathbf{v}) = \frac{\mathbf{u} \cdot \mathbf{v}}{||\mathbf{u}|| \cdot ||\mathbf{v} ||}$$

In [5]:
sentence_odyssey = 'I must go back to my ship and to my crew'
sentence_amazon = 'We process and ship your order'

In [6]:
words_a = sentence_amazon.lower().split()
words_o = sentence_odyssey.lower().split()
words_o

['i', 'must', 'go', 'back', 'to', 'my', 'ship', 'and', 'to', 'my', 'crew']

We build the embedding matrix

In [7]:
def embedding_matrix(words):
    embeddings_seq = []
    for word in words:
        embeddings_seq += [embeddings_dict[word]]
    embeddings_seq = torch.stack(embeddings_seq)
    return embeddings_seq

In [8]:
embeddings_seq_a = embedding_matrix(words_a)
embeddings_seq_o = embedding_matrix(words_o)

In [9]:
embeddings_seq_o.size()

torch.Size([11, 50])

In [10]:
embeddings_seq_o[0][:10]

tensor([ 1.1891e-01,  1.5255e-01, -8.2073e-02, -7.4144e-01,  7.5917e-01,
        -4.8328e-01, -3.1009e-01,  5.1476e-01, -9.8708e-01,  6.1757e-04])

We compute the attention scores as the pairwise cosines of the word embeddings

In [11]:
def attn_cos_scores(embeddings_seq):
    E_normed = embeddings_seq/torch.norm(embeddings_seq_o, dim=-1).reshape(-1, 1)
    attn_scores_cos = E_normed @ E_normed.T
    return attn_scores_cos

In [12]:
def print_cos_scores(words):
    embeddings = embedding_matrix(words)
    attn_scores_cos = attn_cos_scores(embeddings)
    print('\t', end='')
    for i in range(len(words)):
        print(words[i], end='\t')
    print()

    for i in range(attn_scores_cos.shape[0]):
        print(words[i], end='\t')
        for j in range(attn_scores_cos.shape[1]):
            print(f"{attn_scores_cos[i,j]:.2f}", end='\t')
        print()

In [13]:
print_cos_scores(words_o)

	i	must	go	back	to	my	ship	and	to	my	crew	
i	1.00	0.75	0.86	0.76	0.73	0.90	0.35	0.65	0.73	0.90	0.42	
must	0.75	1.00	0.85	0.68	0.87	0.69	0.42	0.69	0.87	0.69	0.45	
go	0.86	0.85	1.00	0.84	0.84	0.81	0.41	0.68	0.84	0.81	0.49	
back	0.76	0.68	0.84	1.00	0.83	0.76	0.49	0.77	0.83	0.76	0.51	
to	0.73	0.87	0.84	0.83	1.00	0.68	0.54	0.86	1.00	0.68	0.51	
my	0.90	0.69	0.81	0.76	0.68	1.00	0.38	0.63	0.68	1.00	0.44	
ship	0.35	0.42	0.41	0.49	0.54	0.38	1.00	0.46	0.54	0.38	0.78	
and	0.65	0.69	0.68	0.77	0.86	0.63	0.46	1.00	0.86	0.63	0.49	
to	0.73	0.87	0.84	0.83	1.00	0.68	0.54	0.86	1.00	0.68	0.51	
my	0.90	0.69	0.81	0.76	0.68	1.00	0.38	0.63	0.68	1.00	0.44	
crew	0.42	0.45	0.49	0.51	0.51	0.44	0.78	0.49	0.51	0.44	1.00	


## Contextual embeddings

We design a new vector representation for _ship_ so that it receives an influence from _crew_ and the other words of its context. This influence will depend on the embeddings from te context. Let us use the cosine similarities as attention scores

In [14]:
attn_cos_scores(embeddings_seq_o)[6]

tensor([0.3466, 0.4178, 0.4068, 0.4853, 0.5401, 0.3791, 1.0000, 0.4586, 0.5401,
        0.3791, 0.7848])

We compute the new embeddings as the sum of the noncontextual embeddings weighted by the cosine similarity. We have contextual embeddings.

In [15]:
new_embeddings_ship = (0.35 * embeddings_dict['i'] + 
                  0.42 * embeddings_dict['must'] + 
                  0.41 * embeddings_dict['go'] +
                  0.49 * embeddings_dict['back'] +
                  0.54 * embeddings_dict['to'] + 
                  0.38 * embeddings_dict['my'] +
                  1.00 * embeddings_dict['ship'] +
                  0.46 * embeddings_dict['and'] +
                  0.54 * embeddings_dict['to'] +
                  0.38 * embeddings_dict['my'] +
                  0.78 * embeddings_dict['crew'])
new_embeddings_ship

tensor([  3.2289,   0.6422,   1.4712,  -2.3538,   2.2414,  -0.4237,  -4.1052,
          2.6216,   0.1719,  -2.4324,   1.3882,   3.7241,  -1.9721,   1.1893,
          2.2511,   0.9502,  -0.7646,   1.0289,  -3.0553,  -3.6306,   0.8305,
          2.9299,   1.3221,  -0.7092,   2.9745, -10.5959,  -1.3168,   0.2059,
          3.5457,  -2.7711,  18.2672,   2.4817,  -3.5887,   0.3297,   1.2718,
          0.6539,   1.5873,   0.0195,   0.7724,  -1.4620,  -0.2067,  -1.2464,
          2.1504,  -0.1811,  -0.5026,  -0.2888,  -0.5060,  -1.9676,  -0.0605,
         -0.6725])

Exact computation with torch

In [16]:
(attn_cos_scores(embeddings_seq_o) @ embeddings_seq_o)[6]

tensor([ 3.2319e+00,  6.4082e-01,  1.4718e+00, -2.3434e+00,  2.2358e+00,
        -4.1877e-01, -4.1002e+00,  2.6211e+00,  1.8010e-01, -2.4360e+00,
         1.3923e+00,  3.7188e+00, -1.9603e+00,  1.1980e+00,  2.2394e+00,
         9.3763e-01, -7.7049e-01,  1.0349e+00, -3.0615e+00, -3.6259e+00,
         8.3401e-01,  2.9281e+00,  1.3165e+00, -7.1303e-01,  2.9667e+00,
        -1.0567e+01, -1.3099e+00,  2.0283e-01,  3.5362e+00, -2.7571e+00,
         1.8220e+01,  2.4698e+00, -3.5804e+00,  3.2604e-01,  1.2760e+00,
         6.5701e-01,  1.5889e+00,  1.1571e-02,  7.6620e-01, -1.4560e+00,
        -2.0362e-01, -1.2484e+00,  2.1550e+00, -1.8767e-01, -5.0253e-01,
        -2.9128e-01, -5.1006e-01, -1.9596e+00, -5.8853e-02, -6.7380e-01])

## Self-attention

Vaswani et al. (2017) defined attention as:
$$
\text{Attention}({Q}, {K}, {Q}) = \text{softmax}(\frac{{Q}  {K}^\intercal}{\sqrt{d_k}})  {V},
$$
where
$$
\begin{array}{lcl}
{Q} &=& {X} {W}_Q,   \\
{K} &=& {X} {W}_K , \\
{V} &=& {X} {W}_V.\\
\end{array}
$$
and ${X}$ represents complete input sequence (all the tokens).

$d_k$ is the dimension of the input and $\sqrt{d_k}$ a scaling factor. The $\text{softmax}$ function is defined as:
$$
\text{softmax}(x_1, x_2, ..., x_j, ..., x_n) = (\frac{e^{x_1}}{\sum_{i=1}^n e^{x_i}}, \frac{e^{x_2}}{\sum_{i=1}^n e^{x_i}}, ..., \frac{e^{x_j}}{\sum_{i=1}^n e^{x_i}}, ..., \frac{e^{x_n}}{\sum_{i=1}^n e^{x_i}})
$$

We omit the weight matrices and we use the same embeddings for ${Q}$, ${K}$, and ${Q}$: GloVe embeddings

For the matrix above, self attention, $\text{softmax}(\frac{{Q}  {K}^\intercal}{\sqrt{d_k}})$,  for _ship_ yields:

In [17]:
dk = embeddings_dict['i'].size()[0]
dk = torch.tensor(dk)
dk

tensor(50)

In [18]:
attn_scores_o = F.softmax(embeddings_seq_o @ embeddings_seq_o.T/torch.sqrt(dk), dim=-1)
attn_scores_o[6]

tensor([0.0303, 0.0302, 0.0276, 0.0407, 0.0459, 0.0343, 0.5530, 0.0297, 0.0459,
        0.0343, 0.1281])

The scaled and normalized attention scores

In [19]:
def print_attn_scores(words):
    embeddings = embedding_matrix(words)
    sent_length, dk = embeddings.size()
    attn_scores = F.softmax(embeddings @ embeddings.T/torch.sqrt(torch.tensor(dk)), dim=-1)
    print('\t', end='')
    for i in range(sent_length):
        print(words[i], end='\t')
    print()
    for i in range(sent_length):
        print(words[i], end='\t')
        for j in range(sent_length):
            print(f"{attn_scores[i,j]:.2f}", end='\t')
        print()

In [20]:
print_attn_scores(words_o)

	i	must	go	back	to	my	ship	and	to	my	crew	
i	0.36	0.05	0.07	0.05	0.04	0.19	0.01	0.02	0.04	0.19	0.01	
must	0.14	0.20	0.10	0.06	0.11	0.10	0.03	0.05	0.11	0.10	0.02	
go	0.18	0.09	0.14	0.09	0.08	0.13	0.02	0.04	0.08	0.13	0.02	
back	0.14	0.05	0.09	0.19	0.08	0.12	0.03	0.06	0.08	0.12	0.03	
to	0.11	0.11	0.09	0.09	0.15	0.08	0.04	0.07	0.15	0.08	0.03	
my	0.19	0.03	0.05	0.04	0.03	0.29	0.01	0.02	0.03	0.29	0.01	
ship	0.03	0.03	0.03	0.04	0.05	0.03	0.55	0.03	0.05	0.03	0.13	
and	0.10	0.08	0.07	0.10	0.12	0.09	0.04	0.15	0.12	0.09	0.04	
to	0.11	0.11	0.09	0.09	0.15	0.08	0.04	0.07	0.15	0.08	0.03	
my	0.19	0.03	0.05	0.04	0.03	0.29	0.01	0.02	0.03	0.29	0.01	
crew	0.06	0.05	0.05	0.06	0.05	0.06	0.21	0.04	0.05	0.06	0.31	


For _ship:_

In [21]:
attn_scores_o[6]

tensor([0.0303, 0.0302, 0.0276, 0.0407, 0.0459, 0.0343, 0.5530, 0.0297, 0.0459,
        0.0343, 0.1281])

We have the weights of 55% for _ship_ and 13% for _crew_, the rest from the other words.

And the new contextual embedding is for _ship_ is a linear combination:

In [22]:
self_attention_ship = (0.03 * embeddings_dict['i'] + 
                  0.03 * embeddings_dict['must'] + 
                  0.03 * embeddings_dict['go'] +
                  0.04 * embeddings_dict['back'] +
                  0.05 * embeddings_dict['to'] + 
                  0.03 * embeddings_dict['my'] +
                  0.55 * embeddings_dict['ship'] +
                  0.03 * embeddings_dict['and'] +
                  0.05 * embeddings_dict['to'] +
                  0.03 * embeddings_dict['my'] +
                  0.13 * embeddings_dict['crew'])
self_attention_ship

tensor([ 1.0442,  0.0966,  0.3467, -0.4238,  0.2203, -0.0956, -0.9915,  0.6637,
         0.4368, -0.7943,  0.5639,  0.9838,  0.0240,  0.5066,  0.0732, -0.1740,
        -0.3322,  0.5614, -1.1613, -0.5717,  0.4356,  0.4120, -0.0659, -0.3336,
         0.6579, -1.7421, -0.0344,  0.1440,  0.8547, -0.1430,  2.6614, -0.0553,
        -0.5376,  0.3057,  0.4068,  0.2231,  0.3959, -0.2940, -0.1163, -0.1340,
         0.1709, -0.5332,  0.9552, -0.4178, -0.1058, -0.1715, -0.2251, -0.3923,
         0.2098, -0.3625])

Exact and complete computation of the whole matrix with torch of 
$$
\text{softmax}(\frac{{Q}  {K}^\intercal}{\sqrt{d_k}})  {V} :
$$

In [23]:
self_attention_output_o = attn_scores_o @ embeddings_seq_o

The contextual embeddings for _ship:_

In [24]:
self_attention_output_o[6]

tensor([ 1.0387,  0.1033,  0.3426, -0.4320,  0.2237, -0.0958, -0.9926,  0.6662,
         0.4424, -0.7942,  0.5638,  0.9921,  0.0205,  0.5082,  0.0743, -0.1773,
        -0.3408,  0.5675, -1.1545, -0.5718,  0.4288,  0.4191, -0.0658, -0.3339,
         0.6682, -1.7473, -0.0485,  0.1531,  0.8642, -0.1447,  2.6571, -0.0545,
        -0.5343,  0.3160,  0.4041,  0.2277,  0.3958, -0.2916, -0.1126, -0.1385,
         0.1744, -0.5375,  0.9499, -0.4145, -0.1039, -0.1755, -0.2213, -0.3995,
         0.2119, -0.3610])

We can now write a `self_attention` function 

In [None]:
def self_attention(input_seq):
    dk = torch.tensor(input_seq.size()[-1])
    attn_scores = F.softmax(input_seq @ input_seq.T/torch.sqrt(dk), dim=-1)
    attn_output = attn_scores @ input_seq
    return attn_output, attn_scores

The word _ship_ in another context: _We process and ship your order_

In [67]:
attention_output_a, attn_scores_a = self_attention(embeddings_seq_a)

Attention scores for _ship:_

In [68]:
attn_scores_a[3]

tensor([0.0431, 0.0258, 0.0419, 0.7811, 0.0490, 0.0590])

In [69]:
print_attn_scores(words_a)

	we	process	and	ship	your	order	
we	0.61	0.06	0.06	0.02	0.20	0.05	
process	0.17	0.50	0.08	0.03	0.11	0.11	
and	0.22	0.12	0.30	0.08	0.15	0.13	
ship	0.04	0.03	0.04	0.78	0.05	0.06	
your	0.14	0.03	0.03	0.02	0.74	0.04	
order	0.16	0.13	0.10	0.09	0.18	0.34	


The new contextual embeddings for _ship:_

In [70]:
attention_output_a[3]

tensor([ 1.2758,  0.1034,  0.2720, -0.4776,  0.1746, -0.1060, -0.9901,  0.6328,
         0.6967, -0.8847,  0.7106,  1.2264,  0.2491,  0.5023, -0.1277, -0.2361,
        -0.3709,  0.6545, -1.2587, -0.5332,  0.6681,  0.1687, -0.2567, -0.4218,
         0.6960, -1.7077, -0.0052,  0.1572,  1.0763,  0.0410,  2.5467, -0.3418,
        -0.5414,  0.4175,  0.4147,  0.2666,  0.3770, -0.4228, -0.2462, -0.0377,
         0.3202, -0.7298,  1.2020, -0.5636, -0.0899, -0.1845, -0.2390, -0.4307,
         0.3828, -0.4905])

## PyTorch implementation
 
PyTorch has an implementation of self-attention encapsulated in the `MultiheadAttention` class. Before going to the attention module, the query, key value, goes through a linear layer. The output also goes through a linear layer. These three layers are initialized with Xavier's algorithm.

In [30]:
from torch.nn import MultiheadAttention

att_layer = MultiheadAttention(50, 
                               1,
                               bias=False,
                               batch_first=True)

In [31]:
(attn_output, attn_scores) = att_layer(embeddings_seq_o, embeddings_seq_o, embeddings_seq_o)

The attention score for _ship:_

In [32]:
attn_scores[6]

tensor([0.1005, 0.0725, 0.0798, 0.0987, 0.0696, 0.0930, 0.0930, 0.0785, 0.0696,
        0.0930, 0.1516], grad_fn=<SelectBackward0>)

In [33]:
attn_output

tensor([[-1.3139e-01, -3.6727e-01, -1.8204e-01, -2.4207e-01,  1.4971e-01,
         -1.1306e-01,  2.4283e-01, -2.1270e-01,  3.5660e-01,  5.4594e-01,
          1.6515e-01, -1.4105e-01,  3.6244e-02, -8.5999e-02, -1.8266e-01,
          3.8833e-01, -3.4812e-01, -2.7353e-01, -3.2677e-01, -6.7322e-01,
          4.3488e-02, -6.8552e-02,  3.9725e-01, -2.1517e-01,  1.5552e-02,
          1.5287e-01,  1.8118e-01, -2.8809e-01,  2.4120e-01, -2.8585e-01,
         -7.9411e-02, -1.3777e-01, -1.9382e-01, -1.9365e-03, -4.6244e-02,
          1.1650e-01,  1.3885e-01, -2.8046e-01,  1.0420e-01, -5.4482e-01,
          2.2524e-01, -2.3087e-01, -3.8066e-01,  2.4418e-01, -4.7407e-01,
         -4.1701e-01, -3.1227e-01, -1.4631e-01,  1.1491e-01,  2.2806e-01],
        [-1.2243e-01, -3.8173e-01, -1.9070e-01, -2.4583e-01,  1.4383e-01,
         -1.1336e-01,  2.5795e-01, -2.0619e-01,  3.6367e-01,  5.5137e-01,
          1.5844e-01, -1.5577e-01,  4.2192e-02, -7.6187e-02, -1.7239e-01,
          3.9157e-01, -3.4213e-01, -2

### The initial dense layers

The weight initial values with the 4 matrices

In [34]:
att_layer.state_dict()

OrderedDict([('in_proj_weight',
              tensor([[ 0.0831,  0.0224,  0.1531,  ...,  0.1413,  0.0540,  0.0741],
                      [ 0.0662,  0.0010,  0.1485,  ...,  0.0335,  0.0900,  0.1510],
                      [ 0.0128, -0.0897, -0.0168,  ...,  0.1347, -0.0138,  0.1366],
                      ...,
                      [ 0.0172,  0.0685,  0.1596,  ..., -0.0977, -0.1345,  0.1724],
                      [ 0.0812, -0.0550,  0.0471,  ...,  0.1455,  0.0330,  0.1082],
                      [-0.0687,  0.1496, -0.0260,  ..., -0.0008,  0.0710, -0.1402]])),
             ('out_proj.weight',
              tensor([[-0.0877,  0.1369,  0.0201,  ..., -0.0538,  0.0559, -0.0571],
                      [-0.0083, -0.0460, -0.0495,  ...,  0.1036, -0.0824, -0.0321],
                      [ 0.0777,  0.0244,  0.0257,  ...,  0.1305, -0.0999, -0.1050],
                      ...,
                      [ 0.0337, -0.0537,  0.0896,  ..., -0.0816,  0.0030,  0.0837],
                      [ 0.0438,  0.035

The three input matrices are concatenated

In [35]:
att_layer.state_dict()['in_proj_weight'].size()

torch.Size([150, 50])

The output matrix

In [36]:
att_layer.state_dict()['out_proj.weight'].size()

torch.Size([50, 50])

### By-passing the dense layers

We create identity matrices to pass through the dense layers and recover the attention values and scores

In [37]:
i_50 = torch.eye(50)
i_50

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]])

In [38]:
att_layer.state_dict()['out_proj.weight'][:] = i_50

In [39]:
att_layer.state_dict()['in_proj_weight'].size()

torch.Size([150, 50])

In [40]:
i_3_50 = torch.vstack((i_50, i_50, i_50))
i_3_50.size()

torch.Size([150, 50])

In [41]:
att_layer.state_dict()['in_proj_weight'][:] = i_3_50

In [42]:
att_layer.state_dict()

OrderedDict([('in_proj_weight',
              tensor([[1., 0., 0.,  ..., 0., 0., 0.],
                      [0., 1., 0.,  ..., 0., 0., 0.],
                      [0., 0., 1.,  ..., 0., 0., 0.],
                      ...,
                      [0., 0., 0.,  ..., 1., 0., 0.],
                      [0., 0., 0.,  ..., 0., 1., 0.],
                      [0., 0., 0.,  ..., 0., 0., 1.]])),
             ('out_proj.weight',
              tensor([[1., 0., 0.,  ..., 0., 0., 0.],
                      [0., 1., 0.,  ..., 0., 0., 0.],
                      [0., 0., 1.,  ..., 0., 0., 0.],
                      ...,
                      [0., 0., 0.,  ..., 1., 0., 0.],
                      [0., 0., 0.,  ..., 0., 1., 0.],
                      [0., 0., 0.,  ..., 0., 0., 1.]]))])

### Multihead attention without the dense layers

We obtain now the same results as the `self_attention()` function for _ship:_

The attention scores for _ship:_

In [43]:
(attn_output, attn_scores) = att_layer(embeddings_seq_o, embeddings_seq_o, embeddings_seq_o)

The attention vector for _ship:_

In [46]:
attn_scores[6]

tensor([0.0303, 0.0302, 0.0276, 0.0407, 0.0459, 0.0343, 0.5530, 0.0297, 0.0459,
        0.0343, 0.1281], grad_fn=<SelectBackward0>)

The embedding vector for _ship_

In [45]:
attn_output[6]

tensor([ 1.0387,  0.1033,  0.3426, -0.4320,  0.2237, -0.0958, -0.9926,  0.6662,
         0.4424, -0.7942,  0.5638,  0.9921,  0.0205,  0.5082,  0.0743, -0.1773,
        -0.3408,  0.5674, -1.1545, -0.5718,  0.4288,  0.4191, -0.0658, -0.3339,
         0.6682, -1.7473, -0.0485,  0.1531,  0.8642, -0.1447,  2.6571, -0.0545,
        -0.5343,  0.3160,  0.4041,  0.2277,  0.3958, -0.2916, -0.1126, -0.1385,
         0.1744, -0.5375,  0.9499, -0.4145, -0.1039, -0.1755, -0.2213, -0.3995,
         0.2119, -0.3610], grad_fn=<SelectBackward0>)

## Test with a simple matrix

Three words, dimension of embeddings: 4

In [87]:
test_input_sequence = torch.tensor([[1.0, 0.0, 0.0, 1.0],
                                 [0.0, 1.5, 1.0, 1.0],
                                 [0.0, 1.0, 1.0, 1.0]])

In [88]:
test_input_sequence.size()

torch.Size([3, 4])

### Self-attention from the book

In [89]:
self_attention(test_input_sequence)

(tensor([[0.4519, 0.6852, 0.5481, 1.0000],
         [0.1045, 1.1609, 0.8955, 1.0000],
         [0.1387, 1.1034, 0.8613, 1.0000]]),
 tensor([[0.4519, 0.2741, 0.2741],
         [0.1045, 0.5307, 0.3648],
         [0.1387, 0.4842, 0.3771]]))

### Multihead attention from PyTorch

In [90]:
att_layer = MultiheadAttention(4, 
                               1, 
                               bias=False)

The multihead attention uses a Xavier initialization of the dense layers. The results will be different for those of `self_attention()`

In [91]:
att_layer(test_input_sequence, 
          test_input_sequence,
          test_input_sequence)

(tensor([[-1.0728, -0.8113, -0.4689,  0.6076],
         [-1.0398, -0.7560, -0.4240,  0.6074],
         [-1.0482, -0.7704, -0.4357,  0.6072]], grad_fn=<SqueezeBackward1>),
 tensor([[0.2403, 0.4111, 0.3487],
         [0.3710, 0.3092, 0.3198],
         [0.3352, 0.3318, 0.3331]], grad_fn=<SqueezeBackward1>))

Weights of the dense layers

In [92]:
att_layer.state_dict()

OrderedDict([('in_proj_weight',
              tensor([[ 6.1394e-02, -6.1015e-01,  4.7498e-01,  1.9639e-01],
                      [-5.7650e-01,  4.3781e-01,  4.9654e-03, -3.5933e-01],
                      [-4.1028e-01, -4.8474e-02, -5.4061e-01,  5.4894e-01],
                      [ 5.9218e-01,  4.1155e-01, -3.4587e-01,  2.5921e-01],
                      [-5.4377e-04,  5.5142e-02,  4.4350e-01,  3.0742e-01],
                      [-1.3947e-01, -5.4506e-01, -2.2638e-01,  5.4712e-01],
                      [ 5.2317e-01,  2.6272e-01,  2.5883e-01, -3.7431e-01],
                      [-4.4056e-01,  1.1487e-01, -5.2629e-01,  2.9210e-01],
                      [-4.6415e-01, -3.9766e-01, -5.8648e-01, -6.0914e-01],
                      [ 3.8215e-01,  9.8802e-02,  4.3574e-01,  2.6193e-01],
                      [-4.5063e-01, -5.5486e-01, -4.9369e-01, -3.5490e-01],
                      [-5.2275e-02,  6.0433e-01,  4.1626e-01, -8.4171e-02]])),
             ('out_proj.weight',
              tensor

### By-passing the dense layers

We use weights of identity matrices

In [93]:
i_4 = torch.eye(4)

In [94]:
att_layer.state_dict()['out_proj.weight'][:] = i_4

In [95]:
i_3_4 = torch.vstack((i_4, i_4, i_4))
i_3_4.size()

torch.Size([12, 4])

We set these weights

In [96]:
att_layer.state_dict()['in_proj_weight'][:] = i_3_4

In [97]:
att_layer.state_dict()

OrderedDict([('in_proj_weight',
              tensor([[1., 0., 0., 0.],
                      [0., 1., 0., 0.],
                      [0., 0., 1., 0.],
                      [0., 0., 0., 1.],
                      [1., 0., 0., 0.],
                      [0., 1., 0., 0.],
                      [0., 0., 1., 0.],
                      [0., 0., 0., 1.],
                      [1., 0., 0., 0.],
                      [0., 1., 0., 0.],
                      [0., 0., 1., 0.],
                      [0., 0., 0., 1.]])),
             ('out_proj.weight',
              tensor([[1., 0., 0., 0.],
                      [0., 1., 0., 0.],
                      [0., 0., 1., 0.],
                      [0., 0., 0., 1.]]))])

Now we have the same results as with `self_attention()`

In [98]:
att_layer(test_input_sequence, 
          test_input_sequence,
          test_input_sequence)

(tensor([[0.4519, 0.6852, 0.5481, 1.0000],
         [0.1045, 1.1609, 0.8955, 1.0000],
         [0.1387, 1.1034, 0.8613, 1.0000]], grad_fn=<SqueezeBackward1>),
 tensor([[0.4519, 0.2741, 0.2741],
         [0.1045, 0.5307, 0.3648],
         [0.1387, 0.4842, 0.3771]], grad_fn=<SqueezeBackward1>))