# Chapter 15: Self-Attention and Transformers
## The Attention mechanism and its implementation in PyTorch
Computing self-attention of a sentence with GloVe embeddings and the `MultiheadAttention` class with PyTorch

Programs from the book: [_Python for Natural Language Processing_](https://link.springer.com/book/9783031575488)

__Author__: Pierre Nugues

## Modules

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
torch.manual_seed(1234)

<torch._C.Generator at 0x10cf074f0>

## Noncontextual embeddings

We load GloVe

In [3]:
def read_embeddings(file):
    """
    Return the embeddings in the from of a dictionary
    :param file:
    :return:
    """
    embeddings = {}
    with open(file, encoding='utf8') as glove:
        for line in glove:
            values = line.strip().split()
            word = values[0]
            vector = [float(value) for value in values[1:]]
            vector = torch.FloatTensor(vector)
            embeddings[word] = vector
    return embeddings

In [4]:
# PATH = '../../corpus/'
PATH = '../datasets/embeddings/'

In [5]:
embedding_file = PATH + 'glove.6B.50d.txt'
embeddings_dict = read_embeddings(embedding_file)

In [6]:
embeddings_dict['ship']

tensor([ 1.5213,  0.1052,  0.3816, -0.5080,  0.0324, -0.1348, -1.2474,  0.7981,
         0.8469, -1.1010,  0.8874,  1.3749,  0.4293,  0.6572, -0.2636, -0.4176,
        -0.4885,  0.9106, -1.7158, -0.4380,  0.7839,  0.1964, -0.4066, -0.5397,
         0.8244, -1.7434,  0.1428,  0.2804,  1.1688,  0.1690,  2.2271, -0.5827,
        -0.4572,  0.6281,  0.5444,  0.2846,  0.4448, -0.5534, -0.3649, -0.0164,
         0.4088, -0.8715,  1.5513, -0.8070, -0.1004, -0.2846, -0.3322, -0.5061,
         0.4827, -0.6620])

## Cosine similarity

Let us compute the cosine similarity of the words in a sentence:
> I must go back to my ship and to my crew

_Odyssey_, book I 

Remember that:
$$\cos(\mathbf{u}, \mathbf{v}) = \frac{\mathbf{u} \cdot \mathbf{v}}{||\mathbf{u}|| \cdot ||\mathbf{v} ||}$$

In [7]:
sentence_odyssey = 'I must go back to my ship and to my crew'
sentence_amazon = 'We process and ship your order'
# in the most cost-efficient way possible

In [8]:
words_a = sentence_amazon.lower().split()
words_o = sentence_odyssey.lower().split()
words_o

['i', 'must', 'go', 'back', 'to', 'my', 'ship', 'and', 'to', 'my', 'crew']

We build the embedding matrix

In [9]:
def embedding_matrix(words, embeddings_dict):
    embeddings_seq = [embeddings_dict[word] for word in words]
    embeddings_seq = torch.stack(embeddings_seq)
    return embeddings_seq

In [10]:
X_a = embedding_matrix(words_a, embeddings_dict)
X_o = embedding_matrix(words_o, embeddings_dict)

In [11]:
X_o.size()

torch.Size([11, 50])

In [12]:
X_a.size()

torch.Size([6, 50])

In [13]:
X_o[0][:10]

tensor([ 1.1891e-01,  1.5255e-01, -8.2073e-02, -7.4144e-01,  7.5917e-01,
        -4.8328e-01, -3.1009e-01,  5.1476e-01, -9.8708e-01,  6.1757e-04])

We compute the attention weights as the pairwise cosines of the word embeddings

In [14]:
def attn_cos_weights(embeddings_mat):
    E_normed = F.normalize(embeddings_mat)
    attn_weights_cos = E_normed @ E_normed.T
    return attn_weights_cos

In [15]:
def print_cos_weights(words, embeddings_dict):
    embeddings = embedding_matrix(words, embeddings_dict)
    attn_weights_cos = attn_cos_weights(embeddings)
    print('\t', end='')
    for i in range(len(words)):
        print(words[i], end='\t')
    print()

    for i in range(attn_weights_cos.shape[0]):
        print(words[i], end='\t')
        for j in range(attn_weights_cos.shape[1]):
            print(f"{attn_weights_cos[i, j]:.2f}", end='\t')
        print()

In [16]:
print_cos_weights(words_o, embeddings_dict)

	i	must	go	back	to	my	ship	and	to	my	crew	
i	1.00	0.75	0.86	0.76	0.73	0.90	0.35	0.65	0.73	0.90	0.42	
must	0.75	1.00	0.85	0.68	0.87	0.69	0.42	0.69	0.87	0.69	0.45	
go	0.86	0.85	1.00	0.84	0.84	0.81	0.41	0.68	0.84	0.81	0.49	
back	0.76	0.68	0.84	1.00	0.83	0.76	0.49	0.77	0.83	0.76	0.51	
to	0.73	0.87	0.84	0.83	1.00	0.68	0.54	0.86	1.00	0.68	0.51	
my	0.90	0.69	0.81	0.76	0.68	1.00	0.38	0.63	0.68	1.00	0.44	
ship	0.35	0.42	0.41	0.49	0.54	0.38	1.00	0.46	0.54	0.38	0.78	
and	0.65	0.69	0.68	0.77	0.86	0.63	0.46	1.00	0.86	0.63	0.49	
to	0.73	0.87	0.84	0.83	1.00	0.68	0.54	0.86	1.00	0.68	0.51	
my	0.90	0.69	0.81	0.76	0.68	1.00	0.38	0.63	0.68	1.00	0.44	
crew	0.42	0.45	0.49	0.51	0.51	0.44	0.78	0.49	0.51	0.44	1.00	


In [17]:
print_cos_weights(words_a, embeddings_dict)

	we	process	and	ship	your	order	
we	1.00	0.64	0.70	0.36	0.75	0.64	
process	0.64	1.00	0.61	0.29	0.52	0.67	
and	0.70	0.61	1.00	0.46	0.58	0.69	
ship	0.36	0.29	0.46	1.00	0.37	0.52	
your	0.75	0.52	0.58	0.37	1.00	0.63	
order	0.64	0.67	0.69	0.52	0.63	1.00	


## Contextual embeddings

We design a new vector representation for _ship_ so that it receives an influence from _crew_ and the other words of its context. This influence will depend on the embeddings from te context. Let us use the cosine similarities as attention weights

In [18]:
attn_cos_weights(X_o)[6]

tensor([0.3466, 0.4178, 0.4068, 0.4853, 0.5401, 0.3791, 1.0000, 0.4586, 0.5401,
        0.3791, 0.7848])

We compute the new embeddings as the sum of the noncontextual embeddings weighted by the cosine similarity. We have contextual embeddings.

In [19]:
new_embeddings_ship = (0.35 * embeddings_dict['i'] +
                       0.42 * embeddings_dict['must'] +
                       0.41 * embeddings_dict['go'] +
                       0.49 * embeddings_dict['back'] +
                       0.54 * embeddings_dict['to'] +
                       0.38 * embeddings_dict['my'] +
                       1.00 * embeddings_dict['ship'] +
                       0.46 * embeddings_dict['and'] +
                       0.54 * embeddings_dict['to'] +
                       0.38 * embeddings_dict['my'] +
                       0.78 * embeddings_dict['crew'])
new_embeddings_ship

tensor([  3.2289,   0.6422,   1.4712,  -2.3538,   2.2414,  -0.4237,  -4.1052,
          2.6216,   0.1719,  -2.4324,   1.3882,   3.7241,  -1.9721,   1.1893,
          2.2511,   0.9502,  -0.7646,   1.0289,  -3.0553,  -3.6306,   0.8305,
          2.9299,   1.3221,  -0.7092,   2.9745, -10.5959,  -1.3168,   0.2059,
          3.5457,  -2.7711,  18.2672,   2.4817,  -3.5887,   0.3297,   1.2718,
          0.6539,   1.5873,   0.0195,   0.7724,  -1.4620,  -0.2067,  -1.2464,
          2.1504,  -0.1811,  -0.5026,  -0.2888,  -0.5060,  -1.9676,  -0.0605,
         -0.6725])

Exact computation with torch

In [20]:
(attn_cos_weights(X_o) @ X_o)[6]

tensor([ 3.2319e+00,  6.4082e-01,  1.4718e+00, -2.3434e+00,  2.2358e+00,
        -4.1877e-01, -4.1002e+00,  2.6211e+00,  1.8010e-01, -2.4360e+00,
         1.3923e+00,  3.7188e+00, -1.9603e+00,  1.1980e+00,  2.2394e+00,
         9.3763e-01, -7.7049e-01,  1.0349e+00, -3.0615e+00, -3.6259e+00,
         8.3401e-01,  2.9281e+00,  1.3165e+00, -7.1303e-01,  2.9667e+00,
        -1.0567e+01, -1.3099e+00,  2.0283e-01,  3.5362e+00, -2.7571e+00,
         1.8220e+01,  2.4698e+00, -3.5804e+00,  3.2604e-01,  1.2760e+00,
         6.5701e-01,  1.5889e+00,  1.1571e-02,  7.6620e-01, -1.4560e+00,
        -2.0362e-01, -1.2484e+00,  2.1550e+00, -1.8767e-01, -5.0253e-01,
        -2.9128e-01, -5.1006e-01, -1.9596e+00, -5.8853e-02, -6.7380e-01])

## Self-attention

Vaswani et al. (2017) defined attention as:
$$
\text{Attention}({Q}, {K}, {Q}) = \text{softmax}(\frac{{Q}  {K}^\intercal}{\sqrt{d_k}})  {V},
$$
where
$$
\begin{array}{lcl}
{Q} &=& {X} {W}^Q,   \\
{K} &=& {X} {W}^K , \\
{V} &=& {X} {W}^V.\\
\end{array}
$$
and ${X}$ represents complete input sequence (all the tokens).

$d_k$ is the dimension of the input and $\sqrt{d_k}$Â a scaling factor. The $\text{softmax}$ function is defined as:
$$
\text{softmax}(x_1, x_2, ..., x_j, ..., x_n) = (\frac{e^{x_1}}{\sum_{i=1}^n e^{x_i}}, \frac{e^{x_2}}{\sum_{i=1}^n e^{x_i}}, ..., \frac{e^{x_j}}{\sum_{i=1}^n e^{x_i}}, ..., \frac{e^{x_n}}{\sum_{i=1}^n e^{x_i}})
$$

We omit the weight matrices and we use the same embeddings for ${Q}$, ${K}$, and ${Q}$: GloVe embeddings

For the matrix above, self attention, $\text{softmax}(\frac{{Q}  {K}^\intercal}{\sqrt{d_k}})$,  for _ship_ yields:

In [21]:
dk = embeddings_dict['i'].size(dim=0)
dk = torch.tensor(dk)
dk

tensor(50)

In [22]:
attn_weights_o = F.softmax(
    X_o @ X_o.T/torch.sqrt(dk), dim=-1)
attn_weights_o[6]

tensor([0.0303, 0.0302, 0.0276, 0.0407, 0.0459, 0.0343, 0.5530, 0.0297, 0.0459,
        0.0343, 0.1281])

We simplify with a scale factor of `dk**-0.5`

In [23]:
attn_weights_o = F.softmax(
    X_o @ X_o.T * dk ** -0.5, dim=-1)
attn_weights_o[6]

tensor([0.0303, 0.0302, 0.0276, 0.0407, 0.0459, 0.0343, 0.5530, 0.0297, 0.0459,
        0.0343, 0.1281])

The scaled and normalized attention weights

In [24]:
def print_attn_weights(words, embeddings_dict):
    embeddings = embedding_matrix(words, embeddings_dict)
    sent_length, dk = embeddings.size()
    attn_weights = F.softmax(embeddings @ embeddings.T * dk ** -0.5, dim=-1)
    print('\t', end='')
    for i in range(sent_length):
        print(words[i], end='\t')
    print()
    for i in range(sent_length):
        print(words[i], end='\t')
        for j in range(sent_length):
            print(f"{attn_weights[i, j]:.2f}", end='\t')
        print()

In [25]:
print_attn_weights(words_o, embeddings_dict)

	i	must	go	back	to	my	ship	and	to	my	crew	
i	0.36	0.05	0.07	0.05	0.04	0.19	0.01	0.02	0.04	0.19	0.01	
must	0.14	0.20	0.10	0.06	0.11	0.10	0.03	0.05	0.11	0.10	0.02	
go	0.18	0.09	0.14	0.09	0.08	0.13	0.02	0.04	0.08	0.13	0.02	
back	0.14	0.05	0.09	0.19	0.08	0.12	0.03	0.06	0.08	0.12	0.03	
to	0.11	0.11	0.09	0.09	0.15	0.08	0.04	0.07	0.15	0.08	0.03	
my	0.19	0.03	0.05	0.04	0.03	0.29	0.01	0.02	0.03	0.29	0.01	
ship	0.03	0.03	0.03	0.04	0.05	0.03	0.55	0.03	0.05	0.03	0.13	
and	0.10	0.08	0.07	0.10	0.12	0.09	0.04	0.15	0.12	0.09	0.04	
to	0.11	0.11	0.09	0.09	0.15	0.08	0.04	0.07	0.15	0.08	0.03	
my	0.19	0.03	0.05	0.04	0.03	0.29	0.01	0.02	0.03	0.29	0.01	
crew	0.06	0.05	0.05	0.06	0.05	0.06	0.21	0.04	0.05	0.06	0.31	


For _ship:_

In [26]:
attn_weights_o[6]

tensor([0.0303, 0.0302, 0.0276, 0.0407, 0.0459, 0.0343, 0.5530, 0.0297, 0.0459,
        0.0343, 0.1281])

We have the weights of 55% for _ship_ and 13% for _crew_, the rest from the other words.

And the new contextual embedding is for _ship_ is a linear combination:

In [27]:
self_attention_ship = (0.03 * embeddings_dict['i'] +
                       0.03 * embeddings_dict['must'] +
                       0.03 * embeddings_dict['go'] +
                       0.04 * embeddings_dict['back'] +
                       0.05 * embeddings_dict['to'] +
                       0.03 * embeddings_dict['my'] +
                       0.55 * embeddings_dict['ship'] +
                       0.03 * embeddings_dict['and'] +
                       0.05 * embeddings_dict['to'] +
                       0.03 * embeddings_dict['my'] +
                       0.13 * embeddings_dict['crew'])
self_attention_ship

tensor([ 1.0442,  0.0966,  0.3467, -0.4238,  0.2203, -0.0956, -0.9915,  0.6637,
         0.4368, -0.7943,  0.5639,  0.9838,  0.0240,  0.5066,  0.0732, -0.1740,
        -0.3322,  0.5614, -1.1613, -0.5717,  0.4356,  0.4120, -0.0659, -0.3336,
         0.6579, -1.7421, -0.0344,  0.1440,  0.8547, -0.1430,  2.6614, -0.0553,
        -0.5376,  0.3057,  0.4068,  0.2231,  0.3959, -0.2940, -0.1163, -0.1340,
         0.1709, -0.5332,  0.9552, -0.4178, -0.1058, -0.1715, -0.2251, -0.3923,
         0.2098, -0.3625])

Exact and complete computation of the whole matrix with torch of 
$$
\text{softmax}(\frac{{Q}  {K}^\intercal}{\sqrt{d_k}})  {V} :
$$

In [28]:
self_attention_output_o = attn_weights_o @ X_o

The contextual embeddings for _ship:_

In [29]:
self_attention_output_o[6]

tensor([ 1.0387,  0.1033,  0.3426, -0.4320,  0.2237, -0.0958, -0.9926,  0.6662,
         0.4424, -0.7942,  0.5638,  0.9921,  0.0205,  0.5082,  0.0743, -0.1773,
        -0.3408,  0.5675, -1.1545, -0.5718,  0.4288,  0.4191, -0.0658, -0.3339,
         0.6682, -1.7473, -0.0485,  0.1531,  0.8642, -0.1447,  2.6571, -0.0545,
        -0.5343,  0.3160,  0.4041,  0.2277,  0.3958, -0.2916, -0.1126, -0.1385,
         0.1744, -0.5375,  0.9499, -0.4145, -0.1039, -0.1755, -0.2213, -0.3995,
         0.2119, -0.3610])

We can now write an `attention` function 

In [30]:
def attention(Q, K, V):
    d_k = K.size(dim=-1)
    scale = d_k ** -0.5
    attn_weights = F.softmax(Q @ K.T*scale, dim=-1)
    attn_output = attn_weights @ V
    return attn_output, attn_weights

The word _ship_ in another context: _We process and ship your order_

In [31]:
attention_output_a, attn_weights_a = attention(X_a, X_a, X_a)

Attention weights for _ship:_

In [32]:
attn_weights_a[3]

tensor([0.0431, 0.0258, 0.0419, 0.7811, 0.0490, 0.0590])

In [33]:
print_attn_weights(words_a, embeddings_dict)

	we	process	and	ship	your	order	
we	0.61	0.06	0.06	0.02	0.20	0.05	
process	0.17	0.50	0.08	0.03	0.11	0.11	
and	0.22	0.12	0.30	0.08	0.15	0.13	
ship	0.04	0.03	0.04	0.78	0.05	0.06	
your	0.14	0.03	0.03	0.02	0.74	0.04	
order	0.16	0.13	0.10	0.09	0.18	0.34	


The new contextual embeddings for _ship:_

In [34]:
attention_output_a[3]

tensor([ 1.2758,  0.1034,  0.2720, -0.4776,  0.1746, -0.1060, -0.9901,  0.6328,
         0.6967, -0.8847,  0.7106,  1.2264,  0.2491,  0.5023, -0.1277, -0.2361,
        -0.3709,  0.6545, -1.2587, -0.5332,  0.6681,  0.1687, -0.2567, -0.4218,
         0.6960, -1.7077, -0.0052,  0.1572,  1.0763,  0.0410,  2.5467, -0.3418,
        -0.5414,  0.4175,  0.4147,  0.2666,  0.3770, -0.4228, -0.2462, -0.0377,
         0.3202, -0.7298,  1.2020, -0.5636, -0.0899, -0.1845, -0.2390, -0.4307,
         0.3828, -0.4905])

## Attention class

In [35]:
class Attention(nn.Module):
    def __init__(self, d_model, d_k):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_k
        self.WQ = nn.Linear(d_model, d_k, bias=False)
        self.WK = nn.Linear(d_model, d_k, bias=False)
        self.WV = nn.Linear(d_model, d_k, bias=False)

    def forward(self, X):
        Q = self.WQ(X)
        K = self.WK(X)
        V = self.WV(X)
        attn_weights = F.softmax(
            Q @ K.T * self.d_k ** -0.5, dim=-1)
        attn_output = attn_weights @ V
        return attn_output, attn_weights

In [36]:
attn = Attention(X_o.size(dim=1), X_o.size(dim=1))
attn(X_o)

(tensor([[-1.6623e-01, -3.2289e-01, -3.6165e-01, -2.7903e-02, -2.3599e-01,
           4.7430e-01,  6.5428e-01, -2.8029e-01,  2.0756e-02,  5.3674e-01,
          -8.3914e-02, -5.4726e-01, -5.4179e-01,  2.7358e-01,  8.2042e-01,
          -8.8333e-02,  5.8626e-02,  3.1333e-01, -4.4335e-01,  4.0806e-01,
           3.1551e-01, -4.1586e-02, -1.5990e-01,  1.9130e-01,  4.0484e-01,
          -2.3194e-01,  1.1097e-01,  3.9249e-02, -2.8538e-01,  4.4958e-02,
          -1.1391e-01,  5.1877e-01, -3.1351e-02,  1.8618e-01,  3.7673e-01,
          -3.7783e-01,  3.2962e-01, -3.5203e-01,  2.1866e-01,  2.1663e-01,
           3.3362e-01, -5.1468e-01,  1.7154e-01, -1.7607e-02,  5.5630e-01,
           2.4584e-01,  5.0733e-01, -2.8888e-01,  2.9656e-01, -3.4202e-01],
         [-1.7108e-01, -3.1969e-01, -3.5901e-01, -2.8647e-02, -2.3346e-01,
           4.7101e-01,  6.5450e-01, -2.7812e-01,  1.8074e-02,  5.3323e-01,
          -8.2419e-02, -5.4471e-01, -5.4338e-01,  2.7221e-01,  8.2259e-01,
          -8.9572e-02,  

## Multihead Attention

In [37]:
class MultiheadAttention(nn.Module):
    def __init__(self, d_model, h):
        super().__init__()
        self.h = h
        d_k = d_model // h
        self.attn_modules = nn.ModuleList(
            [Attention(d_model, d_k) for _ in range(h)])
        self.WO = nn.Linear(d_model, d_model, bias=False)

    def forward(self, X):
        attn_heads, attn_weights = zip(
            *[attn_module(X)
              for attn_module in self.attn_modules])
        attn_output = self.WO(torch.cat(attn_heads, dim=-1))
        attn_weights = torch.sum(torch.stack(attn_weights),
                                 dim=0)/self.h
        return attn_output, attn_weights

In [38]:
h = 5
multihead_attn = MultiheadAttention(X_o.size(dim=1), h)

In [39]:
multihead_attn(X_o)

(tensor([[-0.0476,  0.0971, -0.2834, -0.0025, -0.0784, -0.1255,  0.0030,  0.0129,
           0.1317,  0.1839, -0.1862, -0.2681,  0.1798,  0.1930, -0.1245,  0.0827,
           0.0704, -0.2206, -0.2917,  0.1680,  0.0738,  0.0976, -0.3515,  0.1676,
          -0.1140, -0.2064, -0.0919, -0.0226, -0.4319, -0.5557, -0.2912, -0.0844,
          -0.0881, -0.2669,  0.1640,  0.1006,  0.1010,  0.0087,  0.4225, -0.3989,
           0.4763,  0.0778,  0.1797,  0.0377, -0.1116, -0.1889,  0.1885, -0.1090,
           0.0822, -0.1300],
         [-0.0422,  0.0979, -0.2805,  0.0017, -0.0836, -0.1249, -0.0054,  0.0156,
           0.1330,  0.1845, -0.1814, -0.2664,  0.1776,  0.1950, -0.1214,  0.0706,
           0.0782, -0.2158, -0.2916,  0.1679,  0.0771,  0.0986, -0.3488,  0.1726,
          -0.1181, -0.2064, -0.0880, -0.0173, -0.4313, -0.5478, -0.3017, -0.0807,
          -0.0915, -0.2621,  0.1648,  0.0938,  0.0985,  0.0124,  0.4275, -0.4023,
           0.4690,  0.0771,  0.1849,  0.0387, -0.1051, -0.1829,  0.19

## PyTorch implementation
 
PyTorch has an implementation of self-attention encapsulated in the `MultiheadAttention` class. Before going to the attention module, the query, key value, goes through a linear layer. The output also goes through a linear layer. These three layers are initialized with Xavier's algorithm.

In [40]:
from torch.nn import MultiheadAttention

att_layer = MultiheadAttention(50,
                               1,
                               bias=False,
                               batch_first=True)

In [41]:
(attn_output, attn_weights) = att_layer(X_o, X_o, X_o)

The attention weights for _ship:_

In [42]:
attn_weights[6]

tensor([0.0806, 0.0874, 0.1017, 0.1062, 0.0982, 0.0740, 0.0890, 0.1023, 0.0982,
        0.0740, 0.0885], grad_fn=<SelectBackward0>)

In [43]:
attn_output[6]

tensor([-0.0423,  0.0588, -0.2737,  0.2912,  0.1943, -0.5409, -0.1762,  0.2302,
         0.1108,  0.1127,  0.4632,  0.1016,  0.0829, -0.4936, -0.0664,  0.3236,
         0.3227, -0.0398, -0.1062, -0.2222,  0.3521, -0.4974, -0.2566, -0.3431,
        -0.3069,  0.0561, -0.1297, -0.1030,  0.1801,  0.2085, -0.2368,  0.1301,
        -0.0616,  0.4501, -0.2524, -0.1762,  0.4751, -0.2228,  0.0771,  0.3694,
        -0.0914, -0.1948,  0.0567,  0.0920, -0.1048,  0.1368, -0.3046, -0.3925,
         0.4692,  0.0542], grad_fn=<SelectBackward0>)

### The initial dense layers

The weight initial values with the 4 matrices

In [44]:
att_layer.state_dict()

OrderedDict([('in_proj_weight',
              tensor([[-0.0096,  0.1481,  0.1668,  ...,  0.1107,  0.0094,  0.1637],
                      [-0.1581,  0.1049, -0.0774,  ...,  0.1085,  0.1400,  0.0530],
                      [-0.1002, -0.0320,  0.0373,  ..., -0.0191, -0.1164, -0.0301],
                      ...,
                      [ 0.0568,  0.0302,  0.0529,  ..., -0.1093,  0.0568,  0.0989],
                      [ 0.0490, -0.1305, -0.1599,  ..., -0.1245, -0.1216,  0.0492],
                      [ 0.0905,  0.0845, -0.0155,  ...,  0.0216,  0.0237,  0.1466]])),
             ('out_proj.weight',
              tensor([[ 0.1042,  0.1293,  0.0220,  ..., -0.0726,  0.1329,  0.0067],
                      [-0.0826,  0.0951, -0.0153,  ...,  0.0756,  0.0262, -0.1396],
                      [ 0.0387, -0.0094,  0.1204,  ..., -0.0228, -0.0403, -0.1186],
                      ...,
                      [-0.0911,  0.0308, -0.0224,  ...,  0.0062,  0.0424,  0.1293],
                      [ 0.0966,  0.124

The three input matrices are concatenated

In [45]:
att_layer.state_dict()['in_proj_weight'].size()

torch.Size([150, 50])

The output matrix

In [46]:
att_layer.state_dict()['out_proj.weight'].size()

torch.Size([50, 50])

### By-passing the dense layers

We create identity matrices to pass through the dense layers and recover the attention output values and weights

In [47]:
i_50 = torch.eye(50)
i_50

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]])

In [48]:
att_layer.state_dict()['out_proj.weight'][:] = i_50

In [49]:
att_layer.state_dict()['in_proj_weight'].size()

torch.Size([150, 50])

In [50]:
i_3_50 = torch.vstack((i_50, i_50, i_50))
i_3_50.size()

torch.Size([150, 50])

In [51]:
att_layer.state_dict()['in_proj_weight'][:] = i_3_50

In [52]:
att_layer.state_dict()

OrderedDict([('in_proj_weight',
              tensor([[1., 0., 0.,  ..., 0., 0., 0.],
                      [0., 1., 0.,  ..., 0., 0., 0.],
                      [0., 0., 1.,  ..., 0., 0., 0.],
                      ...,
                      [0., 0., 0.,  ..., 1., 0., 0.],
                      [0., 0., 0.,  ..., 0., 1., 0.],
                      [0., 0., 0.,  ..., 0., 0., 1.]])),
             ('out_proj.weight',
              tensor([[1., 0., 0.,  ..., 0., 0., 0.],
                      [0., 1., 0.,  ..., 0., 0., 0.],
                      [0., 0., 1.,  ..., 0., 0., 0.],
                      ...,
                      [0., 0., 0.,  ..., 1., 0., 0.],
                      [0., 0., 0.,  ..., 0., 1., 0.],
                      [0., 0., 0.,  ..., 0., 0., 1.]]))])

### Multihead attention without the dense layers

We obtain now the same results as the `self_attention()` function for _ship:_

The attention weights for _ship:_

In [53]:
(attn_output, attn_weights) = att_layer(X_o, X_o, X_o)

The attention vector for _ship:_

In [54]:
attn_weights[6]

tensor([0.0303, 0.0302, 0.0276, 0.0407, 0.0459, 0.0343, 0.5530, 0.0297, 0.0459,
        0.0343, 0.1281], grad_fn=<SelectBackward0>)

The embedding vector for _ship_

In [55]:
attn_output[6]

tensor([ 1.0387,  0.1033,  0.3426, -0.4320,  0.2237, -0.0958, -0.9926,  0.6662,
         0.4424, -0.7942,  0.5638,  0.9921,  0.0205,  0.5082,  0.0743, -0.1773,
        -0.3408,  0.5675, -1.1545, -0.5718,  0.4288,  0.4191, -0.0658, -0.3339,
         0.6682, -1.7473, -0.0485,  0.1531,  0.8642, -0.1447,  2.6571, -0.0545,
        -0.5343,  0.3160,  0.4041,  0.2277,  0.3958, -0.2916, -0.1126, -0.1385,
         0.1744, -0.5375,  0.9499, -0.4145, -0.1039, -0.1755, -0.2213, -0.3995,
         0.2119, -0.3610], grad_fn=<SelectBackward0>)

## Multihead

In [56]:
att_layer_5 = MultiheadAttention(50,
                                 5,
                                 bias=False,
                                 batch_first=True)

In [57]:
attn_output, attn_weights = att_layer_5(X_o, X_o, X_o)

In [58]:
attn_output.size()

torch.Size([11, 50])

In [59]:
attn_weights.size()

torch.Size([11, 11])

In [60]:
att_layer_5.state_dict()

OrderedDict([('in_proj_weight',
              tensor([[ 0.1360,  0.0887,  0.0160,  ..., -0.0507, -0.0443,  0.0677],
                      [-0.0965, -0.0508, -0.0583,  ..., -0.0545,  0.0740,  0.0370],
                      [ 0.1667, -0.0200, -0.0527,  ..., -0.0334, -0.1026, -0.0321],
                      ...,
                      [-0.0631, -0.0444,  0.0494,  ...,  0.1621, -0.1284, -0.0097],
                      [ 0.0005, -0.0813,  0.0025,  ..., -0.0986,  0.0662, -0.0093],
                      [ 0.0406, -0.0478, -0.0056,  ..., -0.1667,  0.1692,  0.1685]])),
             ('out_proj.weight',
              tensor([[ 0.0088,  0.0021, -0.0036,  ...,  0.0429, -0.0320,  0.0913],
                      [ 0.0726,  0.1074,  0.0703,  ...,  0.1014,  0.0193, -0.0654],
                      [-0.1264,  0.0508,  0.0701,  ..., -0.0122,  0.1093, -0.0150],
                      ...,
                      [-0.0537, -0.1326,  0.0687,  ..., -0.1040,  0.1059,  0.0237],
                      [-0.0228,  0.140

In [61]:
att_layer_5.state_dict()['in_proj_weight'].size()

torch.Size([150, 50])

In [62]:
att_layer_5.state_dict()['out_proj.weight'].size()

torch.Size([50, 50])

## Test with a simple matrix

Three words, dimension of embeddings: 4

In [63]:
X_test = torch.tensor([[1.0, 0.0, 0.0, 1.0],
                       [0.0, 1.5, 1.0, 1.0],
                       [0.0, 1.0, 1.0, 1.0]])

In [64]:
X_test.size()

torch.Size([3, 4])

### Self-attention
We use the function above

In [65]:
attention(X_test, X_test, X_test)

(tensor([[0.4519, 0.6852, 0.5481, 1.0000],
         [0.1045, 1.1609, 0.8955, 1.0000],
         [0.1387, 1.1034, 0.8613, 1.0000]]),
 tensor([[0.4519, 0.2741, 0.2741],
         [0.1045, 0.5307, 0.3648],
         [0.1387, 0.4842, 0.3771]]))

### Multihead attention from PyTorch

In [66]:
att_layer = MultiheadAttention(4,
                               1,
                               bias=False)

The multihead attention uses a Xavier initialization of the dense layers. The results will be different for those of `attention()`

In [67]:
att_layer(X_test,
          X_test,
          X_test)

(tensor([[-0.1571, -0.0579,  0.1843,  0.1282],
         [-0.1326,  0.0203,  0.2545,  0.1003],
         [-0.1359,  0.0097,  0.2450,  0.1040]], grad_fn=<SqueezeBackward1>),
 tensor([[0.2597, 0.3744, 0.3659],
         [0.4984, 0.2479, 0.2538],
         [0.4644, 0.2631, 0.2725]], grad_fn=<SqueezeBackward1>))

Weights of the dense layers

In [68]:
att_layer.state_dict()

OrderedDict([('in_proj_weight',
              tensor([[-0.2128, -0.4922, -0.4521,  0.2865],
                      [ 0.2480, -0.3926, -0.5545,  0.3429],
                      [ 0.0617, -0.5343, -0.0888, -0.1824],
                      [-0.0799, -0.2073,  0.4222,  0.2441],
                      [-0.0213,  0.5616,  0.1008, -0.3557],
                      [-0.4620,  0.0832,  0.5685, -0.3563],
                      [-0.2444, -0.5901,  0.0873,  0.6115],
                      [-0.0226, -0.4285,  0.0756,  0.1348],
                      [ 0.5236,  0.2901,  0.0021, -0.5789],
                      [ 0.2230, -0.4890,  0.4181, -0.5811],
                      [ 0.1001, -0.3766, -0.0605,  0.5007],
                      [ 0.4468, -0.5748,  0.5644, -0.1387]])),
             ('out_proj.weight',
              tensor([[-0.2071,  0.3655,  0.1681, -0.2090],
                      [-0.2451,  0.1686,  0.2852,  0.1841],
                      [-0.4539, -0.2133,  0.2407,  0.4938],
                      [ 0.3667, 

### By-passing the dense layers

We use weights of identity matrices

In [69]:
i_4 = torch.eye(4)

In [70]:
att_layer.state_dict()['out_proj.weight'][:] = i_4

In [71]:
i_3_4 = torch.vstack((i_4, i_4, i_4))
i_3_4.size()

torch.Size([12, 4])

We set these weights

In [72]:
att_layer.state_dict()['in_proj_weight'][:] = i_3_4

In [73]:
att_layer.state_dict()

OrderedDict([('in_proj_weight',
              tensor([[1., 0., 0., 0.],
                      [0., 1., 0., 0.],
                      [0., 0., 1., 0.],
                      [0., 0., 0., 1.],
                      [1., 0., 0., 0.],
                      [0., 1., 0., 0.],
                      [0., 0., 1., 0.],
                      [0., 0., 0., 1.],
                      [1., 0., 0., 0.],
                      [0., 1., 0., 0.],
                      [0., 0., 1., 0.],
                      [0., 0., 0., 1.]])),
             ('out_proj.weight',
              tensor([[1., 0., 0., 0.],
                      [0., 1., 0., 0.],
                      [0., 0., 1., 0.],
                      [0., 0., 0., 1.]]))])

Now we have the same results as with `attention()`

In [74]:
att_layer(X_test,
          X_test,
          X_test)

(tensor([[0.4519, 0.6852, 0.5481, 1.0000],
         [0.1045, 1.1609, 0.8955, 1.0000],
         [0.1387, 1.1034, 0.8613, 1.0000]], grad_fn=<SqueezeBackward1>),
 tensor([[0.4519, 0.2741, 0.2741],
         [0.1045, 0.5307, 0.3648],
         [0.1387, 0.4842, 0.3771]], grad_fn=<SqueezeBackward1>))